1. Saving and Loading

  • saving and loading PyTorch model

  • 注:pickle模块是不安全的,即应该只unpickle(load)我们信任的数据,
    这也适应于加载PyTorch模型,只能使用从我们信任的来源保存的PyTorch模型;

  • https://docs.python.org/3/library/pickle.html

Method Memo

torch.save

  • 使用Python的pickle工具将序列化的对象保存到磁盘,可用
    torch.save保存模型、张量和各种其他Python对象(如字典);

torch.load

  • 用pickle的unpickling功能来反序列化pickle Python对象文件;

  • 如模型、张量或字典,并将其加载到内存中;

  • 还可设置将对象加载到哪个设备(CPU、GPU等);

torch.nn.Module.load_state_dict

  • 用保存的state_dict()对象加载模型的参数字典model.state_dict()

2. state_dict() Saving Model

  • A:用pathlib模块创建目录,将模型保存到被调用的模型中;

  • B:创建文件路径来保存模型;

  • C:调用torch.save(obj,f):
    obj是目标模型的state_dict(),f是保存模型的文件名;

  • 注:PyTorch保存的模型或对象通常以.pt或.pth结尾;

from pathlib import Path

# 1:create model directory
MODEL_PATH = Path("ModelDir")
MODEL_PATH.mkdir(parents=True,exist_ok=True)

# 2:create model save path
MODEL_NAME = "pytorch_workflow_model_v0.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3:save model state dict
print(f"Saving Model To: {MODEL_SAVE_PATH}")

# only saving state_dict() only save model learned parameter
torch.save(obj=model_0.state_dict(),f=MODEL_SAVE_PATH)
Saving Model To: ModelDir\pytorch_workflow_model_v0.pth
# check saved file path
!ls -l ModelDir/pytorch_workflow_model_v0.pth

# ls *.pth,ls -R DirectoryName:需安装bash_kernel
-rw-rw-r-- 1 elf elf 1063 June 2 10:02 ModelDir/pytorch_workflow_model_v0.pth

3. state_dict() Loading Model

  • 现在在ModelDir/pytorch_workflow_model_v0.pth中有保存的模型 state_dict()
    可使用 torch.nn.Module.load_state_dict(torch.load(f)) 加载;

  • f是保存的模型state_dict()的filepath;

  • 为何torch.nn.Module.load_state_dict()内部可调用torch.load()?

  • 因只保存了模型的state_dict(),它是学习参数的字典learned parameter dictionary,
    而非整个模型,故我们首先必须用torch.load()加载state_dict();

  • 然后将该state_dict()传递给模型的新实例(nn.Module的子类(subclass));

  • 为何不保存整个entire模型?保存整个模型而不仅仅是state_dict()更直观intuitive;

  • Saving Loading Model For Entire Model

  • 保存整个模型的缺点:序列化的数据绑定到特定的类和保存模型时使用的确切的目录结构;
    故当在其它项目中使用或重构后,代码可能以各种方式中断;

  • serialize data bound to specific class and exact directory structure;

  • 故使用灵活的方法,只保存和加载state_dict(),基本上是模型参数的字典;

  • 我们通过创建LinearRegressionModel()的另一个实例来测试它,
    它是torch.nn.Module的子类,故将具有内置方法 load_state_dict()

# instantiate model new instance(will instantiate with random weight)
loaded_model_0 = LinearRegressionModel()

# load saved model state_dict(will update model new instance with trained weight)
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))
<All keys matched successfully>
  • 看起来一切都很匹配match up,现在为测试加载的模型,将对测试数据进行推理(预测)

  • PyTorch Inference Rule:

    • set model in evaluation mode:model.eval();

    • make prediction via inference mode context manager:
      with torch.inference_mode(): …​;

    • all prediction should made with object on the same device:
      e.g. data and model on GPU only or data and model on CPU only;

# 1:put loaded model into evaluation mode
loaded_model_0.eval()

# 2:use inference mode context manager to make prediction
with torch.inference_mode():
    # perform forward pass on test data with loaded model
    loaded_model_pred = loaded_model_0(X_test)

# compare previous model prediction with loaded model prediction(these should be the same)
y_pred == loaded_model_pred
tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])
  • 看起来加载的模型预测和以前模型预测(保存前的预测)相同,表明模型正在按预期保存和加载;

  • 有很多方式可保存和加载PyTorch模型:
    PyTorch Guide Saving Loading Model