一文详解pytorch的“动态图”与“自动微分”技术
前言
眾所周知,Pytorch是一個(gè)非常流行且深受好評(píng)的深度學(xué)習(xí)訓(xùn)練框架。這與它的兩大特性“動(dòng)態(tài)圖”、“自動(dòng)微分”有非常大的關(guān)系。“動(dòng)態(tài)圖”使得pytorch的調(diào)試非常簡(jiǎn)單,每一個(gè)步驟,每一個(gè)流程都可以被我們精確的控制、調(diào)試、輸出。甚至是在每個(gè)迭代都能夠重構(gòu)整個(gè)網(wǎng)絡(luò)。這在其他基于靜態(tài)圖的訓(xùn)練框架中是非常不方便處理的。在靜態(tài)圖的訓(xùn)練框架中,必須先構(gòu)建好整個(gè)網(wǎng)絡(luò),然后開(kāi)始訓(xùn)練。如果想在訓(xùn)練過(guò)程中輸出中間節(jié)點(diǎn)的數(shù)據(jù)或者是想要改變一點(diǎn)網(wǎng)絡(luò)的結(jié)構(gòu),就需要非常復(fù)雜的操作,甚至是不可實(shí)現(xiàn)的。而“自動(dòng)微分”技術(shù)使得在編寫深度學(xué)習(xí)網(wǎng)絡(luò)的時(shí)候,只需要實(shí)現(xiàn)算子的前向傳播即可,無(wú)需像caffe那樣對(duì)同一個(gè)算子需要同時(shí)實(shí)現(xiàn)前向傳播和反向傳播。由于反向傳播一般比前向傳播要復(fù)雜,并且手動(dòng)推導(dǎo)反向傳播的時(shí)候很容易出錯(cuò),所以“自動(dòng)微分”能夠極大的節(jié)約勞動(dòng)力,提升效率。
動(dòng)態(tài)圖
用過(guò)caffe或者tensorflow的同學(xué)應(yīng)該知道,在訓(xùn)練之前需要構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò),caffe里面使用配置文件prototxt來(lái)進(jìn)行描述,tensorflow中使用python代碼來(lái)描述。訓(xùn)練之前,框架都會(huì)有一個(gè)解析和構(gòu)建神經(jīng)網(wǎng)絡(luò)的過(guò)程。構(gòu)建完了之后再進(jìn)行數(shù)據(jù)讀取和訓(xùn)練。在訓(xùn)練過(guò)程中網(wǎng)絡(luò)一般是不會(huì)變的,所以叫做“靜態(tài)圖”。想要獲取中間變量的輸出,可以是可以,就是比較麻煩一些,caffe使用c++訓(xùn)練的話,需要獲取layer的top,然后打印,tensorflow需要通過(guò)session來(lái)獲取。但是如果想要控制網(wǎng)絡(luò)的運(yùn)行,比如讓網(wǎng)絡(luò)停在某一個(gè)OP之后,這是很難做到的。即無(wú)法精確的控制網(wǎng)絡(luò)運(yùn)行的每一步,只能等網(wǎng)絡(luò)運(yùn)行完了,然后通過(guò)相關(guān)的接口去獲取相關(guān)的數(shù)據(jù)。而pytorch的“動(dòng)態(tài)圖”機(jī)制就可以對(duì)網(wǎng)絡(luò)實(shí)現(xiàn)非常精確的控制。在pytorch運(yùn)行之前,不會(huì)去創(chuàng)建所謂的神經(jīng)網(wǎng)絡(luò),這完全由python代碼定義的forward函數(shù)來(lái)描述。即我們手工編寫的forward函數(shù)就是pytorch前向運(yùn)行的動(dòng)態(tài)圖。當(dāng)代碼執(zhí)行到哪一句的時(shí)候,網(wǎng)絡(luò)就運(yùn)行到哪一步。所以當(dāng)你對(duì)forward函數(shù)進(jìn)行調(diào)試,斷點(diǎn),修改的時(shí)候,神經(jīng)網(wǎng)絡(luò)也就被相應(yīng)的調(diào)試、中斷和修改了。也就是說(shuō)pytorch的forwad代碼就是神經(jīng)網(wǎng)絡(luò)的執(zhí)行流,或者說(shuō)就是pytorch的“動(dòng)態(tài)圖”。對(duì)forward的控制就是對(duì)神經(jīng)網(wǎng)絡(luò)的控制。如下圖所示:
正因?yàn)檫@樣的實(shí)現(xiàn)機(jī)制,使得對(duì)神經(jīng)網(wǎng)絡(luò)的調(diào)試可以像普通python代碼那樣進(jìn)行調(diào)試,非常的方便和友好。并且可以在任何時(shí)候,修改網(wǎng)絡(luò)的結(jié)構(gòu),這就是動(dòng)態(tài)圖的好處。
自動(dòng)微分
上面的動(dòng)態(tài)圖詳解了pytorch如何構(gòu)建前向傳播的動(dòng)態(tài)神經(jīng)網(wǎng)絡(luò)的,實(shí)際上pytorch并沒(méi)有顯式的去構(gòu)建一個(gè)所謂的動(dòng)態(tài)圖,本質(zhì)就是按照f(shuō)orward的代碼執(zhí)行流程走了一遍而已。那么對(duì)于反向傳播,因?yàn)槲覀儧](méi)有構(gòu)建反向傳播的代碼,pytorch也就無(wú)法像前向傳播那樣,通過(guò)我們手動(dòng)編寫的代碼執(zhí)行流進(jìn)行反向傳播。那么pytorch是如何實(shí)現(xiàn)精確的反向傳播的呢?其實(shí)最大的奧秘就藏在tensor的grad_fn屬性里面。有的同學(xué)可能在調(diào)試pytorch代碼的時(shí)候已經(jīng)不經(jīng)意的遇到過(guò)這個(gè)grad_fn屬性。如下圖所示:
Pytorch中的tensor對(duì)象都有一個(gè)叫做grad_fn的屬性,它實(shí)際上是一個(gè)鏈表,實(shí)現(xiàn)在pytorch源碼的autograd下面。該屬性記錄了該tensor是如何由前一個(gè)tensor產(chǎn)生的。在深入探究grad_fn之前,先來(lái)了解一下pytroch中的leaf tensor和非leaf tensor。
?
Leaf/非leaf tensor:
Pytorch中的tensor有兩種產(chǎn)生方式,一種是憑空創(chuàng)建的,例如一些op里面的params,訓(xùn)練的images,這些tensor,他們不是由其他tensor計(jì)算得來(lái)的,而是通過(guò)torch.zeros(),torch.ones(),torch.from_numpy()等憑空創(chuàng)建出來(lái)的。另外一種產(chǎn)生方式是由某一個(gè)tensor經(jīng)過(guò)一個(gè)op計(jì)算得到,例如tensor a通過(guò)conv計(jì)算得到tensor b。其實(shí)這兩種op創(chuàng)建方式對(duì)應(yīng)的就是leaf節(jié)點(diǎn)(葉子節(jié)點(diǎn))和非leaf(非葉子節(jié)點(diǎn))。如下圖所示,為一個(gè)cnn網(wǎng)絡(luò)中的leaf節(jié)點(diǎn)和非leaf節(jié)點(diǎn)。黃色的節(jié)點(diǎn)對(duì)應(yīng)的tensor就是憑空生成的,是leaf節(jié)點(diǎn);藍(lán)色的tensor就是通過(guò)其他tensor計(jì)算得來(lái)的,是非leaf節(jié)點(diǎn)。那么顯而易見(jiàn),藍(lán)色的非leaf節(jié)點(diǎn)的grad_fn是有值的,因?yàn)樗奶荻刃枰^續(xù)向后傳播給創(chuàng)建它的那個(gè)節(jié)點(diǎn)。而黃色的leaf節(jié)點(diǎn)的grad_fn為None,因?yàn)樗麄儾皇怯善渌?jié)點(diǎn)創(chuàng)建而來(lái),他們的梯度不需要繼續(xù)反向傳播。
深究grad_fn:
grad_fn是python層的封裝,其實(shí)現(xiàn)對(duì)應(yīng)的就是pytorch源碼在autograd下面的node對(duì)象,為C++實(shí)現(xiàn),如下圖所示:
node其實(shí)是一個(gè)鏈表,有一個(gè)next_edges_屬性,里面存儲(chǔ)著指向下一級(jí)的所有node。注意它不是一個(gè)簡(jiǎn)單的單向鏈表,因?yàn)楹芏鄑ensor可能是由多個(gè)tensor創(chuàng)建來(lái)的。例如tensor a = tensor b + tensor c. 那么tensor a的grad_fn屬性里面的next_edges就會(huì)有兩個(gè)指針,分別指向tensor b和tensor c的grad_fn屬性。在python層,next_edges_屬性被封裝成了next_functions。因此正確的說(shuō)法是:tensor a的grad_fn屬性里面的next_ functions,指向了tensor b和tensor c的grad_fn屬性。其實(shí)有了這個(gè)完整的鏈表,就已經(jīng)完整的表達(dá)了反向傳播的計(jì)算圖。就可以完成完整的反向傳播了。 下面我們通過(guò)一個(gè)小例子來(lái)進(jìn)一步說(shuō)明grad_fn是如何表達(dá)反向傳播計(jì)算圖的。首先我們定義一個(gè)非常簡(jiǎn)單的網(wǎng)絡(luò):僅有兩個(gè)conv層,一個(gè)relu層,一個(gè)pool層,如下圖所示(conv層帶有參數(shù)weights和bias):
對(duì)應(yīng)的代碼片段如下所示:
class TinyCnn(torch.nn.Module):def __init__(self, arg_dict={}):super(TinyCnn, self).__init__()self.conv = torch.nn.Conv2d(3, 3, kernel_size=2, stride=2)self.relu = torch.nn.ReLU(inplace=True)self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)def forward(self, images):conv_out = self.conv(images)relu_out = self.relu(conv_out)pool_out = self.pool(relu_out)return pool_outcnn = TinyCnn() loss_fun = torch.nn.BCELoss() images = torch.rand((1,3,4,4)) labels = torch.rand((1,3,1,1)) preds = cnn(images) loss = loss_fun(preds, labels) loss.backward()那么當(dāng)代碼執(zhí)行到loss = loss_fun(preds, labels),我們看看loss的grad_fn以及其對(duì)應(yīng)的next_functions:
可以看到loss的grad_fn為:<BinaryCrossEntropyBackward object at 0x000001A07E079FC8>,而它的next_functions為:(<MaxPool2DWithIndicesBackward object at 0x000001A07E08BC88>, 0),繼續(xù)跟蹤MaxPool2DWithIndicesBackward的nex_functions為:(<ReluBackward1 object at 0x000001A07E079B88>, 0),如果繼續(xù)跟蹤下去,整個(gè)反向傳播的計(jì)算圖就非常的直觀了,使用下圖表示:
Images由于是葉子節(jié)點(diǎn),且不需要求梯度,因此ThnnConv2DBackward的第一個(gè)next_functions對(duì)應(yīng)的是None。而conv中的weights和bias雖然也是葉子節(jié)點(diǎn),但是需要求梯度,因此增加了一個(gè)AccumulateGrad的方法,表示可累計(jì)梯度,實(shí)際上就是對(duì)weights和bias的梯度的保存。
grad_fn是如何生成的?
有了上面的介紹,其實(shí)大家應(yīng)該已經(jīng)大致了解了pytorch自動(dòng)微分的大致流程。實(shí)際上是通過(guò)tensor的gran_fn來(lái)組織的,grad_fn本質(zhì)上是一個(gè)鏈表,指向下一級(jí)別的tensor的grad_fn,因此通過(guò)這樣一個(gè)鏈表構(gòu)成了一個(gè)完整的反向計(jì)算的動(dòng)態(tài)圖。那么最后有一個(gè)問(wèn)題就是tensor的grad_fn是如何構(gòu)建的?無(wú)論是我們自己編寫的上層代碼,還是在pytorch底層的op實(shí)現(xiàn)里面,并沒(méi)有顯示的去創(chuàng)建grad_fn,那么它是在何時(shí),又是如何組裝的?實(shí)際上通過(guò)編譯pytorch源碼就能發(fā)現(xiàn)端倪。Pytorch會(huì)對(duì)所有底層算子進(jìn)一個(gè)二次封裝,在做完正常的op前向之后,增加了grad_fn的設(shè)置,next_functions的設(shè)置等流程。如下圖所示為原始卷積的前向流程和經(jīng)過(guò)pytroch自動(dòng)封裝的卷積前向計(jì)算流程對(duì)比。可以明顯的看到多了一些對(duì)grad_fn設(shè)置的代碼。
后記
以上流程就是pytorch的“動(dòng)態(tài)圖”與“自動(dòng)微分”的核心邏輯。基于pytorch1.6.0源碼分析,由于作者才疏學(xué)淺,且涉獵范圍有限,難免有所錯(cuò)誤,如果有不對(duì)的地方,還請(qǐng)見(jiàn)諒并指正。
?
總結(jié)
以上是生活随笔為你收集整理的一文详解pytorch的“动态图”与“自动微分”技术的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 像素密度推高到 5000 PPI!麻省理
- 下一篇: [多图/秒懂]白话OpenPose,最受