对seq2seq的一些个人理解
對(duì)seq2seq的一些個(gè)人理解
原創(chuàng)?2017年05月10日 11:43:25因?yàn)樽霎呍O(shè)用到seq2seq框架,網(wǎng)上關(guān)于seq2seq的資料很多,但關(guān)于seq2seq的代碼則比較少,閱讀tensorflow的源碼則需要跳來(lái)跳去比較麻煩(其實(shí)就是博主懶)。踩了很多坑后,形成了一些個(gè)人的理解,在這里記錄下,如果有人恰好路過(guò),歡迎指出錯(cuò)誤~
seq2seq圖解如下:?
上圖中,C是encoder輸出的最終狀態(tài),作為decoder的初始狀態(tài);W是encoder的最終輸出,作為decoder的初始輸入。
具體到tensorflow代碼中(tensorflow r1.1.0cpu版本),查閱tf.contrib.rnn.BasicLSTMCell的源碼如下:
class BasicLSTMCell(RNNCell):def __init__(self, num_units, forget_bias=1.0,input_size=None, state_is_tuple=True, activation=tanh,reuse=None):super(BasicLSTMCell, self).__init__(_reuse=reuse)if not state_is_tuple:logging.warn("%s: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.", self)if input_size is not None:logging.warn("%s: The input_size parameter is deprecated.", self)self._num_units = num_unitsself._forget_bias = forget_biasself._state_is_tuple = state_is_tupleself._activation = activation@propertydef state_size(self):return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units)@propertydef output_size(self):return self._num_unitsdef call(self, inputs, state):"""Long short-term memory cell (LSTM)."""# Parameters of gates are concatenated into one multiply for efficiency.if self._state_is_tuple:c, h = stateelse:c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)concat = _linear([inputs, h], 4 * self._num_units, True)# i = input_gate, j = new_input, f = forget_gate, o = output_gatei, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))new_h = self._activation(new_c) * sigmoid(o)if self._state_is_tuple:new_state = LSTMStateTuple(new_c, new_h)else:new_state = array_ops.concat([new_c, new_h], 1)return new_h, new_state- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
令調(diào)用LSTM的命令為:
output,state = tf.contrib.rnn.BasicLSTMCell(input,init_state)- 1
可知,state其實(shí)是包含了output在內(nèi)的。state[0]才是真正的state,即圖中的C;state[1]是output,即圖中的W。這樣一來(lái),最后輸出的output其實(shí)就顯得雞肋了。(如果要在encode和decode之間搞事情的話,這點(diǎn)就比較重要了。博主就是踩了這個(gè)坑。。。當(dāng)然如果不在這里搞事情的話就可以完美繞過(guò)這個(gè)坑)
知道這點(diǎn)后,那么接下來(lái)的就好理解多了。博主之前曾有過(guò)一段時(shí)間的疑惑,那就是seq2seq的decode_input到底是什么?如果跟target只是移了一個(gè)位,其他完全不變的話,那要encoder干什么?知道了上面的背景后,我們不難知道,教程中decode_input跟target的移位只是加速訓(xùn)練過(guò)程。而在具體應(yīng)用中,decode_input可以是encode的最后一個(gè)輸出,也可以自己設(shè)定一個(gè)全零的數(shù)組。個(gè)人覺(jué)得設(shè)定全零的數(shù)組比較好,因?yàn)槌跏紶顟B(tài)就已經(jīng)包含了encode的最后一個(gè)輸出了,而且全零數(shù)組可以當(dāng)作是一個(gè)開(kāi)始的標(biāo)識(shí)(至于seq2seq具體的訓(xùn)練過(guò)程可視化,可以閱讀2017年ACL的一篇文章Visualizing and Understanding Neural Machine Translation?http://nlp.csai.tsinghua.edu.cn/~ly/papers/acl2017_dyz.pdf)
最后,還說(shuō)幾點(diǎn)比較零散的:?
1、對(duì)于短句(<30詞),可以不進(jìn)行輸入翻轉(zhuǎn),模型收斂地稍微慢一點(diǎn)而已;對(duì)于長(zhǎng)句則最好進(jìn)行翻轉(zhuǎn)?
2、多閱讀教程,多實(shí)踐。上手操作永遠(yuǎn)是學(xué)習(xí)的最佳途徑
總結(jié)
以上是生活随笔為你收集整理的对seq2seq的一些个人理解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 谷歌开源 tf-seq2seq,你也能用
- 下一篇: tensorflow中的seq2seq例