nn.functional 和 nn.Module入门讲解
本文來自《20天吃透Pytorch》
一,nn.functional 和 nn.Module
前面我們介紹了Pytorch的張量的結(jié)構(gòu)操作和數(shù)學(xué)運(yùn)算中的一些常用API。
利用這些張量的API我們可以構(gòu)建出神經(jīng)網(wǎng)絡(luò)相關(guān)的組件(如激活函數(shù),模型層,損失函數(shù))。
Pytorch和神經(jīng)網(wǎng)絡(luò)相關(guān)的功能組件大多都封裝在 torch.nn模塊下。
這些功能組件的絕大部分既有函數(shù)形式實(shí)現(xiàn),也有類形式實(shí)現(xiàn)。
其中nn.functional(一般引入后改名為F)有各種功能組件的函數(shù)實(shí)現(xiàn)。例如:
(激活函數(shù)) * F.relu * F.sigmoid * F.tanh * F.softmax
(模型層) * F.linear * F.conv2d * F.max_pool2d * F.dropout2d * F.embedding
(損失函數(shù)) * F.binary_cross_entropy * F.mse_loss * F.cross_entropy
為了便于對參數(shù)進(jìn)行管理,一般通過繼承 nn.Module 轉(zhuǎn)換成為類的實(shí)現(xiàn)形式,并直接封裝在 nn 模塊下。例如:
(激活函數(shù)) * nn.ReLU * nn.Sigmoid * nn.Tanh * nn.Softmax
(模型層) * nn.Linear * nn.Conv2d * nn.MaxPool2d * nn.Dropout2d * nn.Embedding
(損失函數(shù)) * nn.BCELoss * nn.MSELoss * nn.CrossEntropyLoss
二,使用nn.Module來管理參數(shù)
在Pytorch中,模型的參數(shù)是需要被優(yōu)化器訓(xùn)練的,因此,通常要設(shè)置參數(shù)為 requires_grad = True 的張量。
同時(shí),在一個(gè)模型中,往往有許多的參數(shù),要手動管理這些參數(shù)并不是一件容易的事情。
Pytorch一般將參數(shù)用nn.Parameter來表示,并且用nn.Module來管理其結(jié)構(gòu)下的所有參數(shù)。
三,使用nn.Module來管理子模塊
實(shí)際上nn.Module除了可以管理其引用的各種參數(shù),還可以管理其引用的子模塊,功能十分強(qiáng)大。
一般情況下,我們都很少直接使用 nn.Parameter來定義參數(shù)構(gòu)建模型,而是通過一些拼裝一些常用的模型層來構(gòu)造模型。
這些模型層也是繼承自nn.Module的對象,本身也包括參數(shù),屬于我們要定義的模塊的子模塊。
nn.Module提供了一些方法可以管理這些子模塊。
children() 方法: 返回生成器,包括模塊下的所有子模塊。
named_children()方法:返回一個(gè)生成器,包括模塊下的所有子模塊,以及它們的名字。
modules()方法:返回一個(gè)生成器,包括模塊下的所有各個(gè)層級的模塊,包括模塊本身。
named_modules()方法:返回一個(gè)生成器,包括模塊下的所有各個(gè)層級的模塊以及它們的名字,包括模塊本身。
其中chidren()方法和named_children()方法較多使用。
modules()方法和named_modules()方法較少使用,其功能可以通過多個(gè)named_children()的嵌套使用實(shí)現(xiàn)。
i = 0 for child in net.children():i+=1print(child,"\n") print("child number",i) i = 0 for name,child in net.named_children():i+=1print(name,":",child,"\n") print("child number",i) i = 0 for module in net.modules():i+=1print(module) print("module number:",i)下面我們通過named_children方法找到embedding層,并將其參數(shù)設(shè)置為不可訓(xùn)練(相當(dāng)于凍結(jié)embedding層)。
children_dict = {name:module for name,module in net.named_children()}print(children_dict) embedding = children_dict["embedding"] embedding.requires_grad_(False) #凍結(jié)其參數(shù)總結(jié)
以上是生活随笔為你收集整理的nn.functional 和 nn.Module入门讲解的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 梦到半个西瓜是什么意思
- 下一篇: Dataset和DataLoader构建