理解transformer
文章目錄
- 1 注意力機(jī)制
- 2 自注意力機(jī)制
- 3 自注意力機(jī)制加強(qiáng)版
- 4 Transformer的結(jié)構(gòu)
- 4.1 input
- 4.2 encoder
- 4.2.1 Multi-head attention
- 4.2.2 殘差鏈接
- 4.2.3 層正則化layer norm
- 4.2.4 前饋神經(jīng)網(wǎng)絡(luò) feed forward network
- 4.3 decoder
- 4.3.1 輸入
- 4.3.1 Masked Multi-head attention
- 4.3.2 Multi-head attention
- 4.3.3 前饋神經(jīng)網(wǎng)絡(luò) feed forward network
- 4.4 The Final Linear and Softmax Layer
- 4.5 decoder總結(jié)
- 4.4 The Loss Function
- 4.5 代碼閱讀
1 注意力機(jī)制
在LSTM模型中,將encoder最后一個(gè)時(shí)間步的隱狀態(tài)作為decoder的輸入,這會(huì)導(dǎo)致一個(gè)問(wèn)題:會(huì)丟失很多前面的信息。畢竟隱狀態(tài)的維度是有限的,能承載的信息也是有限的。距離越遠(yuǎn),丟的信息越多。注意力機(jī)制的提出就是用來(lái)解決這個(gè)問(wèn)題的。
注意力機(jī)制是在seq2seq模型中提出的。注意力機(jī)制是seq2seq帶給nlp最好的禮物。
我們希望h1,h2,h3,h4h_1,h_2,h_3,h_4h1?,h2?,h3?,h4?都能參與到解碼器的計(jì)算中。這樣就需要給他們分配一個(gè)權(quán)重,那這個(gè)權(quán)重怎么學(xué)習(xí)呢?
我們把編碼器的輸出變量h1,h2,h3,h4h_1,h_2,h_3,h_4h1?,h2?,h3?,h4?稱為value。
把解碼器的每一步輸出hi′h_i'hi′?稱為query。
使用(query,key)對(duì)計(jì)算權(quán)重。key就是h1,h2,h3,h4h_1,h_2,h_3,h_4h1?,h2?,h3?,h4?。
上一步得到的權(quán)重與value(這里仍然是h1,h2,h3,h4h_1,h_2,h_3,h_4h1?,h2?,h3?,h4?)加權(quán)平均得到注意力輸出。
這樣的模式可以推廣到所有需要注意力模型的地方。
value通常是上一個(gè)階段的輸出,key和value是一樣的。
query通常是另外的一個(gè)向量。
使用 (query,key) 進(jìn)行加權(quán)求和,得到的值作為權(quán)重,再對(duì)value進(jìn)行加權(quán)求和。
這樣的話在decoder階段,每一步的輸入會(huì)有3個(gè)變量:編碼器經(jīng)過(guò)注意力加權(quán)后的輸出h(這個(gè)時(shí)候的query是hi?1′h_{i-1}'hi?1′?),解碼器上一步的隱狀態(tài)hi?1′h_{i-1}'hi?1′?,上一步的輸出yi?1y_{i-1}yi?1?。
Bahdanau是注意力機(jī)制的一種計(jì)算方法,也是現(xiàn)在很多工具包中的實(shí)現(xiàn)方法。
當(dāng)前處于解碼器中的第i步
在解碼器中上一步的隱狀態(tài)si?1s_{i-1}si?1?,上一步的輸出yi?1y_{i-1}yi?1?,這一步的上下文向量cic_ici?
在編碼器中的最后輸出的隱狀態(tài)為h1h_1h1?,h2h_2h2?,h3h_3h3?,h4h_4h4?
為了計(jì)算cic_ici?,需要使用注意力機(jī)制來(lái)解決。
value : h1h_1h1?,h2h_2h2?,h3h_3h3?,h4h_4h4?
query : si?1s_{i-1}si?1?
key : h1h_1h1?,h2h_2h2?,h3h_3h3?,h4h_4h4?
我們會(huì)使用(query,key)計(jì)算value的權(quán)重。對(duì)于其中第j個(gè)權(quán)重的計(jì)算方式是這樣的:
eij=a(si?1,hj)e_{ij}=a(s_{i-1,h_j})eij?=a(si?1,hj??) # 這里是將兩個(gè)向量拼接
αij=exp(eij)∑j=1Txexp(eik)\alpha_{ij}=\dfrac{exp(e_{ij})}{\sum_{j=1}^{T_x}exp(e_{ik})}αij?=∑j=1Tx??exp(eik?)exp(eij?)? #這里會(huì)保證權(quán)重和為1。這里Tx=4{T_x}=4Tx?=4
ci=∑j=1Txαijhjc_i=\sum_{j=1}^{T_x}\alpha_{ij}h_jci?=∑j=1Tx??αij?hj? #計(jì)算得到上下文向量,Tx=4{T_x}=4Tx?=4
si=f(si?1,yi?1,ci)s_i=f(s_{i-1},y_{i-1},c_i)si?=f(si?1?,yi?1?,ci?) #得到第i步的隱狀態(tài)
P(yi∣y1,y2...yi?1,X)=g(yi?1,si,ci)P(y_i|y_1,y_2...y_{i-1},X) = g(y_{i-1},s_i,c_i)P(yi?∣y1?,y2?...yi?1?,X)=g(yi?1?,si?,ci?) #得到第i步的輸出
2 自注意力機(jī)制
在上面的介紹中 (query,key)計(jì)算得到一個(gè)權(quán)重。這里query是不同于key的向量。自注意力機(jī)制中query是key的一部分。就是說(shuō)key通過(guò)自己注意自己學(xué)習(xí)到權(quán)重。所以稱為自注意力機(jī)制。
例如在學(xué)習(xí)x2x_2x2?的權(quán)重參數(shù)值時(shí)候,使用x2x_2x2?作為query,x1,x3,x4x_1,x_3,x_4x1?,x3?,x4?作為key和value。
對(duì)每個(gè)位置都計(jì)算得到權(quán)重參數(shù),然后加權(quán)平均得到y2y_2y2?。
同理y3,y4,y1y_3,y_4,y_1y3?,y4?,y1?的計(jì)算也是一樣。
dkd_kdk?是embedding的維度。在歸一化之前會(huì)對(duì)每一個(gè)分?jǐn)?shù)除以定值(embedding的維度開(kāi)根號(hào))。這樣可以讓softmax的分布更加平滑。
3 自注意力機(jī)制加強(qiáng)版
增強(qiáng)版的自注意力機(jī)制是
1 不使用x2x_2x2?作為query,而是先對(duì)x2x_2x2?做線性變換:Wqx2W_qx_2Wq?x2?,之后的向量作query。
2 x1,x3,x4x_1,x_3,x_4x1?,x3?,x4?不直接作為key和value,而是先做線性變換之后再做key和value。Wkx1W_kx_1Wk?x1?作為key,Wvx1W_vx_1Wv?x1?作為value。
其余步驟相同。
這樣的模型有更多的參數(shù),模型性能也更加強(qiáng)大。
4 Transformer的結(jié)構(gòu)
以下內(nèi)容會(huì)部分來(lái)自于The Illustrated Transformer【譯】
了解了注意力機(jī)制的變遷之后,我們?cè)賮?lái)看transformer結(jié)構(gòu)。Transformer是在"Attention is All You Need"中提出的。這是一篇刷爆朋友圈的論文。因?yàn)樗男Ч诂F(xiàn)有效果有了較大幅度的提升。
transformer與之前一些結(jié)構(gòu)的不同在于:
- 雙向LSTM:一個(gè)模型想要包含當(dāng)前位置的信息,前一個(gè)位置的信息,后一個(gè)位置的信息
- CNN:一個(gè)位置包含的信息取決于kernel size大小
- transformer:可以得到全局信息
transformer 是由input、encoder、decoder和output四部分組成的。
encoder組件由6層首尾相連的encoder組成。decoder組件是由6層decoder組成。
4.1 input
transformer模型的輸入由詞向量以及位置編碼兩部分組成。
詞向量是使用word-piece。數(shù)據(jù)集是英-法 WMT 2014,包含36M 句子,這些句子被分為 32000 word-piece 詞匯。每個(gè)詞匯使用dmodel=512d_{model}=512dmodel?=512來(lái)表示。
每個(gè)位置都定義了一個(gè)encoding。 在transformer中一直在做加權(quán)平均,沒(méi)有前后順序,這就會(huì)成為bag of words。
在這里有些位置用sin,有些位置用cos,表示位置信息。每個(gè)位置的encoding是什么樣子并不重要。重要的是每個(gè)位置的encoding不一樣
位置信息encoding之后 與 詞向量相加,也就是 embed(word) + embed(position),整體作為輸入送入到encoder。embed(position)的位置也是512。
按照偶數(shù)位sin,奇數(shù)位cos的方式,得到的結(jié)果確實(shí)是i,j越接近,pm.pnp_m.p_npm?.pn?越大。相對(duì)位置越遠(yuǎn),點(diǎn)乘的結(jié)果越?。
4.2 encoder
6個(gè)encoder結(jié)構(gòu)完全相同,但是參數(shù)不共享。
4.2.1 Multi-head attention
多頭注意力機(jī)制是transformer模型中的重要改進(jìn)。這個(gè)模型使用的是自注意力機(jī)制加強(qiáng)版。這部分內(nèi)容在前面已經(jīng)介紹了。這里重點(diǎn)介紹一下Multi-head。
不是對(duì)輸入做一個(gè)Attention,而是需要做多個(gè)Attention。
假如每個(gè)單詞512維度,這里有h個(gè)scaled dot-product attention。每一套可以并行計(jì)算。 Q K V 做了不同的affine變換,投射到不同的空間,得到不同的維度,也就是WX+b變換。不同head的長(zhǎng)度一樣,但是映射參數(shù)是不一樣的。
之后過(guò)一個(gè)scaled dot-product attention。
h個(gè)結(jié)果concat
然后再做Linear
論文中h=8,dk=dv=dmodel/h=64d_k=d_v=d_{model/h}=64dk?=dv?=dmodel/h?=64,dmodel=512d_{model}=512dmodel?=512
做Attention,Q K V 形狀是不會(huì)發(fā)生變化的,每個(gè)的形狀還是 seq_length,x,hidden_size。
公式如下:
輸入X,包含token embedding和position embedding
- 對(duì)X做變換
Qi=QWiQQ^i=QW^Q_iQi=QWiQ?,Ki=KWiKK^i=KW^K_iKi=KWiK?,Vi=VWiVV^i=VW^V_iVi=VWiV?
每一次映射不共享參數(shù),每一次映射會(huì)有(WiQ,WiK,WiV)W^Q_i,W^K_i,W^V_i)WiQ?,WiK?,WiV?)三個(gè)參數(shù)。
- 對(duì)多頭中的某一組做attention
Attention(Qi,Ki,Vi)=KiQiTdkViAttention(Q_i,K_i,V_i)=\dfrac{K_iQ_i^T}{\sqrt{d_k}}V_iAttention(Qi?,Ki?,Vi?)=dk??Ki?QiT??Vi?
headi=Attention(Qi,Ki,Vi)head_i=Attention(Q_i,K_i,V_i)headi?=Attention(Qi?,Ki?,Vi?)
h組并行計(jì)算
- 拼接之后輸出
MultiHead(Q,K,V)=Concat(head1,head2,...head5)MultiHead(Q,K,V)=Concat(head_1,head_2,...head_5)MultiHead(Q,K,V)=Concat(head1?,head2?,...head5?)
經(jīng)過(guò)multi-head之后,得到h1,h2,h3,h4h_1,h_2,h_3,h_4h1?,h2?,h3?,h4?。
看圖上怎么還有一個(gè)Linear???
4.2.2 殘差鏈接
殘差鏈接是這樣的。
將輸入x加到multi-head或者feed network的輸出h上。這樣可以加快訓(xùn)練。
這一步得到的結(jié)果記為h1′,h2′,h3′,h4′h_1',h_2',h_3',h_4'h1′?,h2′?,h3′?,h4′?。
4.2.3 層正則化layer norm
層正則化,是對(duì)殘差鏈接的結(jié)果做正則化。
對(duì)h1′,h2′,h3′,h4′h_1',h_2',h_3',h_4'h1′?,h2′?,h3′?,h4′?這4個(gè)向量分別計(jì)算每個(gè)向量的均值μ\muμ和方差σ\sigmaσ。
γ\gammaγ和β\betaβ是共享的參數(shù),在模型中需要訓(xùn)練。
γ\gammaγ和β\betaβ可以在一定程度上抵消掉正則的操作。為什么正則了又要抵消呢?
這樣做可以讓每一個(gè)時(shí)間步的值更平均一些,差異不會(huì)特別大。
這一步的輸出是h1′′,h2′′,h3′′,h4′′h_1'',h_2'',h_3'',h_4''h1′′?,h2′′?,h3′′?,h4′′?。
4.2.4 前饋神經(jīng)網(wǎng)絡(luò) feed forward network
對(duì)于上一步的結(jié)果加一個(gè)前饋神經(jīng)網(wǎng)絡(luò)。
FFN(x)=max(0,xW1+b1)W2+b2FFN(x) = max(0, xW_1 + b_1 )W_2 + b_2FFN(x)=max(0,xW1?+b1?)W2?+b2?
在每一個(gè)時(shí)間步會(huì)做一個(gè)y=F(x)的變化,得到另外的100維的向量。
對(duì)這一步的結(jié)果再加一個(gè)殘差鏈接和層正則化。
這樣就得到一個(gè)transformer block。
輸入->Multi head attention ->殘差鏈接->層正則化->Feed-forward Network->殘差鏈接->層正則化。
在實(shí)際使用過(guò)程中層正則化會(huì)放在Multi head attention或者Feed-forward Network-前面。
4.3 decoder
decoder組件是由多個(gè)decoder組成的。在本模型中是6個(gè)decoder。
每一個(gè)decoder是由Masked Multi-head attention, Multi-head attention以及Feed-Forward三部分組成。
4.3.1 輸入
在decoder中的第一步輸入是始位置表示以及 encoder組件的輸出:K和V。經(jīng)過(guò)decoder之后,輸出第一個(gè)單詞I。第二步的輸入是第一步的輸出,以及K和V。
4.3.1 Masked Multi-head attention
Masked Multi-head attention的輸入是decoder前一步的輸出。第一個(gè)位置為起始位置表示。
接著加上位置編碼。整體作為輸入送入 Masked Multi-head attention。
在這里,處理第i步的時(shí)候,只能使用第1步到第i-1的向量做attention。這就是Masked含義。i位置之后的信息不可見(jiàn)。
之后做殘差鏈接。
再之后做層正則化。將結(jié)果送入Multi-head attention。
4.3.2 Multi-head attention
Multi-head attention這一部分與encoder的Multi-head attention相同。輸入是encoder組件的輸出K和V,以及Masked Multi-head attention的輸出,三部分作為輸入。
經(jīng)過(guò)Multi-head attention->殘差鏈接->層正則化,輸出。
4.3.3 前饋神經(jīng)網(wǎng)絡(luò) feed forward network
上一步的輸出,經(jīng)過(guò)前饋神經(jīng)網(wǎng)絡(luò)的結(jié)果作為輸出。
至此,一個(gè)decoder完成。其輸出作為下一個(gè)decoder的輸入。
4.4 The Final Linear and Softmax Layer
解碼器最后輸出浮點(diǎn)向量,如何將它轉(zhuǎn)成詞?這是最后的線性層和softmax層的主要工作。
線性層是個(gè)簡(jiǎn)單的全連接層,將解碼器的最后輸出映射到一個(gè)非常大的logits向量上。假設(shè)模型已知有1萬(wàn)個(gè)單詞(輸出的詞表)從訓(xùn)練集中學(xué)習(xí)得到。那么,logits向量就有1萬(wàn)維,每個(gè)值表示是某個(gè)詞的可能傾向值。
softmax層將這些分?jǐn)?shù)轉(zhuǎn)換成概率值(都是正值,且加和為1),最高值對(duì)應(yīng)的維上的詞就是這一步的輸出單詞。
4.5 decoder總結(jié)
encoder組件從輸入序列的處理開(kāi)始,最后的encoder組件的輸出被轉(zhuǎn)換為K和V,它倆被每個(gè)解碼器的"encoder-decoder atttention"層來(lái)使用,幫助解碼器集中于輸入序列的合適位置。
下面的步驟一直重復(fù)直到一個(gè)特殊符號(hào)出現(xiàn)表示解碼器完成了翻譯輸出。每一步的輸出被喂到下一個(gè)解碼器中。正如編碼器的輸入所做的處理,對(duì)解碼器的輸入增加位置向量。
4.4 The Loss Function
如何對(duì)比兩個(gè)概率分布呢?簡(jiǎn)單采用 cross-entropy或者Kullback-Leibler divergence中的一種。
4.5 代碼閱讀
總結(jié)
以上是生活随笔為你收集整理的理解transformer的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: LeetCode答案汇总(持续更新...
- 下一篇: array专题8