lstm数学推导_手推公式:LSTM单元梯度的详细的数学推导
長(zhǎng)短期記憶是復(fù)雜和先進(jìn)的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的重要組成部分。本文的主要思想是解釋其背后的數(shù)學(xué)原理,所以閱讀本文之前,建議首先對(duì)LSTM有一些了解。
介紹
上面是單個(gè)LSTM單元的圖表。我知道它看起來(lái)可怕,但我們會(huì)通過(guò)一個(gè)接一個(gè)的文章,希望它會(huì)很清楚。
解釋
基本上一個(gè)LSTM單元有4個(gè)不同的組件。忘記門、輸入門、輸出門和單元狀態(tài)。我們將首先簡(jiǎn)要討論這些部分的使用,然后深入討論數(shù)學(xué)部分。
忘記門
顧名思義,這部分負(fù)責(zé)決定在最后一步中扔掉或保留哪些信息。這是由第一個(gè)s型層完成的。
根據(jù)ht-1(以前的隱藏狀態(tài))和xt(時(shí)間步長(zhǎng)t的當(dāng)前輸入),它為單元格狀態(tài)C_t-1中的每個(gè)值確定一個(gè)介于0到1之間的值。
遺忘門和上一個(gè)狀態(tài)
如果為1,所有的信息保持原樣,如果為0,所有的信息都被丟棄,對(duì)于其他的值,它決定有多少來(lái)自前一個(gè)狀態(tài)的信息被帶入下一個(gè)狀態(tài)。
輸入門
Christopher Olah博客的解釋在輸入門發(fā)生了什么:
下一步是決定在單元格狀態(tài)中存儲(chǔ)什么新信息。這包括兩部分。首先,一個(gè)稱為“輸入門層”的sigmoid層決定我們將更新哪些值。接下來(lái),一個(gè)tanh層創(chuàng)建一個(gè)新的候選值的向量,C~t,可以添加到狀態(tài)中。在下一步中,我們將結(jié)合這兩者來(lái)創(chuàng)建對(duì)狀態(tài)的更新。
現(xiàn)在這兩個(gè)值i。e i_t和c~t結(jié)合決定什么新的輸入是被輸入到狀態(tài)。
單元狀態(tài)
單元狀態(tài)充當(dāng)LSTM的內(nèi)存。這就是它們?cè)谔幚磔^長(zhǎng)的輸入序列時(shí)比普通RNN表現(xiàn)得更好的地方。在每一個(gè)時(shí)間步長(zhǎng),前一個(gè)單元狀態(tài)(Ct-1)與遺忘門結(jié)合,以決定什么信息要被傳送,然后與輸入門(it和c~t)結(jié)合,形成新的單元狀態(tài)或單元的新存儲(chǔ)器。
狀態(tài)的計(jì)算公式
輸出門
最后,LSTM單元必須給出一些輸出。從上面得到的單元狀態(tài)通過(guò)一個(gè)叫做tanh的雙曲函數(shù),因此單元狀態(tài)值在-1和1之間過(guò)濾。
LSTM單元的基本單元結(jié)構(gòu)已經(jīng)介紹完成,繼續(xù)推導(dǎo)在實(shí)現(xiàn)中使用的方程。
推導(dǎo)先決條件
推導(dǎo)方程的核心概念是基于反向傳播、成本函數(shù)和損失。除此以外還假設(shè)您對(duì)高中微積分(計(jì)算導(dǎo)數(shù)和規(guī)則)有基本的了解。
變量:對(duì)于每個(gè)門,我們有一組權(quán)重和偏差,表示為:
Wf,bf->遺忘門的權(quán)重和偏差Wi,bi->輸入門的權(quán)重和偏差Wc,bc->單元狀態(tài)的權(quán)重和偏差Wo,bo->輸出門的權(quán)重和偏差Wv ,bv -> 與Softmax層相關(guān)的權(quán)重和偏差ft, it,ctiledet, o_t -> 輸出使用的激活函數(shù)af, ai, ac, ao -> 激活函數(shù)的輸入J是成本函數(shù),我們將根據(jù)它計(jì)算導(dǎo)數(shù)。注意(下劃線(_)后面的字符是下標(biāo))
前向傳播推導(dǎo)
門的計(jì)算公式
狀態(tài)的計(jì)算公式
以遺忘門為例說(shuō)明導(dǎo)數(shù)的計(jì)算。我們需要遵循下圖中紅色箭頭的路徑。
我們畫出一條從f_t到代價(jià)函數(shù)J的路徑,也就是
ft→Ct→h_t→J。
反向傳播完全發(fā)生在相同的步驟中,但是是反向的
ft←Ct←h_t←J。
J對(duì)ht求導(dǎo),ht對(duì)Ct求導(dǎo),Ct對(duì)f_t求導(dǎo)。
所以如果我們?cè)谶@里觀察,J和ht是單元格的最后一步,如果我們計(jì)算dJ/dht,那么它可以用于像dJ/dC_t這樣的計(jì)算,因?yàn)?
dJ/dCt = dJ/dht * dht/dCt(鏈?zhǔn)椒▌t)
同樣,對(duì)第一點(diǎn)提到的所有變量的導(dǎo)數(shù)也要計(jì)算。
現(xiàn)在我們已經(jīng)準(zhǔn)備好了變量并且清楚了前向傳播的公式,現(xiàn)在是時(shí)候通過(guò)反向傳播來(lái)推導(dǎo)導(dǎo)數(shù)了。我們將從輸出方程開(kāi)始因?yàn)槲覀兛吹皆谄渌匠讨幸彩褂昧送瑯拥膶?dǎo)數(shù)。這時(shí)就要用到鏈?zhǔn)椒▌t了。我們現(xiàn)在開(kāi)始吧。
反向傳播推導(dǎo)
lstm的輸出有兩個(gè)值需要計(jì)算。
Softmax:對(duì)于交叉熵?fù)p失的導(dǎo)數(shù),我們將直接使用最終的方程。
隱藏狀態(tài)是ht。ht是w.r的微分。根據(jù)鏈?zhǔn)椒▌t,推導(dǎo)過(guò)程如下圖所示。
輸出門相關(guān)變量:ao和ot,微分的完整方程如下:
dJ/dVt * dVt/dht * dht/dO_t
dJ/dVt * dVt/dht可以寫成dJ/dht(我們從隱藏狀態(tài)得到這個(gè)值)。
ht的值= ot * tanh(ct) ->所以我們只需要對(duì)ht w.r求導(dǎo)。t o_t。其區(qū)別如下:
同樣,a_o和J之間的路徑也顯示出來(lái)。微分的完整方程如下:
dJ/dVt * dVt/dht * dt /da_o
dJ/dVt * dVt/dht * dht/dOt可以寫成dJ/dOt(我們從上面的o_t得到這個(gè)值)。
Ct是單元的單元狀態(tài)。除此之外,我們還處理候選單元格狀態(tài)ac和c~_t。
Ct的推導(dǎo)很簡(jiǎn)單,因?yàn)閺腃t到J的路徑很簡(jiǎn)單。Ct→ht→Vt→j,因?yàn)槲覀円呀?jīng)有了dJ/dht,我們直接微分ht w.r。t Ct。
ht = ot * tanh(ct) ->所以我們只需要對(duì)ht w.r求導(dǎo)。t C_t。
微分的完整方程如下:
dJ/dht * dht/dCt * dCt/dc~_t
可以將dJ/dht * dht/dCt寫成dJ/dCt(我們?cè)谏厦嬗羞@個(gè)值)。
Ct的值如圖9公式5所示(下圖第3行最后一個(gè)Ct缺少波浪號(hào)(~)符號(hào)->書寫錯(cuò)誤)。所以我們只需要對(duì)C_t w.r求導(dǎo)。t c ~ _t。
ac:如下圖所示為ac到J的路徑。根據(jù)箭頭,微分的完整方程如下:
dJ/dht * dht/dCt * dCt/ da_c
dJ/dht * dht/dCt * dCt/dc_t可以寫成dJ/dc_t(我們?cè)谏厦嬗羞@個(gè)值)。
所以我們只需要對(duì)c~t w.r求導(dǎo)。t ac。
輸入門相關(guān)變量:it和ai
微分的完整方程如下:
dt / dt * dt /dit
可以將dJ/dht * dht/dCt寫入為dJ/dCt(我們?cè)趩卧駹顟B(tài)中有這個(gè)值)。所以我們只需要對(duì)Ct w.r求導(dǎo)。t it。
a_i:微分的完整方程如下:
dJ/dht * dht/dCt * dt /da_i
dJ/dht * dht/dCt * dCt/dit可以寫成dJ/dit(我們?cè)谏厦嬗羞@個(gè)值)。所以我們只需要對(duì)i_t w.r求導(dǎo)。t ai。
遺忘門相關(guān)變量:ft和af
微分的完整方程如下:
dJ/dht * dht/dCt * dCt/df_t
可以將dJ/dht * dht/dCt寫入為dJ/dCt(我們?cè)趩卧駹顟B(tài)中有這個(gè)值)。所以我們只需要對(duì)Ct w.r求導(dǎo)。t ft。
a_f:微分的完整方程如下:
dJ/dht * dht/dCt * dft/da_t
dJ/dht * dht/dCt * dCt/dft可以寫成dJ/dft(我們?cè)谏厦嬗羞@個(gè)值)。所以我們只需要對(duì)ftw.r求導(dǎo)。t af。
Lstm的輸入
每個(gè)單元格i有兩個(gè)與輸入相關(guān)的變量。前一個(gè)單元格狀態(tài)C_t-1和前一個(gè)隱藏狀態(tài)與當(dāng)前輸入連接,即
[ht-1,xt] > Z_t
C_t-1:這是Lstm單元的內(nèi)存。圖5顯示了單元格狀態(tài)。c - t-1的推導(dǎo)很簡(jiǎn)單因?yàn)橹挥衏 - t和c - t。
Zt:如下圖所示,Zt進(jìn)入四個(gè)不同的路徑,af,ai,ao,ac。
Zt→af→ft→Ct→h_t→J。- >遺忘門
Zt→ai→it→Ct→h_t→J。- >輸入門
Zt→ac→c~t→Ct→h_t→J。->單元狀態(tài)
Zt→ao→ot→Ct→h_t→J。- >輸出門
權(quán)重和偏差
W和b的推導(dǎo)很簡(jiǎn)單。下面的推導(dǎo)是針對(duì)Lstm的輸出門的。對(duì)于其余的門,對(duì)權(quán)重和偏差也進(jìn)行了類似的處理。
輸入和遺忘門的權(quán)重和偏差
輸出和輸出門的權(quán)重和偏差
J/dWf = dJ/daf。daf / dWf ->遺忘門
dJ/dWi = dJ/dai。dai / dWi ->輸入門
dJ/dWv = dJ/dVtdVt/ dWv ->輸出門
dJ/dWo = dJ/dao。dao / dWo ->輸出門
我們完成了所有的推導(dǎo)。但是有兩點(diǎn)需要強(qiáng)調(diào)
到目前為止,我們所做的只是一個(gè)時(shí)間步長(zhǎng)。現(xiàn)在我們要讓它只進(jìn)行一次迭代。
所以如果我們有總共T個(gè)時(shí)間步長(zhǎng),那么每一個(gè)時(shí)間步長(zhǎng)的梯度會(huì)在T個(gè)時(shí)間步長(zhǎng)結(jié)束時(shí)相加,所以每次迭代結(jié)束時(shí)的累積梯度為:
每次迭代結(jié)束時(shí)的累積梯度用來(lái)更新權(quán)重
總結(jié)
LSTM是非常復(fù)雜的結(jié)構(gòu),但它們工作得非常好。具有這種特性的RNN主要有兩種類型:LSTM和GRU。
訓(xùn)練LSTMs也是一項(xiàng)棘手的任務(wù),因?yàn)橛性S多超參數(shù),而正確地組合通常是一項(xiàng)困難的任務(wù)。
作者:Rahuljha
deephub翻譯組
總結(jié)
以上是生活随笔為你收集整理的lstm数学推导_手推公式:LSTM单元梯度的详细的数学推导的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 怎么pe安装windows7 如何在PE
- 下一篇: u盘转为ntfs读不出来了怎么办 u盘转