Pytorch 加载和保存模型
目錄
保存和加載模型
1.??什么是狀態(tài)字典:state_dict?
2.保存和加載推理模型
2.1 保存/加載 state_dict (推薦使用)
2.2 保存/加載完整模型
3. 保存和加載 Checkpoint 用于推理/繼續(xù)訓(xùn)練
4. 在一個(gè)文件中保存多個(gè)模型
5. 使用在不同模型參數(shù)下的熱啟動(dòng)模式
6. 通過設(shè)備保存/加載模型
6.1 保存到 CPU、加載到 CPU
6.2 保存到 GPU、加載到 GPU
6.3 保存到 CPU,加載到 GPU
6.4 保存 torch.nn.DataParallel 模型
?
截取自PyTorch官方教程中文版,這書好像是拼接的,沒有頁碼。
其中一個(gè)版本地址:http://www.pytorch123.com/SecondSection/what_is_pytorch/,好像沒有我下載的這本全
?
保存和加載模型
? ? 當(dāng)保存和加載模型時(shí),需要熟悉三個(gè)核心功能:
? ? torch.save :將序列化對象保存到磁盤。此函數(shù)使用Python的 pickle 模塊進(jìn)行序列化。使用此函數(shù)可以保存如模型、tensor、? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?字典等各種對象。
? ?torch.load :使用pickle的 unpickling 功能將pickle對象文件反序列化到內(nèi)存。此功能還可以有助于設(shè)備加載數(shù)據(jù)。
? ?torch.nn.Module.load_state_dict :使用反序列化函數(shù) state_dict 來加載模型的參數(shù)字典。
?
?
1.?什么是狀態(tài)字典:state_dict?
在PyTorch中, torch.nn.Module 模型的可學(xué)習(xí)參數(shù)(即權(quán)重和偏差)包含在模型的參數(shù)中,(使用 model.parameters() 可以進(jìn)行訪問)。 state_dict 是Python字典對象,它將每一層映射到其參數(shù)張量。注意,只有具有可學(xué)習(xí)參數(shù)的層(如卷積層,線性層等)的模型 才具有 state_dict 這一項(xiàng)。目標(biāo)優(yōu)化 torch.optim 也有 state_dict 屬性,它包含有關(guān)優(yōu)化器的狀態(tài)信息,以及使用的超參數(shù)。
? ? 因?yàn)?/span>state_dict的對象是Python字典,所以它們可以很容易的保存、更新、修改和恢復(fù),為PyTorch模型和優(yōu)化器添加了大量模塊。
下面通過從簡單模型訓(xùn)練一個(gè)分類器中來了解一下 state_dict 的使用。
import torch.nn as nn import torch.nn.functional as F import torch.optim as optim# 定義模型 class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型 model = TheModelClass() # 初始化優(yōu)化器 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 打印模型的狀態(tài)字典 print("Model's state_dict:") for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # 打印優(yōu)化器的狀態(tài)字典 print("Optimizer's state_dict:") for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])運(yùn)行結(jié)果:
?
?
2.保存和加載推理模型
2.1 保存/加載 state_dict (推薦使用)
(1)保存
PATH = 'test.pt'torch.save(model.state_dict(),? PATH)?
(2)加載
model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH))model.eval()? ? 當(dāng)保存好模型用來推斷的時(shí)候,只需要保存模型學(xué)習(xí)到的參數(shù),使用 torch.save() 函數(shù)來保存模型 state_dict ,它會(huì)給模型恢復(fù)提供 最大的靈活性,這就是為什么要推薦它來保存的原因。
? ? 在 PyTorch 中最常見的模型保存使‘.pt’或者是‘.pth’作為模型文件擴(kuò)展名。
? ? 請記住,在運(yùn)行推理之前,務(wù)必調(diào)用 model.eval() 去設(shè)置 dropout 和 batch normalization 層為評估模式。如果不這么做,可能導(dǎo)致 模型推斷結(jié)果不一致。
? ? 注意:load_state_dict() 函數(shù)只接受字典對象,而不是保存對象的路徑。這就意味著在你傳給load_state_dict() 函數(shù)之前,你必須反序列化 你保存的 state_dict 。例如,你無法通過model.load_state_dict(PATH) 來加載模型。
?
2.2 保存/加載完整模型
(1)保存
PATH = 'test.pt'torch.save(model, PATH)(2)加載
# 模型類必須在此之前被定義model = torch.load(PATH)model.eval()? ? 在 PyTorch 中最常見的模型保存使用‘.pt’或者是‘.pth’作為模型文件擴(kuò)展名。
? ? 此部分保存/加載過程使用最直觀的語法并涉及最少量的代碼。以 Python `pickle 模塊的方式來保存模型。這種方法的缺點(diǎn)是序列化數(shù)據(jù)受 限于某種特殊的類而且需要確切的字典結(jié)構(gòu)。這是因?yàn)?/span>pickle無法保存模型類本身。相反,它保存包含類的文件的路徑,該文件在加載時(shí)使用。 因此,當(dāng)在其他項(xiàng)目使用或者重構(gòu)之后,您的代碼可能會(huì)以各種方式中斷。
? ? 請記住,在運(yùn)行推理之前,務(wù)必調(diào)用 model.eval() 設(shè)置 dropout 和 batch normalization 層為評估模式。如果不這么做,可能導(dǎo)致模型推斷結(jié)果不一致。
?
?
3. 保存和加載 Checkpoint 用于推理/繼續(xù)訓(xùn)練
(1)保存
PATH = 'test.tar' torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,... }, PATH)(2)加載
model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - or - model.train()? ? 要保存多個(gè)組件,請?jiān)谧值渲薪M織它們并使用 torch.save() 來序列化字典。PyTorch 中常見的保存checkpoint 是使用 .tar 文件擴(kuò)展名。
? ? 當(dāng)保存成 Checkpoint 的時(shí)候,可用于推理或者是繼續(xù)訓(xùn)練,保存的不僅僅是模型的 state_dict。保存優(yōu)化器的 state_dict 也很重要, 因?yàn)樗鳛槟P陀?xùn)練更新的緩沖區(qū)和參數(shù)。你也許想保存其他項(xiàng)目,比如最新記錄的訓(xùn)練損失,外部的 torch.nn.Embedding 層等等。
? ? 要加載項(xiàng)目,首先需要初始化模型和優(yōu)化器,然后使用 torch.load() 來加載本地字典。這里,你可以非常容易的通過簡單查詢字典來訪問你所保存的項(xiàng)目。
? ? 請記住在運(yùn)行推理之前,務(wù)必調(diào)用 model.eval() 去設(shè)置 dropout 和 batch normalization 為評估。如果不這樣做,有可能得到不一致的推斷結(jié)果。 如果你想要恢復(fù)訓(xùn)練,請調(diào)用 model.train() 以確保這些層處于訓(xùn)練模式。
?
?
4. 在一個(gè)文件中保存多個(gè)模型
(1)保存
PATH = 'test.tar' torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),... }, PATH)(2)加載
modelA = TheModelAClass(*args, **kwargs)modelB = TheModelBClass(*args, **kwargs)optimizerA = TheOptimizerAClass(*args, **kwargs)optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)modelA.load_state_dict(checkpoint['modelA_state_dict'])modelB.load_state_dict(checkpoint['modelB_state_dict'])optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])modelA.eval()modelB.eval()# - or -modelA.train()modelB.train()? ? PyTorch 中常見的保存 checkpoint 是使用 .tar 文件擴(kuò)展名。
? ? 要加載項(xiàng)目,首先需要初始化模型和優(yōu)化器,然后使用 torch.load() 來加載本地字典。這里,你可以非常容易的通過簡單查詢字典來訪問你所保存的項(xiàng)目。
? ? 請記住在運(yùn)行推理之前,務(wù)必調(diào)用 model.eval() 去設(shè)置 dropout 和 batch normalization 為評估。如果不這樣做,有可能得到不一致的推斷結(jié)果。 如果你想要恢復(fù)訓(xùn)練,請調(diào)用 model.train() 以確保這些層處于訓(xùn)練模式。
?
?
5. 使用在不同模型參數(shù)下的熱啟動(dòng)模式
(1)保存
PATH = 'test.pt' torch.save(modelA.state_dict(), PATH)(2)加載
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)? ? 在遷移學(xué)習(xí)或訓(xùn)練新的復(fù)雜模型時(shí),部分加載模型或加載部分模型是常見的情況。利用訓(xùn)練好的參數(shù),有助于熱啟動(dòng)訓(xùn)練過程,并希望幫助你的模型比從頭開始訓(xùn)練能夠更快地收斂。
? ? 無論是從缺少某些鍵的 state_dict 加載還是從鍵的數(shù)目多于加載模型的 state_dict , 都可以通過在load_state_dict() 函數(shù)中將 strict 參數(shù)設(shè)置為 False 來忽略非匹配鍵的函數(shù)。
? ? 如果要將參數(shù)從一個(gè)層加載到另一個(gè)層,但是某些鍵不匹配,主要修改正在加載的 state_dict 中的參數(shù)鍵的名稱以匹配要在加載到模型中的鍵即可。
?
?
6. 通過設(shè)備保存/加載模型
6.1 保存到 CPU、加載到 CPU
(1)保存
PATH = 'test.pt' torch.save(model.state_dict(), PATH)(2)加載
device = torch.device('cpu')model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, map_location=device))? ? 當(dāng)從CPU上加載模型在GPU上訓(xùn)練時(shí), 將 torch.device('cpu') 傳遞給 torch.load() 函數(shù)中的map_location 參數(shù).在這種情況下,使用 map_location 參數(shù)將張量下的存儲(chǔ)器動(dòng)態(tài)的重新映射到CPU設(shè)備。
?
6.2 保存到 GPU、加載到 GPU
(1)保存
PATH = 'test.pt' torch.save(model.state_dict(), PATH)(2)加載
device = torch.device("cuda")model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH))model.to(device)# 確保在你提供給模型的任何輸入張量上調(diào)用input = input.to(device)? ? 當(dāng)在GPU上訓(xùn)練并把模型保存在GPU,只需要使用 model.to(torch.device('cuda')) ,將初始化的 model 轉(zhuǎn)換為 CUDA 優(yōu)化模型。另外,請 務(wù)必在所有模型輸入上使用.to(torch.device('cuda')) 函數(shù)來為模型準(zhǔn)備數(shù)據(jù)。請注意,調(diào)用 my_tensor.to(device)
會(huì)在GPU上返回 my_tensor 的副本。 因此,請記住手動(dòng)覆蓋張量: my_tensor=my_tensor.to(torch.device('cuda')) 。
?
6.3 保存到 CPU,加載到 GPU
(1)保存
PATH = 'test.pt' torch.save(model.state_dict(), PATH)(2)加載
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) # Choose whatever GPU device number you want model.load_state_dict(torch.load(PATH, map_location="cuda:0")) model.to(device) # 確保在你提供給模型的任何輸入張量上調(diào)用input = input.to(device)? ? 在CPU上訓(xùn)練好并保存的模型加載到GPU時(shí),將 torch.load() 函數(shù)中的 map_location 參數(shù)設(shè)置為 cuda:device_id 。這會(huì)將模型加載到 指定的GPU設(shè)備。接下來,請務(wù)必調(diào)用model.to(torch.device('cuda')) 將模型的參數(shù)張量轉(zhuǎn)換為 CUDA 張量。最后,確保在所有模型輸入上使用 .to(torch.device('cuda')) 函數(shù)來為CUDA優(yōu)化模型。請注意,調(diào)用my_tensor.to(device) 會(huì)在GPU上返回 my_tensor 的新副本。它不會(huì)覆蓋 my_tensor 。 因此, 請手動(dòng)覆蓋張量 my_tensor = my_tensor.to(torch.device('cuda')) 。
?
6.4 保存 torch.nn.DataParallel 模型
(1)保存
PATH = 'test.pt' torch.save(model.state_dict(), PATH)(2)加載
# 加載任何你想要的設(shè)備? torch.nn.DataParallel 是一個(gè)模型封裝,支持并行GPU使用。要普通保存 DataParallel 模型,請保存 model.module.state_dict() 。 這樣,你就可以非常靈活地以任何方式加載模型到你想要的設(shè)備中。
總結(jié)
以上是生活随笔為你收集整理的Pytorch 加载和保存模型的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch 加载部分预训练模型并冻结
- 下一篇: matplotlib 画多条折线图且x