RNN神经网络产生梯度消失和梯度爆炸的原因及解决方案
1、RNN模型結(jié)構(gòu)
循環(huán)神經(jīng)網(wǎng)絡(luò)RNN(Recurrent Neural Network)會(huì)記憶之前的信息,并利用之前的信息影響后面結(jié)點(diǎn)的輸出。也就是說,循環(huán)神經(jīng)網(wǎng)絡(luò)的隱藏層之間的結(jié)點(diǎn)是有連接的,隱藏層的輸入不僅包括輸入層的輸出,還包括上時(shí)刻隱藏層的輸出。下圖為RNN模型結(jié)構(gòu)圖:
2、RNN前向傳播算法
RNN前向傳播公式為:
其中:
St為t時(shí)刻的隱含層狀態(tài)值;
Ot為t時(shí)刻的輸出值;
①是隱含層計(jì)算公式,U是輸入x的權(quán)重矩陣,St-1是t-1時(shí)刻的狀態(tài)值,W是St-1作為輸入的權(quán)重矩陣,$\Phi $是激活函數(shù);
②是輸出層計(jì)算公司,V是輸出層的權(quán)重矩陣,f是激活函數(shù)。
損失函數(shù)(loss function)采用交叉熵$L_{t}=-\overline{o_{t}}logo_{_{t}}$(Ot是t時(shí)刻預(yù)測(cè)輸出,$\overline{o_{t}}$是t時(shí)刻正確的輸出)
那么對(duì)于一次訓(xùn)練任務(wù)中,損失函數(shù)$L=\sum_{i=1}^{T}-\overline{o_{t}}logo_{_{t}}$, T是序列總長(zhǎng)度。
假設(shè)初始狀態(tài)St為0,t=3 有三段時(shí)間序列時(shí),由 ① 帶入②可得到
t1、t2、t3 各個(gè)狀態(tài)和輸出為:
t=1:
狀態(tài)值:$s_{1}=\Phi (Ux_{1}+Ws_{0})$
輸出:$o_{1}=f(V\Phi (Ux_{1}+Ws_{0}))$
t=2:
狀態(tài)值:$s_{2}=\Phi (Ux_{2}+Ws_{1})$
輸出:$o_{2}=f(V\Phi (Ux_{2}+Ws_{1}))=f(V\Phi (Ux_{2}+W\Phi(Ux_{1}+Ws_{0})))$
t=3:
狀態(tài)值:$s_{3}=\Phi (Ux_{3}+Ws_{2})$
輸出:$o_{3}=f(V\Phi (Ux_{3}+Ws_{2}))=\cdots =f(V\Phi (Ux_{3}+W\Phi(Ux_{2}+W\Phi(Ux_{1}+Ws_{0}))))$
3、RNN反向傳播算法
BPTT(back-propagation through time)算法是針對(duì)循層的訓(xùn)練算法,它的基本原理和BP算法一樣。其算法本質(zhì)還是梯度下降法,那么該算法的關(guān)鍵就是計(jì)算各個(gè)參數(shù)的梯度,對(duì)于RNN來說參數(shù)有 U、W、V。
反向傳播
現(xiàn)對(duì)t=3時(shí)刻的U、W、V求偏導(dǎo),由鏈?zhǔn)椒▌t得到:
可以簡(jiǎn)寫成:
觀察③④⑤式,可知,對(duì)于 V 求偏導(dǎo)不存在依賴問題;但是對(duì)于 W、U 求偏導(dǎo)的時(shí)候,由于時(shí)間序列長(zhǎng)度,存在長(zhǎng)期依賴的情況。主要原因可由 t=1、2、3 的情況觀察得 , St會(huì)隨著時(shí)間序列向前傳播,同時(shí)St是 U、W 的函數(shù)。
前面得出的求偏導(dǎo)公式⑥,取其中累乘的部分出來,其中激活函數(shù)Φ 通常是tanh函數(shù) ,則
4、梯度爆炸和梯度消失的原因
激活函數(shù)tanh和它的導(dǎo)數(shù)圖像如下:
由上圖可知當(dāng)激活函數(shù)是tanh函數(shù)時(shí),tanh函數(shù)的導(dǎo)數(shù)最大值為1,又不可能一直都取1這種情況,實(shí)際上這種情況很少出現(xiàn),那么也就是說,大部分都是小于1的數(shù)在做累乘,若當(dāng)t很大的時(shí)候,$\prod_{j=k-1}^{t}tan{h}'W$中的$\prod_{j=k-1}^{t}tan{h}'$趨向0,舉個(gè)例子:0.850=0.00001427247也已經(jīng)接近0了,這是RNN中梯度消失的原因。
再看⑦部分:
$\prod_{j=k-1}^{3}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{3}tan{h}'W$
如果參數(shù) W 中的值太大,隨著序列長(zhǎng)度同樣存在長(zhǎng)期依賴的情況,$\prod_{j=k-1}^{t}tan{h}'W$中的$\prod_{j=k-1}^{t}tan{h}'$趨向于無窮,那么產(chǎn)生問題就是梯度爆炸。
在平時(shí)運(yùn)用中,RNN比較深,使得梯度爆炸或者梯度消失問題會(huì)比較明顯。
5、解決梯度爆炸和梯度消失的方案
1)采使用ReLu激活函數(shù)
面對(duì)梯度消失問題,可以采用ReLu作為激活函數(shù),下圖為ReLu函數(shù):
ReLU函數(shù)在定義域大于0部分的導(dǎo)數(shù)恒等于1,這樣可以解決梯度消失的問題,(雖然恒等于1很容易發(fā)生梯度爆炸的情況,但可通過設(shè)置適當(dāng)?shù)拈撝悼山鉀Q)。
另外計(jì)算方便,計(jì)算速度快,可以加速網(wǎng)絡(luò)訓(xùn)練。但是,定義域負(fù)數(shù)部分恒等于零,這樣會(huì)造成神經(jīng)元無法激活(可通過合理設(shè)置學(xué)習(xí)率,降低發(fā)生的概率)。
ReLU有優(yōu)點(diǎn)也有缺點(diǎn),其中的缺點(diǎn)可以通過其他操作取避免或者減低發(fā)生的概率,是目前使用最多的激活函數(shù)。
還可以通過更改內(nèi)部結(jié)構(gòu)來解決梯度消失和梯度爆炸問題,那就是LSTM了。
2)使用長(zhǎng)短記憶網(wǎng)絡(luò)LSTM
使用長(zhǎng)短期記憶(LSTM)單元和相關(guān)的門類型神經(jīng)元結(jié)構(gòu)可以減少梯度爆炸和梯度消失問題,LSTM的經(jīng)典圖為:
可以抽象為:
三個(gè)×分別代表的就是forget gate,input gate,output gate,而我認(rèn)為L(zhǎng)STM最關(guān)鍵的就是forget gate這個(gè)部件。這三個(gè)gate是如何控制流入流出的呢,其實(shí)就是通過下面ft,it,ot三個(gè)函數(shù)來控制,因?yàn)?\sigma (x)$代表sigmoid函數(shù)) 的值是介于0到1之間的,剛好用趨近于0時(shí)表示流入不能通過gate,趨近于1時(shí)表示流入可以通過gate。
$f_{t}=\sigma (W_{f}X_{t}+b_{f})$
$i_{t}=\sigma (W_{i}X_{t}+b_{i})$
$o_{t}=\sigma (W_{o}X_{t}+b_{o})$
LSTM當(dāng)前的狀態(tài)值為:$S_{t}=f_{t}S_{t-1}+i_{t}X_{t}$,表達(dá)式展開后得:
$S_{t}=\sigma (W_{f}X_{t}+b_{f})S_{t-1}+\sigma (W_{i}X_{t}+b_{i})X_{t}$
如果加上激活函數(shù):
$S_{t}=tanh[\sigma (W_{f}X_{t}+b_{f})S_{t-1}+\sigma (W_{i}X_{t}+b_{i})X_{t}]$
上文中講到傳統(tǒng)RNN求偏導(dǎo)的過程包含:
$\prod_{j=k-1}^{t}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{t}tan{h}'W$
對(duì)于LSTM同樣也包含這樣的一項(xiàng),但是在LSTM中為:
$\prod_{j=k-1}^{t}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{t}tan{h}'\sigma (W_{f}X_{t}+b_{f})$
假設(shè)$Z=tanh'(x)\sigma (y)$,則Z的函數(shù)圖像如下圖所示:
可以看到該函數(shù)值基本上不是0就是1。
傳統(tǒng)RNN的求偏導(dǎo)過程:
$\frac{\sigma L_{3}}{\sigma W}=\sum_{k=0}^{t}\frac{\partial L_{3}}{\partial o_{3}}\frac{\partial o_{3}}{\partial s_{3}}(\prod_{j=k-1}^{3}\frac{\partial s_{j}}{\partial s_{j-1}})\frac{\partial s_{k}}{\partial W}$
如果在LSTM中上式可能就會(huì)變成:
$\frac{\sigma L_{3}}{\sigma W}=\sum_{k=0}^{t}\frac{\partial L_{3}}{\partial o_{3}}\frac{\partial o_{3}}{\partial s_{3}}\frac{\partial s_{k}}{\partial W}$
因?yàn)?\prod_{j=k-1}^{3}\frac{\partial s_{j}}{\partial s_{j-1}}=\prod_{j=k-1}^{3}tan{h}'\sigma (W_{f}X_{t}+b_{f})\approx 0|1$,這樣解決了傳統(tǒng)RNN中梯度消失的問題。
參考
https://www.jiqizhixin.com/articles/2019-01-17-7
https://zhuanlan.zhihu.com/p/28687529
總結(jié)
以上是生活随笔為你收集整理的RNN神经网络产生梯度消失和梯度爆炸的原因及解决方案的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。