Transformer的七十二变
?PaperWeekly 原創(chuàng) ·?作者|李明曉
學(xué)校|魯汶大學(xué)博士生
研究方向|自然語言處理
自 2017 年 Google 提出 Transformer 后,其在各項(xiàng) NLP 任務(wù)中都取得了 SOTA 的表現(xiàn)。然而其自身的結(jié)構(gòu)缺陷導(dǎo)致了兩個問題:
1)由于自注意力機(jī)制每次都要計(jì)算所有詞之間的注意力,其所需計(jì)算復(fù)雜度為輸入長度的平方;2)Transformer 需要事先設(shè)定輸入長度,這導(dǎo)致了其對于長程關(guān)系的捕捉有了一定限制,并且由于需要對輸入文檔進(jìn)行分割會導(dǎo)致語意上的碎片化。
近年來許多工作通過對 Transformer 結(jié)構(gòu)的調(diào)整優(yōu)化來緩解以上兩個問題。
本文分為兩部分,第一部分介紹和比較的三個模型(Star-Transformer 和 BP-Transformer)試圖在時間復(fù)雜度和空間復(fù)雜度上優(yōu)化 Transformer。第二部分介紹和比較的兩個模型(Transformer-XL 和 Compressivetransformer)試圖解決上面提出的第二個問題。
對 Transformer 不了解的可先閱讀該博客:
https://jalammar.github.io/illustrated-transformer/
更高效的Transformer
1. Star-Transformer
論文標(biāo)題:Star-Transformer
論文來源:NAACL 2019
論文鏈接:https://arxiv.org/abs/1902.09113
代碼鏈接:https://github.com/fastnlp/fastNLP
原始的 Transformer 在計(jì)算注意力的時候,序列中每個元素要和所有元素進(jìn)行計(jì)算,也是這樣的計(jì)算方式導(dǎo)致了其復(fù)雜度為序列長度的平方。
同時 Transformer 這樣所有元素直接相互作用的計(jì)算方式?jīng)]能夠很好地使用我們所知道的一些語言序列上的特性,比如語言序列中相鄰的詞往往本身就會有較強(qiáng)的相關(guān)性。
對于這個問題,Star-Transformer 在注意力機(jī)制的計(jì)算上進(jìn)行了優(yōu)化,構(gòu)建了一個星狀的結(jié)構(gòu),所有序列中直接相鄰的元素可以直接相互作用,而非直接相鄰的元素則通過中心元素實(shí)現(xiàn)間接得信息傳遞。
具體結(jié)構(gòu)比較如下圖所示,左邊為正常的 Transformer,右邊為 Star-Transformer。
下圖為 Star-Transformer 的參數(shù)更新算法。在初始化階段,衛(wèi)星節(jié)點(diǎn) 的初始值為相應(yīng)的詞向量,而中心節(jié)點(diǎn)?的初始值為所有衛(wèi)星節(jié)點(diǎn)詞向量的平均值。
算法中參數(shù)更新分為兩步:第一步為衛(wèi)星節(jié)點(diǎn)的更新,第二步為中心節(jié)點(diǎn)的更新。兩步的更新都是基于多頭注意力機(jī)制。
對于衛(wèi)星節(jié)點(diǎn),計(jì)算多頭注意力機(jī)制時只需考慮該節(jié)點(diǎn)狀態(tài)與直接相鄰節(jié)點(diǎn),中心節(jié)點(diǎn),該節(jié)點(diǎn)詞向量和本節(jié)點(diǎn)上一時刻狀態(tài)的信息交互(如下圖中 )。
因?yàn)橹行墓?jié)點(diǎn)擔(dān)負(fù)著所有衛(wèi)星節(jié)點(diǎn)之間的信息交互,因此中心節(jié)點(diǎn)在更新時須與自己上一時刻的信息和所有衛(wèi)星節(jié)點(diǎn)進(jìn)行信息交互。同時為了表示位置信息,在衛(wèi)星節(jié)點(diǎn)中還必須拼接上表示位置信息的可學(xué)習(xí)的向量。
該模型在使用中,針對序列的下游任務(wù)使用衛(wèi)星節(jié)點(diǎn)的輸出,而針對語言推理文本分類這種需要整個句子的任務(wù)則可以使用中心節(jié)點(diǎn)的輸出。
作者的實(shí)驗(yàn)中表明,該非直接的聯(lián)系方式同樣能夠?qū)W習(xí)到長程聯(lián)系,同時在一些任務(wù)上的也取得了比 Transformer 更好的表現(xiàn)。
2. BP-Transformer
論文標(biāo)題:BP-Transformer: Modelling Long-Range Context via Binary Partitioning
論文來源:NAACL 2019
論文鏈接:https://arxiv.org/abs/1911.04070
代碼鏈接:https://github.com/yzh119/BPT
BP-Transformer 采用一個層級(從細(xì)粒度到粗粒度)的注意力計(jì)算機(jī)制來改進(jìn)原始的 Transformer。其能夠?qū)?Transformer 在計(jì)算注意力時的時間復(fù)雜度從 降低到 。
名字中 BP 指的是 Binary partitioning,即二分。在 BP-Transformer 中首先將一整個序列通過二分手段構(gòu)建為一顆二叉樹,二叉樹的葉子節(jié)點(diǎn)即為序列中的元素值,而中間節(jié)點(diǎn)則是序列中的片段。
整個結(jié)構(gòu)可以看為圖神經(jīng)網(wǎng)絡(luò),序列元素和序列片段為圖中的節(jié)點(diǎn),而節(jié)點(diǎn)間的聯(lián)系為圖的邊。邊分為兩種:第一種為 Affiliated Edges 連接片段與組成該片段的葉子節(jié)點(diǎn),另一種為 Contextual Edges 連接葉子節(jié)點(diǎn)和與其相關(guān)的葉子節(jié)點(diǎn)或片段節(jié)點(diǎn)。
整個結(jié)構(gòu)如下圖所示, 為可學(xué)習(xí)的相對位置表示, 的下標(biāo)記第一個數(shù)字表示該節(jié)點(diǎn)在二叉樹中的層級,第二個數(shù)據(jù)表示為在該層級與葉子節(jié)點(diǎn)連接的第幾個節(jié)點(diǎn)。
葉子節(jié)點(diǎn)的 Contextual Edges 可通過往上遞歸求得。例如位置為 的元素與其相連構(gòu)成 Contexttual Edges 的節(jié)點(diǎn)為以下節(jié)點(diǎn),不同行代表的是在二叉樹上不同層的節(jié)點(diǎn),其中? 。
如果 為奇數(shù)則 。 為超參數(shù),表示二叉樹中每個層級由多少個節(jié)點(diǎn)與葉子節(jié)點(diǎn)連接。
構(gòu)建完整個圖后,該模型可通過以下算法更新參數(shù):
其中 GSA (Graph Self-Attention) 為:
加入相對位置后,注意力的計(jì)算可修正為以下公式:
A(u) 為所有與 u 節(jié)點(diǎn)想連的節(jié)點(diǎn),由上面公式可見 GSA 其實(shí)就是多頭注意力機(jī)制,只是相比原始 Transformer 計(jì)算一個節(jié)點(diǎn)與所有節(jié)點(diǎn)的注意力,這里只計(jì)算節(jié)點(diǎn)與其相鄰節(jié)點(diǎn)的注意力,而因?yàn)樵诙鏄渲杏锌鐚哟蔚墓?jié)點(diǎn)連接即有自節(jié)點(diǎn)元素和中間節(jié)點(diǎn)元素(片段)的連接,就實(shí)現(xiàn)在計(jì)算不同粒度下的注意力。
該模型在初始化時,葉子節(jié)點(diǎn)初始化為相應(yīng)的詞向量,而片段節(jié)點(diǎn)則初始化為零。在針對像語言模型這種序列型的下游任務(wù)中,可使用葉子節(jié)點(diǎn)的輸出,而針對像文本分類等需要用的整個句子的則使用二叉樹根節(jié)點(diǎn)的輸出。
作者在多個任務(wù)中測試,結(jié)果表明相比原始的注意力計(jì)算方式,該模型在長文本任務(wù)中取得了更好的表現(xiàn)。
學(xué)習(xí)更長語義聯(lián)系的Transformer
1. Transformer-XL
論文標(biāo)題:Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
論文來源:ACL 2019
論文鏈接:https://arxiv.org/abs/1901.02860
代碼鏈接:https://github.com/kimiyoung/transformer-xl
相比原始 Transformer,Transformer-XL 有以下兩個變化:1)引入循環(huán)機(jī)制,使得新模型能夠?qū)W習(xí)到更長的語義聯(lián)系;2)拋棄絕對位置表示,采用相對位置表示。
1.1 循環(huán)機(jī)制
在原始 Transformer 中,每個序列的計(jì)算相互獨(dú)立,因此也導(dǎo)致了其只能夠?qū)W習(xí)到同個序列內(nèi)的語義聯(lián)系。而在 Transformer-XL 中,每個序列計(jì)算后的隱狀態(tài)會參與到下一個序列的計(jì)算當(dāng)中,使得模型能夠?qū)W習(xí)到跨序列的語義聯(lián)系。
如下圖所示,左邊為原始 Transformer,右邊為 Transformer-XL。
相比原始 Transformer,Transformer-XL 模型的計(jì)算當(dāng)中加入綠色連線,使得當(dāng)層的輸入取決于本序列和上一個序列前一層的輸出。具體計(jì)算公式如下:
其中 h 為隱藏層,n 為層數(shù), τ 表示序列數(shù),W 為模型參數(shù),° 表示矩陣拼接。SG 意為 stop-gradient,即停止梯度計(jì)算,這樣雖然在計(jì)算中運(yùn)用了前一個序列的計(jì)算結(jié)果,但是在反向傳播中并不對其進(jìn)行梯度的更新。
式子一:將上一序列上一層隱狀態(tài)與本序列上一層隱狀態(tài)進(jìn)行矩陣拼接,這也是 Transformer-XL 實(shí)現(xiàn)循環(huán)機(jī)制的關(guān)鍵。
式子二:計(jì)算注意力機(jī)制所需的 q,k,v。與原始 Transformer 不同的是 k,v 的計(jì)算是取決于由式一得到的隱狀態(tài),而 q 則是只含有本序列的信息。在注意力的計(jì)算中,q 與 k,v 的相互作用讓模型實(shí)現(xiàn)了跨序列的語義學(xué)習(xí)。
式子三:常規(guī)的 Transformer 層計(jì)算。
Transformer-XL 通過引入跨層的循環(huán)機(jī)制,使得模型能夠?qū)W習(xí)到跨序列的語義信息。這樣跨層的方式也使得其能夠?qū)W習(xí)到的語義長度受限于網(wǎng)絡(luò)深度,具體依賴關(guān)系為 N*(L-1) 用大 O 表示可近似為 O(N*L),N 為網(wǎng)絡(luò)深度,L 為序列長度。如下圖所示,序列長度為 4,網(wǎng)絡(luò)深度為 3。
1.2 相對位置編碼
由于注意力機(jī)制忽視了位置信息,因此在 Transformer 中需要加入位置編碼。原始 Transformer 采用了正弦/余弦函數(shù)來編碼絕對位置信息。然而在 Transformer-XL 中,若采用和 Transformer 一樣的絕對位置編碼,那么不同序列間同個位置會得到同樣的編碼。
因此這種方法在 Transformer-XL 中行不通,為了解決這個問題 Transformer-XL 采用了相對位置編碼。
以下公式和分別為原始 Transformer 和 Transformer-XL 中注意力的計(jì)算公式。在其中 E 表示詞的 Embedding,而 U 表示絕對位置編碼。在中 R 為相對位置表示,該相對位置表示也是一個正弦函數(shù)表示。
相比,除了用相對位置表示 R 替代了絕對位置表示 U 后,還用兩個可學(xué)習(xí)參數(shù) u 和 v 替代了中的 query 位置的映射,同時將原本對 key 的映射矩陣分成兩組矩陣和,分別生成基于內(nèi)容的 key 向量和基于位置的 key 向量。
替換后中四項(xiàng)分別代表:(a) 基于內(nèi)容的尋址;(b) 基于內(nèi)容的位置偏差;(c) 全部內(nèi)容偏差;(d) 全局位置偏差。
采用相對位置編碼后,Transformer-XL 具體的計(jì)算公式如下:
2. Compressive Transformer
論文標(biāo)題:Compressive Transformers for Long-Range Sequence Modelling
論文來源:ICLR 2020
論文鏈接:https://arxiv.org/abs/1911.05507
為了增加 Transformer 可以學(xué)習(xí)到的語義長度,Compressiv Transformer 在原 Transformer 的結(jié)構(gòu)上增加了一個記憶模塊和一個壓縮記憶模塊。
每一個序列計(jì)算后其隱狀態(tài)會被放入記憶模塊中,然后記憶模塊中的部分原有記憶會被壓縮然后放入壓縮記憶模塊中,這時壓縮記憶模塊中的部分記憶則會被拋棄掉。
如下圖所示,壓縮記憶模塊和記憶模塊維度皆為 6,而序列長度為 3。箭頭和f表示對記憶模塊中的記憶進(jìn)行壓縮并放入壓縮記憶模塊中。
Compressive Transformer 具體的算法細(xì)節(jié)如下,其中m表示記憶模塊,cm 表示壓縮記憶模塊,h 為隱狀態(tài),d 為 Embedding 維度,為壓縮記憶模塊長度,為記憶模塊長度,c 為壓縮常數(shù),l 為層數(shù)。
下圖為一個簡易示意圖,紅色表示計(jì)算注意力,藍(lán)色表示將計(jì)算過的序列存入記憶模塊和壓縮記憶模塊過程。
在論文中作者嘗試了如下幾個不同的壓縮函數(shù):1)max/mean pooling;2)1Dconvolution;3)dialated convolutions;4)most-used。實(shí)驗(yàn)表明在 WIKITEXT-103 數(shù)據(jù)集中 1D convolution 表現(xiàn)最好。
同時為了更好的學(xué)習(xí)壓縮函數(shù)的參數(shù),模型訓(xùn)練時使用了一個輔助的損失函數(shù)(因?yàn)槿羰且蕾嚹P偷膿p失函數(shù),則梯度需要經(jīng)過很長的時序才能傳到存貯的老的記憶,類似于 RNN 里梯隊(duì)消失問題)。
該損失函數(shù)為注意力重建損失函數(shù),旨在測量通過更新后的記憶計(jì)算的注意力和使用原本記憶計(jì)算的注意力之間的差距。通過最小化該差距來確保有效的壓縮信息。
通過引入記憶模塊后,Compressive Transformer 能夠捕捉的語義長度為 O(L*(+c) 其中為壓縮記憶模塊長度,為記憶模塊長度,c 為壓縮常數(shù)。
相比較 Transformer-XL 的 O(LN),Compressive Transformer 通過將計(jì)算后的序列保存在記憶模塊中有效的提高了模型捕捉長程語義的能力。
Reference
BP-Transformer: Modelling Long-Range Context via Binary Partitioning.Zihao Ye, Qipeng Guo, Quan Gan, Xipeng Qiu, Zheng Zhang
Star-Transformer.Qipeng Guo, Xipeng Qiu, Pengfei Liu, Yunfan Shao, Xiangyang Xue, Zheng Zhang
COMPRESSIVE TRANSFORMERS FOR LONG-RANGE SEQUENCE MODELLING, Jack W. Rae?Anna Potapenko?Siddhant M. Jayakumar?Chloe Hillier Timothy P. Lillicrap
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, Zihang Dai?12, Zhilin Yang?12, Yiming Yang1, Jaime Carbonell1, Quoc V. Le2, Ruslan Salakhutdinov1
點(diǎn)擊以下標(biāo)題查看更多往期內(nèi)容:?
BERT在多模態(tài)領(lǐng)域中的應(yīng)用
淺談Knowledge-Injected BERTs
從Word2Vec到BERT
后 BERT 時代的那些 NLP 預(yù)訓(xùn)練模型
兩行代碼玩轉(zhuǎn) Google BERT 句向量詞向量
????
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
關(guān)于PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報道人工智能前沿論文成果的學(xué)術(shù)平臺。如果你研究或從事 AI 領(lǐng)域,歡迎在公眾號后臺點(diǎn)擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。
總結(jié)
以上是生活随笔為你收集整理的Transformer的七十二变的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 曲奇云盘因平台运营问题停服:腾讯云关闭服
- 下一篇: 阿里巴巴淘系开源大型3D家具数据集(3D