angularjs中state的参数4_一文梳理pytorch保存和重载模型参数攻略
訓練過程中保存模型參數,就不怕斷電了——沃資基·索德
在訓練完成之前,我們需要每隔一段時間保存模型當前參數值,一方面可以防止斷電重跑,另一方面可以觀察不同迭代次數模型的表現;在訓練完成以后,我們需要保存模型參數值用于后續的測試過程。所以,保存的對象包含網絡參數值、優化器參數值、epoch值等等。
一、定義一個容易識別的網絡
在正式介紹模型的保存和加載之前,我們首先定義一個基本的網絡Net,它只包含一個全連接層:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layer = nn.Linear(1, 1)self.layer.weight = nn.Parameter(torch.FloatTensor([[10]]))self.layer.bias = nn.Parameter(torch.FloatTensor([1]))def forward(self, x):y = self.layer(x)return y我將全連接的權重w和偏差b分別設置為10和1,全連接的計算方式如下:
假設輸入x=1,可以知道y值為11:
測試一下輸出是不是11,代碼如下:
x = torch.FloatTensor([[1]]) net = Net() out = net(x) print(out)輸出:tensor([[11.]], grad_fn=<AddmmBackward>),說明上述計算是正確的。不采用參數隨機初始化,而是用特殊的數值初始化,是因為我們希望重載模型的時候,能夠從特殊數值一眼判斷出保存和重載過程是否正確,也可以把權重設置為一張圖片數值,然后判斷加載的參數值能不能恢復原圖。
二、保存Net的參數值
保存模型參數之前,需要知道Net的參數值存儲在其state_dict(狀態字典)屬性中,我們查看一下net的state_dict包含哪些參數:
print(net.state_dict())我們將會得到net包含的所有參數名稱與參數值:
包含一個weight和一個bias,對應的值分別是10和1,和我們之前定義的全連接層一致。我們需要保存的就是這個state_dict,保存的函數為“torch.save()”,參數是我們需要保存的dict和存儲路徑:
torch.save(obj=net.state_dict(), f="models/net.pth")現在,同級目錄models下將會出現net.pth文件,pth文件中的內容就是net的參數名稱和值對應的state_dict,如下:
三、加載Net參數值并用于新的模型
最后一個步驟就是從pth文件中重新獲取Net參數值,并把參數值裝載到新定義的Model對象中。這里我們重新定義一個結構和Net類相同的類Model,區別僅僅是Model參數初始值和Net不同,代碼如下:
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.layer = nn.Linear(1, 1)self.layer.weight = nn.Parameter(torch.FloatTensor([[0]]))self.layer.bias = nn.Parameter(torch.FloatTensor([0]))def forward(self, x):out = self.layer(x)return out這里將Model的初始值權重w和偏差都設置為0,查看其state_dict:
model = Model() print(model.state_dict())得到的w和b值與預期相同,均為0,如下:
現在,我們將model對象的參數值設置為net.pth中的值,需要使用“model.load_state_dict()”函數重置model的參數值為"torch.load(models/ net.pth)"中的參數值,如下:
model.load_state_dict(torch.load("models/net.pth")) print(model.state_dict())至此,model的w和b值就不再是0了,而是net中w和b對應的10和1,如下:
其中參數值重載的核心函數為“model.load_state_dict()”,每個繼承自nn.Module的網絡都能通過這個函數設定參數值。
四、優化器與epoch的保存
保存優化器參數值和epoch值的主要目的是用于繼續訓練,保存的流程依舊是先“torch.save()”再“torch.load_state_dict()”,我們首先定義一個Adam優化器、一個任意的epoch值與net如下:
net = Net() Adam = optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999)) epoch = 96現在,創建一個字典來保存所有的對象,并用save函數保存這個字典:
all_states = {"net": net.state_dict(), "Adam": Adam.state_dict(), "epoch": epoch} torch.save(obj=all_states, f="models/all_states.pth")所有的對象都被保存到models文件夾下了:
可以使用load()函數把所有的對象再次提取出來:
reload_states = torch.load("models/all_states.pth") print(reload_states)得到的所有參數如下:
五、總結
pytorch中state_dict()和load_state_dict()函數配合使用可以實現狀態的獲取與重載,load()和save()函數配合使用可以實現參數的存儲與讀取。其中最重要的部分是“字典”的概念,因為參數對象的存儲是需要“名稱”——“值”對應(即鍵值對),讀取時也是通過鍵值對讀取的。
參考:
https://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/
https://blog.csdn.net/Code_Mart/article/details/88254444
總結
以上是生活随笔為你收集整理的angularjs中state的参数4_一文梳理pytorch保存和重载模型参数攻略的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 工业机器人电柜布线_沙湾附近回收工业锅炉
- 下一篇: 机器人煮面机创始人_秋天的第一杯枸杞拿铁