pytorch基础知识整理(三)模型保存与加载
生活随笔
收集整理的這篇文章主要介紹了
pytorch基础知识整理(三)模型保存与加载
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1, torch.save(); troch.load()
torch.save()使用python的pickle模塊把目標保存到磁盤,可以用來保存模型、張量、字典等,文件后綴名一般用pth或pt或pkl。torch.load()使用python的pickle模塊實現從磁盤加載。可以用此來直接保存或加載完整模型:
torch.save(model, 'PATH.pth') model = torch.load('PATH.pth')注意:pytorch1.6以后保存的模型使用zip壓縮,所以保存的模型無法被1.6以前的版本加載,如果要跨版本使用,需要做以下修改
torch.save(model, 'PATH.pth', _use_new_zipfile_serialization=False)2, .state_dict(); .load_state_dict()
模型的框架已經在程序代碼中了,因此訓練好的模型只需要保存模型的參數即可供推理使用。model.state_dict()以字典的形式保存模型的參數,字典的鍵是參數名,值是參數值的張量。得到狀態字典后還需用torch.save()固化到磁盤。
除模型外,優化器optimizer也可以保存和加載狀態字典。
注意在多卡GPU訓練時,保存和加載模型需要在model后加上module,即
torch.save(model.module.state_dict(), 'PATH.pth') model.module.load_state_dict(torch.load('PATH.pth'))3, 保存checkpoint
如果是訓練中途保存用于繼續訓練,就不僅要保存權重參數,還要保存當前epoch,優化器的狀態,當前的損失值等,可以統一打包到一個字典中保存為checkpoint,此時文件后綴名一般用tar。
#保存: torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH) ##加載: checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss']總結
以上是生活随笔為你收集整理的pytorch基础知识整理(三)模型保存与加载的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: pytorch基础知识整理(一)自动求导
- 下一篇: pytorch基础知识整理(四) 模型