深度學(xué)習(xí)之循環(huán)神經(jīng)網(wǎng)絡(luò)(4)RNN層使用方法 1. SimpleRNNCell 2. 多層SimpleRNNCell網(wǎng)絡(luò) 3. SimpleRNN層
?在介紹完循環(huán)神經(jīng)網(wǎng)絡(luò)的算法原理之后,我們來(lái)學(xué)習(xí)如何在TensorFlow中實(shí)現(xiàn)RNN層。在TensorFlow中,可以通過(guò)layers.SimpleRNNCell來(lái)完成
σ(Wxhxt+Whhht?1+b)σ(\boldsymbol W_{xh} \boldsymbol x_t+\boldsymbol W_{hh} \boldsymbol h_{t-1}+\boldsymbol b) σ ( W x h ? x t ? + W h h ? h t ? 1 ? + b ) 的計(jì)算。需要注意的是,在TensorFlow中,RNN表示通用意義上的循環(huán)神經(jīng)網(wǎng)絡(luò),對(duì)于我們目前介紹的基礎(chǔ)循環(huán)神經(jīng)網(wǎng)絡(luò),它一般叫做
SimpleRNN 。SimpleRNN與SimpleRNNCell的區(qū)別在于,帶Cell的層僅僅是完成了一個(gè)時(shí)間戳的前向運(yùn)算,不帶Cell的層一般是基于Cell層實(shí)現(xiàn)的,它在內(nèi)部已經(jīng)完成了多個(gè)時(shí)間戳的循環(huán)運(yùn)算,因此使用起來(lái)更為方便快捷。
?我們先來(lái)介紹SimpleRNNCell的使用方法,再介紹SimpleRNN層的使用方法。
1. SimpleRNNCell
?以某輸入特征長(zhǎng)度n=4n=4 n = 4 (即輸入的詞由幾個(gè)位組成,例如[x1,x2,x3,x4][x_1,x_2,x_3,x_4] [ x 1 ? , x 2 ? , x 3 ? , x 4 ? ] 代表一個(gè)詞,那么這個(gè)詞輸入的特征長(zhǎng)度n=4n=4 n = 4 ),Cell狀態(tài)向量特征長(zhǎng)度為h=3h=3 h = 3 (即每個(gè)時(shí)間戳tt t 上狀態(tài)張量h\boldsymbol h h 的特征長(zhǎng)度,其中ht=σ(Wxhxt+Whhht?1+b)\boldsymbol h_t=σ(\boldsymbol W_{xh} \boldsymbol x_t+\boldsymbol W_{hh} \boldsymbol h_{t-1}+\boldsymbol b) h t ? = σ ( W x h ? x t ? + W h h ? h t ? 1 ? + b ) )為例,首先我們新建一個(gè)SimpleRNNCell,不需要指定序列長(zhǎng)度ss s ,代碼如下:
from tensorflow.keras import layerscell = layers.SimpleRNNCell(3) # 創(chuàng)建RNN Cell,內(nèi)存向量長(zhǎng)度為3
cell.build(input_shape=(None, 4)) # 輸出特征長(zhǎng)度n=4
print(cell.trainable_variables) # 打印Wxh,Whh,b張量
運(yùn)行結(jié)果如下圖所示:
可以看到,SimpleRNNCell內(nèi)部維護(hù)了3個(gè)張量,kernel變量即Wxh\boldsymbol W_{xh} W x h ? 張量,recurrent_kernel變量即Whh\boldsymbol W_{hh} W h h ? 張量,bias變量即偏置b\boldsymbol b b 向量。但是RNN的Memory向量h\boldsymbol h h 并不由SimpleRNNCell維護(hù),需要用戶(hù)自行初始化向量h0\boldsymbol h_0 h 0 ? 并記錄每個(gè)時(shí)間戳上的ht\boldsymbol h_t h t ? 。 ?通過(guò)調(diào)用Cell實(shí)例即可完成前向運(yùn)算: ot,[ht]=Cell(xt,[ht?1])\boldsymbol o_t,[\boldsymbol h_t]=\text{Cell}(\boldsymbol x_t,[\boldsymbol h_{t-1}]) o t ? , [ h t ? ] = Cell ( x t ? , [ h t ? 1 ? ] ) 對(duì)于SimpleRNNCell來(lái)說(shuō),ot=ht\boldsymbol o_t=\boldsymbol h_t o t ? = h t ? ,并沒(méi)有經(jīng)過(guò)額外的線性層轉(zhuǎn)換,是同一個(gè)對(duì)象; [ht][\boldsymbol h_t ] [ h t ? ] 通過(guò)一個(gè)List包裹起來(lái),這么設(shè)置是為了與LSTM、GRU等RNN變種格式統(tǒng)一。在循環(huán)神經(jīng)網(wǎng)絡(luò)的初始化階段,狀態(tài)向量h0\boldsymbol h_0 h 0 ? 一般初始化為全0向量,例如:
import tensorflow as tf
from tensorflow.keras import layers# 初始化狀態(tài)向量,用列表包裹,統(tǒng)一格式
h0 = [tf.zeros([4, 64])]
x = tf.random.normal([4, 80, 100]) # 生成輸入張量,4個(gè)80單詞的句子
xt = x[:, 0, :] # 所有句子的第1個(gè)單詞
# 構(gòu)建輸入特征n=100,序列長(zhǎng)度s=80,狀態(tài)長(zhǎng)度h=64的Cell
cell = layers.SimpleRNNCell(64)
out, h1 = cell(xt, h0) # 前向計(jì)算
print(out.shape, h1[0].shape)
運(yùn)行結(jié)果如下圖所示:
可以看到經(jīng)過(guò)一個(gè)時(shí)間戳的計(jì)算后,輸出張量的shape都為[b,h][b,h] [ b , h ] ,打印出這兩者的id如下:
print(id(out), id(h1[0]))
運(yùn)行結(jié)果如下圖所示:
兩者id一致,即狀態(tài)向量直接作為輸出向量。對(duì)于長(zhǎng)度為ss s 的訓(xùn)練來(lái)說(shuō),需要循環(huán)通過(guò)Cell類(lèi)ss s 次才算完成一次網(wǎng)絡(luò)層的前向運(yùn)算。例如:
h = h0 # h保存每個(gè)時(shí)間戳上的狀態(tài)向量表
# 在序列長(zhǎng)度的維度解開(kāi)輸入,得到xt:[b,n]
for xt in tf.unstack(x, axis=1):out, h = cell(xt, h) # 前向計(jì)算,out和h均被覆蓋
# 最終輸出可以聚合每個(gè)時(shí)間戳上的輸出,也可以只取最后時(shí)間戳的輸出
out = out
print(out)
注: stack與unstack操作詳見(jiàn)深度學(xué)習(xí)(12)TensorFlow高階操作一: 合并與分割 運(yùn)行結(jié)果如下所示:
tf.Tensor(
[[-0.2677562 0.48381227 0.6812192 -0.9811414 -0.95568067 0.040094060.9091274 -0.99737924 -0.2262151 0.63944376 -0.9013501 0.99265766-0.09092428 -0.73986536 0.93987006 0.23288447 0.94647026 -0.93396217-0.98536897 0.813241 0.2766947 -0.25673908 -0.8504294 -0.459959980.7178784 -0.01069952 -0.35384497 -0.7301667 -0.42860696 -0.91951180.96424794 0.93540084 -0.8629409 0.54582363 -0.8481167 -0.88403730.998077 0.7212096 -0.31695995 -0.33156037 0.9733765 0.470587730.9242944 -0.9082541 -0.52866167 0.9714778 0.00706163 -0.22000399-0.22981001 0.35692227 0.9605445 0.73061293 0.7366635 0.16476776-0.19073665 0.8935988 0.88425654 0.6517266 0.43205526 -0.50979210.35872322 -0.42525575 0.4747447 -0.82216126][ 0.7601129 -0.84973663 0.07257108 0.4074115 0.85890645 0.40316352-0.49802104 -0.46189487 -0.97344846 0.33110482 0.22007078 -0.6415040.9584318 0.48941532 0.9487777 -0.4180974 0.403612 -0.888372960.08162808 0.7211986 0.41622642 -0.7644256 -0.9502167 0.7591871-0.76903707 0.54298973 -0.746649 0.8129116 0.4728199 -0.94986810.3234566 -0.0890093 0.24190407 -0.9862916 -0.8878334 0.367466330.7691656 -0.42555293 -0.9808859 0.38541916 -0.9439697 0.337625320.08059914 -0.85767996 0.4056216 0.20410602 0.14420648 -0.13230170.30473053 -0.49009833 -0.93254864 -0.24999127 0.37115836 -0.99332243-0.36034808 0.20640826 -0.7829328 -0.9780473 0.9820612 0.855272-0.38713187 0.4631712 -0.85671836 0.93773407][-0.8166884 0.9315958 0.9914151 -0.02406353 0.969364 0.94645980.00994941 0.8504445 0.94859165 -0.16112864 -0.6662656 -0.346213160.09543993 0.99852633 0.82953227 0.13884324 0.31297988 -0.87489945-0.2261714 -0.538083 -0.89584523 0.6099533 0.37234947 0.48815438-0.99152416 -0.4111157 0.54102963 0.04263455 0.88183767 0.7196480.9528599 0.83965695 0.9976097 0.18376695 -0.7623534 -0.964800830.6696029 -0.98376185 -0.5559587 0.00449213 -0.6537535 0.81219536-0.12517768 -0.2835646 0.59366983 -0.05620956 -0.93757176 -0.93908745-0.8291558 0.49092057 -0.5600199 0.8819002 0.9935199 -0.034000450.8780934 -0.95364624 0.8695547 -0.7339648 0.8402839 -0.81569680.5867556 0.91071755 -0.90019256 0.35353023][ 0.22378804 -0.8577115 -0.9898991 -0.13579184 -0.62707126 0.75127320.97475743 -0.32653475 0.10074215 0.37040377 0.9023177 0.9916534-0.07671987 0.9395568 -0.9829578 0.8298786 -0.89936143 0.241889730.66845345 -0.98227394 0.9421198 0.06484365 -0.6897533 -0.8032981-0.9555406 0.89325416 -0.998938 0.6184817 0.7335308 0.80259580.47380012 -0.57215804 0.78067404 0.9241867 -0.35744458 -0.9898047-0.9867912 0.8348301 0.63395554 -0.9371648 0.64949805 0.999461230.9775342 -0.7392784 0.9140785 -0.27143526 -0.99955875 0.94143620.2668406 0.35915926 -0.82338697 0.9371646 -0.8107094 -0.97396994-0.9549606 -0.2685872 0.61806554 -0.94509935 -0.98313683 0.95650950.23148598 0.81570596 -0.5302515 0.63490754]], shape=(4, 64), dtype=float32)
最后一個(gè)時(shí)間戳的輸出變量out將作為網(wǎng)絡(luò)的最終輸出。實(shí)際上,也可以將每個(gè)時(shí)間戳上的輸出保存,然后求和或者均值,將其作為網(wǎng)絡(luò)的最終輸出。
2. 多層SimpleRNNCell網(wǎng)絡(luò)
?和卷積神經(jīng)網(wǎng)絡(luò)一樣,循環(huán)神經(jīng)網(wǎng)絡(luò)雖然在時(shí)間軸上面展開(kāi)了多次,但只能算一個(gè)網(wǎng)絡(luò)層。通過(guò)在深度方向堆疊多個(gè)Cell類(lèi)來(lái)實(shí)現(xiàn)深層卷積神經(jīng)網(wǎng)絡(luò)一樣的效果,大大的提升網(wǎng)絡(luò)的表達(dá)能力。但是和卷積神經(jīng)網(wǎng)絡(luò)動(dòng)輒幾十、上百的深度層數(shù)來(lái)比,循環(huán)神經(jīng)網(wǎng)絡(luò)很容易出現(xiàn)梯度彌散和梯度爆炸的現(xiàn)象,深層的循環(huán)神經(jīng)網(wǎng)絡(luò)訓(xùn)練起來(lái)非常困難,目前常見(jiàn)的循環(huán)神經(jīng)網(wǎng)絡(luò)層數(shù)一般控制在十層以?xún)?nèi)。 ?我們這里以?xún)蓪拥难h(huán)神經(jīng)網(wǎng)絡(luò)為例,介紹利用Cell方式構(gòu)建多層RNN網(wǎng)絡(luò)。首先新建兩個(gè)SimpleRNNCell單元,代碼如下:
import tensorflow as tf
from tensorflow.keras import layersx = tf.random.normal([4, 80, 100])
xt = x[:, 0, :] # 取第一個(gè)時(shí)間戳的輸入x0
# 構(gòu)建2個(gè)Cell,先cell0,后cell1,內(nèi)存狀態(tài)張量長(zhǎng)度都為64
cell0 = layers.SimpleRNNCell(64)
cell1 = layers.SimpleRNNCell(64)
h0 = [tf.zeros([4, 64])] # cell0的初始狀態(tài)向量
h1 = [tf.zeros([4, 64])] # cell1的初始狀態(tài)向量
在時(shí)間軸上面循環(huán)計(jì)算多次來(lái)實(shí)現(xiàn)整個(gè)網(wǎng)絡(luò)的前向計(jì)算,每個(gè)時(shí)間戳上的輸入xt首先通過(guò)第一層,得到輸出out0,再通過(guò)第二層,得到輸出out1,代碼如下:
for xt in tf.unstack(x, axis=1):# xt作為輸入,輸出為0out0, h0 = cell0(xt, h0)# 上一個(gè)cell的輸出out0作為本cell的輸入out1, h1 = cell1(out0, h1)print(out0.shape, out1.shape)
運(yùn)行結(jié)果如下所示:
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
(4, 64) (4, 64)
共80行,即共80個(gè)時(shí)間戳上的out0和out1。 上述方式先完成一個(gè)時(shí)間戳上的輸入在所有層上的傳播,再循環(huán)計(jì)算所有時(shí)間戳上的輸入。 ?實(shí)際上,也可以先完成輸入在第一層上所有時(shí)間戳的計(jì)算,并保存第一層在所有時(shí)間戳上的輸出列表,再計(jì)算第二層、第三層等的傳播。代碼如下:
# 保存上一層的所有時(shí)間戳上面的輸出
middle_sequences = []
# 計(jì)算第一層的所有時(shí)間戳上的輸出,并保存
for xt in tf.unstack(x, axis=1):out0, h0 = cell0(xt, h0)middle_sequences.append(out0)print(out0.shape)
# 計(jì)算第二層的所有時(shí)間戳上的輸出
# 如果不是末層,需要保存所有時(shí)間戳上面的輸出
for xt in middle_sequences:out1, h1 = cell1(out0, h1)print(out1.shape)
使用這種方式的話,我們需要一個(gè)額外的List來(lái)保存上一層所有時(shí)間戳上面的狀態(tài)信息: middle_sequences.append(out0)。這兩種方式效果相同,可以根據(jù)個(gè)人喜好選擇編程風(fēng)格。 ?需要注意的是,循環(huán)神經(jīng)網(wǎng)絡(luò)的每一層、每一個(gè)時(shí)間戳上面均由狀態(tài)輸出,那么對(duì)于后續(xù)任務(wù)來(lái)說(shuō),我們應(yīng)該收集哪些狀態(tài)輸出最有效呢?一般來(lái)說(shuō),最末層Cell的狀態(tài)有可能保存了高層的全局語(yǔ)義特征,因此一般使用最末層的輸出作為后續(xù)任務(wù)網(wǎng)絡(luò)的輸入。更特別地,每層最后一個(gè)時(shí)間戳上的狀態(tài)輸出包含了整個(gè)序列的全局信息,如果只希望選用一個(gè)狀態(tài)變量來(lái)完成后續(xù)任務(wù),比如情感分類(lèi)問(wèn)題,一般選用最末層、最末時(shí)間戳的狀態(tài)輸出最為合適。
3. SimpleRNN層
?通過(guò)SimpleRNNCell層的使用,我們可以非常深入地理解循環(huán)神經(jīng)網(wǎng)絡(luò)前向運(yùn)算的每個(gè)細(xì)節(jié),但是在實(shí)際使用中,為了簡(jiǎn)便,不希望手動(dòng)參與循環(huán)神經(jīng)網(wǎng)絡(luò)內(nèi)部的計(jì)算過(guò)程,比如每一層的h狀態(tài)向量的初始化,以及每一層在時(shí)間軸上展開(kāi)的運(yùn)算。通過(guò)SimpleRNN層蓋層接口可以非常方便地幫助我們實(shí)現(xiàn)此目的。 ?比如我們要完成單層循環(huán)神經(jīng)網(wǎng)絡(luò)的前向計(jì)算,可以方便地實(shí)現(xiàn)如下:
import tensorflow as tf
from tensorflow.keras import layerslayer = layers.SimpleRNN(64) # 創(chuàng)建狀態(tài)張量長(zhǎng)度為64的SimpleRNN層
x = tf.random.normal([4, 80, 100])
out = layer(x) # 和普通卷積網(wǎng)絡(luò)一樣,一行代碼即可獲得輸出
print(out.shape)
運(yùn)行結(jié)果如下圖所示:
可以看到,通過(guò)SimpleRNN可以?xún)H需一行代碼即可完成整個(gè)前向運(yùn)算過(guò)程,它默認(rèn)返回最后一個(gè)時(shí)間戳上的輸出。 ?如果希望返回所有時(shí)間戳上的輸出列表,可以設(shè)置return_sequences=True參數(shù),代碼如下:
import tensorflow as tf
from tensorflow.keras import layers# 創(chuàng)建RNN時(shí),設(shè)置返回所有時(shí)間戳上的輸出
layer = layers.SimpleRNN(64, return_sequences=True) # 創(chuàng)建狀態(tài)張量長(zhǎng)度為64的SimpleRNN層
x = tf.random.normal([4, 80, 100])
out = layer(x) # 前向計(jì)算
print(out.shape) # 輸出,自動(dòng)進(jìn)行了concat操作
運(yùn)行結(jié)果如下圖所示: 可以看到,返回的輸出張量shape為[4,80,64][4,80,64] [ 4 , 8 0 , 6 4 ] ,中間維度的80即為時(shí)間戳維度。同樣的,對(duì)于多層循環(huán)神經(jīng)網(wǎng)絡(luò),我們可以通過(guò)堆疊多個(gè)SimpleRNN實(shí)現(xiàn),如兩層的網(wǎng)絡(luò),用法和普通的網(wǎng)絡(luò)類(lèi)似。例如:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequentialnet = keras.Sequential([ # 構(gòu)建2層RNN網(wǎng)絡(luò)# 除最末層外,都需要返回所有時(shí)間戳的輸出,用作下一層的輸入layers.SimpleRNN(64, return_sequences=True),layers.SimpleRNN(64),
])
x = tf.random.normal([4, 80, 100])
out = net(x) # 前向計(jì)算
print(out.shape) # 輸出,自動(dòng)進(jìn)行了concat操作
運(yùn)行結(jié)果如下圖所示:
每層都需要上一層的每個(gè)時(shí)間戳上面的狀態(tài)輸出,因此除了最末層以外,所有的RNN層都需要返回每個(gè)時(shí)間戳上面的狀態(tài)輸出,通過(guò)設(shè)置return_sequences=True來(lái)實(shí)現(xiàn)。可以看到,使用SimpleRNN層,與卷積神經(jīng)網(wǎng)絡(luò)的用法類(lèi)似,非常簡(jiǎn)潔和高效。
創(chuàng)作挑戰(zhàn)賽 新人創(chuàng)作獎(jiǎng)勵(lì)來(lái)咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)
總結(jié)
以上是生活随笔 為你收集整理的深度学习之循环神经网络(4)RNN层使用方法 的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
如果覺(jué)得生活随笔 網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔 推薦給好友。