日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 >

【小白学PyTorch】6.模型的构建访问遍历存储(附代码)

發布時間:2025/3/8 41 豆豆
生活随笔 收集整理的這篇文章主要介紹了 【小白学PyTorch】6.模型的构建访问遍历存储(附代码) 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

<<小白學PyTorch>>

小白學PyTorch | 5 torchvision預訓練模型與數據集全覽

小白學PyTorch | 4 構建模型三要素與權重初始化

小白學PyTorch | 3 淺談Dataset和Dataloader

小白學PyTorch | 2 淺談訓練集驗證集和測試集

小白學PyTorch | 1 搭建一個超簡單的網絡

小白學PyTorch | 動態圖與靜態圖的淺顯理解

文章目錄:

  • 1 模型構建函數

    • 1.1 add_module

    • 1.2 ModuleList

    • 1.3 Sequential

    • 1.4 小總結

  • 2 遍歷模型結構

    • 2.1 modules()

    • 2.2 named_modules()

    • 2.3 parameters()

  • 3 保存與載入

本文是對一些函數的學習。函數主要包括下面四個方便:

  • 模型構建的函數:add_module,add_module,add_module

  • 訪問子模塊:add_module,add_module,add_module,add_module

  • 網絡遍歷:add_module,add_module

  • 模型的保存與加載:add_module,add_module,add_module

1 模型構建函數

torch.nn.Module是所有網絡的基類,在PyTorch實現模型的類中都要繼承這個類(這個在之前的課程中已經提到)。在構建Module中,Module是一個包含其他的Module的,類似于,你可以先定義一個小的網絡模塊,然后把這個小模塊作為另外一個網絡的組件。因此網絡結構是呈現樹狀結構

我們先簡單定義一個網絡:

import?torch.nn?as?nn import?torch? class?MyNet(nn.Module):def?__init__(self):super(MyNet,self).__init__()self.conv1?=?nn.Conv2d(3,64,3)self.conv2?=?nn.Conv2d(64,64,3)def?forward(self,x):x?=?self.conv1(x)x?=?self.conv2(x)return?x net?=?MyNet() print(net)

輸出結果:MyNet中有兩個屬性conv1和conv2是兩個卷積層,在正向傳播forward的過程中,依次調用這兩個卷積層實現網絡的功能。

1.1 add_module

這種是最常見的定義網絡的功能,在有些項目中,會看到這樣的方法add_module。我們用這個方法來重寫上面的網絡:

class?MyNet(nn.Module):def?__init__(self):super(MyNet,self).__init__()self.add_module('conv1',nn.Conv2d(3,64,3))self.add_module('conv2',nn.Conv2d(64,64,3))def?forward(self,x):x?=?self.conv1(x)x?=?self.conv2(x)return?x

其實add_module(name,layer)和self.name=layer實現了相同的功能,個人感覺也許是因為add_module可以使用字符串來定義變量名字,所以可以放在循環中?反正這個先了解熟悉熟悉

上面的兩種方法都是一層一層的添加layer,如果網絡復雜的話,那就需要寫很多重復的代碼了。因此接下來來講解一下網絡模塊的構建,torch.nn.ModuleList和torch.nn.Sequential

1.2 ModuleList

ModuleList按照字面意思是用list的形式保存網絡層的。這樣就可以先將網絡需要的layer構建好,保存到一個list,然后通過ModuleList方法添加到網絡中.

class?MyNet(nn.Module):def?__init__(self):super(MyNet,self).__init__()self.linears?=?nn.ModuleList([nn.Linear(10,10)?for?i?in?range(5)])def?forward(self,x):for?l?in?self.linears:x?=?l(x)return?x net?=?MyNet() print(net)

輸出結果是:

這個ModuleList主要是用在讀取config文件來構建網絡模型中的,下面用VGG模型的構建為例子:

vgg_cfg?=?[64,?64,?'M',?128,?128,?'M',?256,?256,?256,?'C',?512,?512,?512,?'M',512,?512,?512,?'M']def?vgg(cfg,?i,?batch_norm=False):layers?=?[]in_channels?=?ifor?v?in?cfg:if?v?==?'M':layers?+=?[nn.MaxPool2d(kernel_size=2,?stride=2)]elif?v?==?'C':layers?+=?[nn.MaxPool2d(kernel_size=2,?stride=2,?ceil_mode=True)]else:conv2d?=?nn.Conv2d(in_channels,?v,?kernel_size=3,?padding=1)if?batch_norm:layers?+=?[conv2d,?nn.BatchNorm2d(v),?nn.ReLU(inplace=True)]else:layers?+=?[conv2d,?nn.ReLU(inplace=True)]in_channels?=?vreturn?layersclass?Model1(nn.Module):def?__init__(self):super(Model1,self).__init__()self.vgg?=?nn.ModuleList(vgg(vgg_cfg,3))def?forward(self,x):for?l?in?self.vgg:x?=?l(x) m1?=?Model1() print(m1)

