用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识
用tensorflow搭建RNN(LSTM)進(jìn)行MNIST 手寫(xiě)數(shù)字辨識(shí)
循環(huán)神經(jīng)網(wǎng)絡(luò)RNN相比傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)在處理序列化數(shù)據(jù)時(shí)更有優(yōu)勢(shì),因?yàn)镽NN能夠?qū)⒓尤肷?#xff08;下)文信息進(jìn)行考慮。一個(gè)簡(jiǎn)單的RNN如下圖所示:
將這個(gè)循環(huán)展開(kāi)得到下圖:
上一時(shí)刻的狀態(tài)會(huì)傳遞到下一時(shí)刻。這種鏈?zhǔn)教匦詻Q定了RNN能夠很好的處理序列化的數(shù)據(jù),RNN 在語(yǔ)音識(shí)別,語(yǔ)言建模,翻譯,圖片描述等問(wèn)題上已經(jīng)取得了很到的結(jié)果。
根據(jù)輸入、輸出的不同和是否有延遲等一些情況,RNN在應(yīng)用中有如下一些形態(tài):
RNN存在的問(wèn)題
RNN能夠把狀態(tài)傳遞到下一時(shí)刻,好像對(duì)一部分信息有記憶能力一樣,如下圖:
h3
的值可能會(huì)由x1,x2的值來(lái)決定。
但是,對(duì)于一些復(fù)雜場(chǎng)景
由于距離太遠(yuǎn),中間間隔了太多狀態(tài),x1,x2對(duì)ht+1
的值幾乎起不到任何作用。(梯度消失和梯度爆炸)
LSTM(Long Short Term Memory)
由于RNN不能很好地處理這種問(wèn)題,于是出現(xiàn)了LSTM(Long Short Term Memory)一種加強(qiáng)版的RNN(LSTM可以改善梯度消失問(wèn)題)。簡(jiǎn)單來(lái)說(shuō)就是原始RNN沒(méi)有長(zhǎng)期的記憶能力,于是就給RNN加上了一些記憶控制器,實(shí)現(xiàn)對(duì)某些信息能夠較長(zhǎng)期的記憶,而對(duì)某些信息只有短期記憶能力。
如上圖所示,LSTM中存在Forget Gate,Input Gate,Output Gate來(lái)控制信息的流動(dòng)程度。
RNN:
LSTN:
加號(hào)圓圈表示線性相加,乘號(hào)圓圈表示用gate來(lái)過(guò)濾信息。
Understanding LSTM中對(duì)LSTM有非常詳細(xì)的介紹。(對(duì)應(yīng)的中文翻譯)
LSTM MNIST手寫(xiě)數(shù)字辨識(shí)
實(shí)際上,圖片文字識(shí)別這類(lèi)任務(wù)用CNN來(lái)做效果更好,但是這里想要強(qiáng)行用LSTM來(lái)做一波。
MNIST_data中每一個(gè)image的大小是28*28,以行順序作為序列輸入,即第一行的28個(gè)像素作為$x_{0}
,第二行為
x_1,...,第28行的28個(gè)像素作為
x_28$輸入,一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)總共的輸入是28個(gè)維度為28的向量,輸出值是10維的向量,表示的是0-9個(gè)數(shù)字的概率值。這是一個(gè)many to one的RNN結(jié)構(gòu)。
下面直接上代碼:
這里outputs,final_state = tf.nn.dynamic_rnn(...).
final_state包含兩個(gè)量,第一個(gè)為c保存了每個(gè)LSTM任務(wù)最后一個(gè)cell中每個(gè)神經(jīng)元的狀態(tài)值,第二個(gè)量h保存了每個(gè)LSTM任務(wù)最后一個(gè)cell中每個(gè)神經(jīng)元的輸出值,所以c和h的維度都是[BATCH_SIZE,NUM_UNITS]。
outputs的維度是[BATCH_SIZE,TIME_STEP,NUM_UNITS],保存了每個(gè)step中cell的輸出值h。
由于這里是一個(gè)many to one的任務(wù),只需要最后一個(gè)step的輸出outputs[:, -1, :],output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES) 通過(guò)一個(gè)全連接層將輸出限制為N_CLASSES。
訓(xùn)練過(guò)程輸出:
train loss: 2.2990 | test accuracy: 0.13 train loss: 0.1347 | test accuracy: 0.96 train loss: 0.0620 | test accuracy: 0.97 train loss: 0.0788 | test accuracy: 0.98 train loss: 0.0160 | test accuracy: 0.98 train loss: 0.0084 | test accuracy: 0.99 train loss: 0.0436 | test accuracy: 0.99 train loss: 0.0104 | test accuracy: 0.98 train loss: 0.0736 | test accuracy: 0.99 train loss: 0.0154 | test accuracy: 0.98 train loss: 0.0407 | test accuracy: 0.98 train loss: 0.0109 | test accuracy: 0.98 train loss: 0.0722 | test accuracy: 0.98 train loss: 0.1133 | test accuracy: 0.98 train loss: 0.0072 | test accuracy: 0.99 train loss: 0.0352 | test accuracy: 0.98可以看到,雖然RNN是擅長(zhǎng)處理序列類(lèi)的任務(wù),在MNIST手寫(xiě)數(shù)字圖片辨識(shí)這個(gè)任務(wù)上,RNN同樣可以取得很高的正確率。
參考:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://yjango.gitbooks.io/superorganism/content/lstmgru.html
參考代碼
https://www.cnblogs.com/sandy-t/p/6930608.html
有些人,一輩子都沒(méi)有得到過(guò)自己想要的,因?yàn)樗麄兛偸前胪径鴱U
總結(jié)
以上是生活随笔為你收集整理的用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: python 格式化输出%和format
- 下一篇: To disable deprecati