PyTorch模型的保存加载以及数据的可视化
文章目錄
- PyTorch模型的保存和加載
- 模塊和張量的序列化和反序列化
- 模塊狀態(tài)字典的保存和載入
- PyTorch數(shù)據(jù)的可視化
- TensorBoard的使用
- 總結(jié)
PyTorch模型的保存和加載
在深度學(xué)習(xí)模型的訓(xùn)練過程中,如何周期性地對(duì)模型做存檔非常重要。
一方面,深度學(xué)習(xí)模型的訓(xùn)練是一個(gè)長期的過程,一般來說,大的模型可能運(yùn)行數(shù)天或者數(shù)周,這樣可能就會(huì)在訓(xùn)練的過程中出現(xiàn)一些問題。由于模型一般在運(yùn)行時(shí)保存在計(jì)算機(jī)的內(nèi)存或者顯存中,一旦出現(xiàn)問題可能會(huì)導(dǎo)致模型訓(xùn)練結(jié)果的丟失。另一方面,對(duì)于訓(xùn)練好的模型,經(jīng)常要對(duì)實(shí)際的數(shù)據(jù)進(jìn)行預(yù)測,這就要求訓(xùn)練好的模型權(quán)重能以一定的格式保存到硬盤中,方便后續(xù)使用時(shí)直接載入原來的權(quán)重。
基于這兩點(diǎn)的共同要求,PyTorch提供了很好的機(jī)制來進(jìn)行模型的保存和加載。
模塊和張量的序列化和反序列化
由于PyTorch的模塊和張量本質(zhì)上是torch.nn.Module和torch.tensor類的實(shí)例,而PyTorch自帶了一系列方法可將這些類的實(shí)例轉(zhuǎn)換成字符串,所以這些實(shí)例可以通過Python序列化方法進(jìn)行序列化和反序列化。
PyTorch里面集成了Python自帶的pickle包對(duì)模塊和張量進(jìn)行序列化。張量的序列化過程本質(zhì)上是把張量的信息,包括數(shù)據(jù)類型和存儲(chǔ)位置以及攜帶的數(shù)據(jù)等轉(zhuǎn)換為字符串,而這些字符串隨后可以使用Python自帶的文件IO函數(shù)進(jìn)行存儲(chǔ)。同樣也可以通過文件IO函數(shù)讀取存儲(chǔ)的字符串然后將字符串逆向解析成PyTorch的模塊和張量。
保存和載入的函數(shù)簽名如下:
torch.save(obj, f, pickle_modeule=pickle, pickle_protocol=2) torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)torch.save函數(shù)傳入的第一個(gè)參數(shù)是PyTorch中可以被序列化的對(duì)象,包括模型和張量等。第二個(gè)參數(shù)是存儲(chǔ)文件的路徑,序列化的結(jié)果將會(huì)被保留在這個(gè)路徑里。第三個(gè)參數(shù)是默認(rèn)的,傳入的是序列化的庫,第四個(gè)參數(shù)是pickle協(xié)議,即如何把對(duì)象轉(zhuǎn)換成字符串的規(guī)范,其協(xié)議版本有0~4版本。
troch.load函數(shù)在給定序列化后的文件路徑以后,就能輸出PyTorch的對(duì)象。其第一個(gè)參數(shù)是文件路徑,第二個(gè)參數(shù)是張量存儲(chǔ)位置的映射,第三個(gè)參數(shù)和torch.save中的一致,最后一個(gè)參數(shù)用來指定傳給pickle_module.load的參數(shù)。
模塊狀態(tài)字典的保存和載入
在PyTorch中一般模型可以由兩種保存方式,第一種是直接保存模型的實(shí)例,第二種是保存模型的狀態(tài)字典,一個(gè)模型的狀態(tài)字典包含模型所有參數(shù)的名字以及名字對(duì)應(yīng)的張量。通過調(diào)用state_lict方法,可以獲取當(dāng)前模型的狀態(tài)字典。
如下代碼所示:
lm = LinearModel(5) # 定義線性模型 lm.state_dict() # 獲取狀態(tài)字典 print(lm.state_dict()) t = lm.state_dict # 保存狀態(tài)字典lm = LinearModel(5) # 重新定義線性模型 lm.state_dict() # 新的狀態(tài)字典 print(lm.state_dict())lm.load_state_dict(t)# 載入原來的狀態(tài)字典 print(lm.state_dict())得到的結(jié)果如下所示:
OrderedDict([('weight', tensor([[ 0.5843],[-1.0206],[-0.7556],[ 1.3406],[ 0.1169]])), ('bias', tensor([-0.2059]))]) OrderedDict([('weight', tensor([[-0.0865],[ 0.1113],[ 0.9502],[-2.3019],[ 0.2588]])), ('bias', tensor([-0.0324]))]) OrderedDict([('weight', tensor([[ 0.5843],[-1.0206],[-0.7556],[ 1.3406],[ 0.1169]])), ('bias', tensor([-0.2059]))])可以看到線性回歸模型實(shí)例反悔了OrderedDict的對(duì)象,即順序字典,其中有兩個(gè)鍵值對(duì),分別對(duì)應(yīng)著權(quán)重和偏置的張量。獲取新的狀態(tài)字典和原來的不同,當(dāng)通過load_state_dict方法傳入舊的狀態(tài)字典讓模型載入?yún)?shù)后,發(fā)現(xiàn)模型的參數(shù)更新為原來的模型參數(shù)。
一般來說推薦使用state_dict方法獲取狀態(tài)字典,然后保存該張量字典來保存模型,這樣可以最大限度地減小代碼對(duì)PyTorch版本的依賴性。
PyTorch數(shù)據(jù)的可視化
TensorBoard的使用
TensorBoard是一個(gè)數(shù)據(jù)可視化工具,能夠直觀地顯示深度學(xué)習(xí)過程中張量的變化,從這個(gè)變化中就可以很容易地了解到模型在訓(xùn)練中的行為,包括但不限于損失函數(shù)的下降趨勢是否合理、張量分量的分布是否在訓(xùn)練中發(fā)生變化以及輸出訓(xùn)練過程中的圖片等等。
這里還是使用博士頓地區(qū)房價(jià)數(shù)據(jù)的線性回歸模型來舉例:
from sklearn.datasets import load_boston from torch.utils.tensorboard import SummaryWriter import torch import torch.nn as nnclass LinearModel(nn.Module):def __init__(self, ndim):super(LinearModel, self).__init__()self.ndim = ndimself.weight = nn.Parameter(torch.randn(ndim, 1))self.bias = nn.Parameter(torch.randn(1))def forward(self, x):return x.mm(self.weight) + self.biasboston = load_boston() lm = LinearModel(13) criterion = nn.MSELoss() optim = torch.optim.SGD(lm.parameters(), lr = 1e-6) data = torch.tensor(boston["data"], requires_grad=True, dtype=torch.float32) target = torch.tensor(boston["target"], dtype=torch.float32) writer = SummaryWriter() # 定義TensorBoard輸出類for step in range(10000):predict = lm(data)loss = criterion(predict, target)writer.add_scalar("Loss/train", loss, step) # 輸出損失函數(shù)writer.add_histogram("Param/weight", lm.weight, step) # 輸出權(quán)重直方圖writer.add_histogram("Param/bias", lm.bias, step) # 輸出偏置直方圖if step and step % 1000 == 0:print("Loss:{:.3f}".format(loss.item()))optim.zero_grad()loss.backward()optim.step()這里相比于之前增加了SummaryWriter的構(gòu)造函數(shù),在構(gòu)造一個(gè)摘要寫入器的實(shí)例以后,可以調(diào)用實(shí)例的方法來添加需要寫入摘要的張量信息。這里主要寫入了一個(gè)標(biāo)量數(shù)據(jù)和兩個(gè)直方圖數(shù)據(jù)。
通過運(yùn)行訓(xùn)練的代碼,在運(yùn)行10000個(gè)epoch之后,可以發(fā)現(xiàn)在當(dāng)前目錄下多了以文件夾runs,runs下面有一個(gè)文件夾,具體的文件夾名字與訓(xùn)練開始時(shí)間、用戶主機(jī)名稱有關(guān)。
接下來可以運(yùn)行tensorboard --logdir ./runs命令,發(fā)現(xiàn)Tensorboard的服務(wù)器已經(jīng)啟動(dòng),可以通過瀏覽器訪問http://127.0.0.1:6006,查看Tensorboard網(wǎng)頁界面,如圖所示:
從圖中可以看出TensorBoard的界面可以顯示很多值,比如SCALARS、DISTRIBUTIONS、HISTOGRAMS等。在默認(rèn)情況下,TensorBoard只顯示寫入數(shù)據(jù)類型的幾個(gè)標(biāo)簽,這里主要是add_scalar產(chǎn)生的SCALARS和add_histogram產(chǎn)生的DISTRIBUTIONS和HISTOGRAMS標(biāo)簽。
SCLARS圖像主要是損失函數(shù)隨著訓(xùn)練步數(shù)變化的曲線。
DISTRIBUTION主要顯示權(quán)重值和偏置的最大和最小的邊界隨著訓(xùn)練步數(shù)的變化過程。
HISTOGRAMS主要顯示權(quán)重和偏置的直方圖隨著訓(xùn)練步數(shù)的變化過程。
除了上面演示的add_scalar和add_histogram方法外,TensorBoard的SummaryWriter還有一系列其他方法添加不同的數(shù)據(jù)到TensorBoard界面中,包括但不限于可以寫入圖片信息的add_image、顯示準(zhǔn)確率召回率曲線的add_pr_curve等等。
總結(jié)
PyTorch通過復(fù)用Python自帶的序列化函數(shù)庫pickle,同時(shí)構(gòu)建了張量和模塊的序列化方法,來實(shí)現(xiàn)深度學(xué)習(xí)模型的保存和載入。深度學(xué)習(xí)模型也可以很方便地輸出和載入當(dāng)前模型參數(shù)的狀態(tài)字典,該狀態(tài)字典和模型的分離也方便不同版本PyTorch訓(xùn)練模型之間的相互兼容。
為了能夠方便地觀察深度學(xué)習(xí)的中間結(jié)果和張量,以及損失函數(shù)的變化情況,PyTorch還集成了TensorBoard相關(guān)的插件,能夠方便地在網(wǎng)頁中對(duì)深度學(xué)習(xí)的模型輸出的中間張量進(jìn)行可視化,也方便了用戶對(duì)深度學(xué)習(xí)模型的調(diào)試和效果評(píng)估。
《新程序員》:云原生和全面數(shù)字化實(shí)踐50位技術(shù)專家共同創(chuàng)作,文字、視頻、音頻交互閱讀總結(jié)
以上是生活随笔為你收集整理的PyTorch模型的保存加载以及数据的可视化的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch中的数据输入和预处理
- 下一篇: GNS3下载安装