先讀取網絡結構的配置文件vgg_cfg然后根據這個文件創建對應的Layer list,然后使用ModuleList添加到網絡中,這樣可以快速創建不同的網絡(用上面為例子的話,可以通過修改配置文件,然后快速修改網絡結構

1.3 Sequential

在一些自己做的小項目中,Sequential其實用的更為頻繁。依然重寫最初最簡單的例子:

class?MyNet(nn.Module):def?__init__(self):super(MyNet,self).__init__()self.conv?=?nn.Sequential(nn.Conv2d(3,64,3),nn.Conv2d(64,64,3))def?forward(self,x):x?=?self.conv(x)return?x net?=?MyNet() print(net)

運行結果:

觀察細致的朋友可以發現這個問題,Seqential內的網絡層是默認用數字進行標號的,而一開始我們使用self.conv1和self.conv2的時候,使用conv1和conv2作為標號的。

我們如何修改Sequential中網絡層的名稱呢?這里需要使用到collections.OrderedDict有序字典。Sequential是支持有序字典構建的。

from?collections?import?OrderedDict? class?MyNet(nn.Module):def?__init__(self):super(MyNet,self).__init__()self.conv?=?nn.Sequential(OrderedDict([('conv1',nn.Conv2d(3,64,3)),('conv2',nn.Conv2d(64,64,3))]))def?forward(self,x):x?=?self.conv(x)return?x net?=?MyNet() print(net)

輸出結果:

1.4 小總結

  • 單獨增加一個網絡層或者子模塊,可以用add_module或者直接賦予屬性;

  • ModuleList可以將一個Module的List增加到網絡中,自由度較高。

  • Sequential按照順序產生一個Module模塊。這里推薦習慣使用OrderedDict的方法進行構建。對網絡層加上規范的名稱,這樣有助于后續查找與遍歷

2 遍歷模型結構

本章節使用下面的方法進行遍歷之前提到的Module。(個人理解,Module是多個layer的合并,但是一個layer可以說成Module。 ) 先定義一個網絡吧,隨便寫一個:

import?torch.nn?as?nn import?torch? from?collections?import?OrderedDict class?MyNet(nn.Module):def?__init__(self):super(MyNet,self).__init__()self.conv1?=?nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)self.conv2?=?nn.Conv2d(64,64,3)self.maxpool1?=?nn.MaxPool2d(2,2)self.features?=?nn.Sequential(OrderedDict([('conv3',?nn.Conv2d(64,128,3)),('conv4',?nn.Conv2d(128,128,3)),('relu1',?nn.ReLU())]))def?forward(self,x):x?=?self.conv1(x)x?=?self.conv2(x)x?=?self.maxpool1(x)x?=?self.features(x)return?x net?=?MyNet() print(net)

輸出結果是:

2.1 modules()

在第四課中初始化模型各個層的參數的時候,用到了這個方法,現在我們再來理解一下:

for?idx,m?in?enumerate(net.modules()):print(idx,"-",m)

運行結果:

上面那個網絡構建的時候用到了Sequential,所以網絡中其實是嵌套了一個小的Module,這就是之前提到的樹狀結構,然后上面便利的時候也是樹狀結構的便利過程,可以看出來應該是一個深度遍歷的過程。

  • 首先第一個輸出的是最大的那個Module,也就是整個網絡,0-Model整個網絡模塊;

  • 1-2-3-4是網絡的四個子模塊,4-Sequential中間仍然包含子模塊

  • 5-6-7是模塊4-Sequential的子模塊。

【總結】

modules()是遞歸的返回網絡的各個module(深度遍歷),從最頂層直到最后的葉子的module。

2.2 named_modules()

named_modules()和module()類似,只是同時返回name和module。

for?idx,(name,m)?in?enumerate(net.named_modules()):print(idx,"-",name)

輸出結果:

2.3 parameters()

for?p?in?net.parameters():print(type(p.data),p.size())

運行結果:

輸出的是四個卷積層的權重矩陣參數和偏置參數。值得一提的是,對網絡進行訓練時需要將parameters()作為優化器optimizer的參數。

optimizer?=?torch.optim.SGD(net.parameters(),lr?=?0.001,momentum=0.9)

總之呢,這個parameters()是返回網絡所有的參數,主要用在給optimizer優化器用的。而要對網絡的某一層的參數做處理的時候,一般還是使用named_parameters()方便一些。

for?idx,(name,m)?in?enumerate(net.named_parameters()):print(idx,"-",name,m.size())

輸出結果:

【小擴展】

我個人有時會使用下面的方法來獲取參數:

for?idx,(name,m)?in?enumerate(net.named_modules()):if?isinstance(m,nn.Conv2d):print(m.weight.shape)print(m.bias.shape)

先判斷是否是卷積層,然后獲取其參數,輸出結果:

3 保存與載入

PyTorch使用torch.save和torch.load方法來保存和加載網絡,而且網絡結構和參數可以分開的保存和加載。

torch.save(model,'model.pth')?#?保存 model?=?torch.load("model.pth")?#?加載

pytorch中網絡結構和模型參數是可以分開保存的。上面的方法是兩者同時保存到了.pth文件中,當然,你也可以僅僅保存網絡的參數來減小存儲文件的大小。注意:如果你僅僅保存模型參數,那么在載入的時候,是需要通過運行代碼來初始化模型的結構的。

torch.save(model.state_dict(),"model.pth")?#?保存參數 model?=?MyNet()?#?代碼中創建網絡結構 params?=?torch.load("model.pth")?#?加載參數 model.load_state_dict(params)?#?應用到網絡結構中

至此,我們今天已經學習了不少的內容,大家對PyTorch的掌握更近一步了呢~

- END -

往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯獲取一折本站知識星球優惠券,復制鏈接直接打開:https://t.zsxq.com/662nyZF本站qq群1003271085。加入微信群請掃碼進群(如果是博士或者準備讀博士請說明):

總結

以上是生活随笔為你收集整理的【小白学PyTorch】6.模型的构建访问遍历存储(附代码)的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。