加载dict_PyTorch 7.保存和加载pytorch模型的两种方法
眾所周知,python的對(duì)象都可以通過torch.save和torch.load函數(shù)進(jìn)行保存和加載(不知道?那你現(xiàn)在知道了(*^_^*)),比如:
x1 = {"d":"ddf","dd":'fdsf'} torch.save(x1, 'a1.pt')x2 = ["ddf",'fdsf'] torch.save(x2, 'a2.pt')x3 = 1 torch.save(x3, 'a3.pt')x4 = torch.ones(3) torch.save(x4, 'a4.pt')讀取的時(shí)候也是一樣:
x5 = torch.load('a1.pt')x6 = torch.load('a2.pt')x7 = torch.load('a3.pt')x8 = torch.load('a4.pt')這種非常簡單粗暴,直接把整個(gè)對(duì)象扔進(jìn)磁盤文件里保存,所以對(duì)于我們訓(xùn)練好的模型來說,因?yàn)橛?xùn)練好的模型也是一個(gè)對(duì)象,所以我們也可以使用這個(gè)方法把訓(xùn)練好的模型對(duì)象直接扔進(jìn)去。但是這樣有一個(gè)問題,就是模型對(duì)象開銷比較大,比如最近包含1350億個(gè)參數(shù)的那個(gè)有名的神經(jīng)網(wǎng)絡(luò)模型,如果把它保存到磁盤里面沒有百八十T是保存不下的。所以我們是不是可以僅僅保存模型里面的關(guān)鍵數(shù)據(jù)呢?
答案是,可以!
因?yàn)闆Q定一個(gè)模型是什么樣有兩方面的因素,一個(gè)是模型的結(jié)構(gòu)是什么,另一個(gè)是模型的參數(shù)是什么,這兩個(gè)定了,這個(gè)模型也就確定了。模型的結(jié)構(gòu)在我們初始化模型對(duì)象的時(shí)候就定了,比如對(duì)于任意一個(gè)模型類,我們初始化它的兩個(gè)對(duì)象,這兩個(gè)對(duì)象代表的模型的結(jié)構(gòu)肯定是一樣的,區(qū)別就在于它們的參數(shù)不一樣。所以我們保存模型的關(guān)鍵就是保存模型的參數(shù),而模型的結(jié)構(gòu)每次用的時(shí)候新建一個(gè)對(duì)象就好了,然后從磁盤里把模型的參數(shù)讀取出來賦給這個(gè)對(duì)象。是不是超級(jí)簡單?
那我們?cè)趺茨玫侥P偷膮?shù)呢?巧了!
模型的state_dict()函數(shù)就是返回模型的所有參數(shù)的(這個(gè)函數(shù)是nn.Module的,所以所有繼承了nn.Module的模型類都有這個(gè)函數(shù)),比如:
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP() net.state_dict()輸出:
OrderedDict([('hidden.weight',tensor([[-0.4195, 0.2609, 0.4325],[-0.4031, 0.2078, 0.2077]])),('hidden.bias', tensor([ 0.0755, -0.1408])),('output.weight', tensor([[0.2473, 0.6614]])),('output.bias', tensor([0.6191]))])有的同學(xué)可能注意到了,self.act層的參數(shù)沒有包含進(jìn)來!
大哥,self.act層沒有參數(shù)好嗎(捂臉)
還有的同學(xué)可能想問,那有的層有參數(shù)、有的層沒有參數(shù),那萬一加載的時(shí)候把某個(gè)參數(shù)給錯(cuò)了層怎么辦?
完全不會(huì)!注意看,state_dict()返回的是一個(gè)字典,每一個(gè)張量都對(duì)應(yīng)的有層的名字,清清楚楚,絕對(duì)沒有問題。
那這樣就簡單了,舉個(gè)例子看一下:
X = torch.randn(2, 3) Y = net(X) # 這個(gè)net就是上面創(chuàng)建的那個(gè)對(duì)象,我們把它的參數(shù)保存起來,然后新建一個(gè)net2,然后把保存的這些參數(shù)加載進(jìn)net2,這樣我們把X輸入net2得到的Y2應(yīng)該與Y是相等的PATH = "./net.pt" torch.save(net.state_dict(), PATH)net2 = MLP() net2.load_state_dict(torch.load(PATH)) Y2 = net2(X) Y2 == Y輸出:
tensor([[1],[1]], dtype=torch.uint8)輸出的張量,代表Y2 == Y比較結(jié)果為true,也就是說是一樣的,驗(yàn)證了我們的猜想(上面代碼注釋中的那個(gè)猜想)。
好了,以上就是pytorch保存和加載模型的兩種方法,是不是非常簡單?
陽陽:保存和加載pytorch模型的兩種方法,選哪個(gè)好??zhuanlan.zhihu.com總結(jié)
以上是生活随笔為你收集整理的加载dict_PyTorch 7.保存和加载pytorch模型的两种方法的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 家装灯线走线图_电工装修走线图_电工装饰
- 下一篇: drds 解决问题_DRDS 错误代码如