1. Saving and Loading
-
saving and loading PyTorch model
-
注:pickle模块是不安全的,即应该只unpickle(load)我们信任的数据,
这也适应于加载PyTorch模型,只能使用从我们信任的来源保存的PyTorch模型;
Method | Memo |
---|---|
torch.save |
|
torch.load |
|
torch.nn.Module.load_state_dict |
|
2. state_dict() Saving Model
|
from pathlib import Path
# 1:create model directory
MODEL_PATH = Path("ModelDir")
MODEL_PATH.mkdir(parents=True,exist_ok=True)
# 2:create model save path
MODEL_NAME = "pytorch_workflow_model_v0.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
# 3:save model state dict
print(f"Saving Model To: {MODEL_SAVE_PATH}")
# only saving state_dict() only save model learned parameter
torch.save(obj=model_0.state_dict(),f=MODEL_SAVE_PATH)
Saving Model To: ModelDir\pytorch_workflow_model_v0.pth
# check saved file path
!ls -l ModelDir/pytorch_workflow_model_v0.pth
# ls *.pth,ls -R DirectoryName:需安装bash_kernel
-rw-rw-r-- 1 elf elf 1063 June 2 10:02 ModelDir/pytorch_workflow_model_v0.pth
3. state_dict() Loading Model
|
|
# instantiate model new instance(will instantiate with random weight)
loaded_model_0 = LinearRegressionModel()
# load saved model state_dict(will update model new instance with trained weight)
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))
<All keys matched successfully>
|
# 1:put loaded model into evaluation mode
loaded_model_0.eval()
# 2:use inference mode context manager to make prediction
with torch.inference_mode():
# perform forward pass on test data with loaded model
loaded_model_pred = loaded_model_0(X_test)
# compare previous model prediction with loaded model prediction(these should be the same)
y_pred == loaded_model_pred
tensor([[True],
[True],
[True],
[True],
[True],
[True],
[True],
[True],
[True],
[True]])
-
看起来加载的模型预测和以前模型预测(保存前的预测)相同,表明模型正在按预期保存和加载;
-
有很多方式可保存和加载PyTorch模型:
PyTorch Guide Saving Loading Model