Transformer太大了,我要把它微调成RNN
文 | 煉丹學徒
編 | 小軼
從前車馬很慢,顯卡跑的也慢,一生只夠愛一個RNN。后來時代進步了,數據量和計算力闊綽了,堆疊起來的Transformer能夠在更深更寬的模型結構里吃下去更多的數據。從19年的預訓練浪潮開始,暴力美學興起,更深的Transformer更久的預訓練更大的模型參數量,暴力出奇跡一個個NLP榜單被刷新,但誰又記得起來當初Transformer論文里“解決RNN無法并行化訓練問題”的追求效率的motivation呢?身在普通高校,手握2080Ti和Titan V,向著大廠的預訓練模型望洋興嘆,我們開始懷念起當初人人都訓練得起的LSTM和GRU。那是精巧輕量的模型,那是人人都刷的起SOTA的時代。
今天這篇來自微軟的論文告訴我們,大廠里有一些研究員也還是愛我們的,Finetuning Pretrained Transformers into RNNs,在保持性能的情況下,將預訓練好的Transformer模型微調到其RNN變體,極大地降低顯存使用和計算開銷。
論文題目:
Finetuning Pretrained Transformers into RNNs
論文鏈接:
https://arxiv.org/abs/2103.13076
Arxiv訪問慢的小伙伴也可以在 【夕小瑤的賣萌屋】訂閱號后臺回復關鍵詞 【0407】 下載論文PDF~
本文提出的模型名為 T2R,代表 Transformer to RNN 。轉換的過程為 swap-then-finetune ,即,對于一個預訓練好的 Transformer 模型,我們將其 的注意力計算改為線性 的替換模塊,然后進行微調。可以預感到,其核心就在于如何用線性的子層對注意力層進行模擬。接下來,我們對其進行詳解。
概述
在2019年EMNLP論文 Transformer Disp [1] 中,作者提出:可以將注意力層的相似度計算()替換為核函數的分數。
ICML'20的另一工作Transformers are RNNs [2]則在此基礎上進一步優化,提出了將的注意力計算替換為線性的模塊。
今天要講的 T2R 這篇文章是緊隨上面 ICML'20 這篇工作進行的。之前 Transformers are RNNs 的方法中,使用的核函數沒有參數,不可訓。而 T2R 把核函數里封裝了一個MLP變成可訓練的。T2R原文的推導直接使用了 Transformers are RNNs 與 Transformer Disp 的結論,因而推導過程并不完整。我們今天也沿著T2R的思路進行講解,如果想要更深入了解 Transformer 轉 RNN 領域的,可以閱讀下面兩篇論文:
[1] Tsai et al. Transformer Disp: A Unified Understanding of Transformer's Attention via the Lens of Kernel. EMNLP 2019
[2] Katharopoulos et al. Transformers are RNNs: Fast autoregressive transformers with linear attention. ICML 2020
Transformer開銷
Transformer 由多頭注意力層、前饋層、層歸一化層堆疊后組成。本篇論文中要替換的,就是其中的多頭注意力層。
在開始講解如何替換之前,我們還是先梳理一下傳統Transformer的多頭注意力層。整個計算過程可以總結如下圖所示:
▲傳統Transformer的多頭注意力層計算過程這張圖我們自下往上看。首先,我們將多頭注意力層的source隱狀態記作,target隱狀態記作。
如何理解此處的source和target:比如,在解碼器的編碼器-解碼器注意力層中,就是編碼器端的序列長度,就是解碼器端的長度。在自回歸推斷的解碼器自注意力層中,就是已生成序列(加上自己)的長度,等于1,指當前要預測的這個字符。
從隱狀態,我們通過線性變換得到。則,注意力層的輸出為:
其中, 操作 旨在計算和的相似度(這里劃重點!等一會兒就要對這個計算動手腳了!):
上述的多頭注意力的計算是我們熟知的。論文對其復雜度進行了分析。設多頭數為,每個頭的隱狀態長度,每個的隱狀態總長 ,則有如下結論:
特征計算:即由隱狀態計算得到的過程,復雜度分別為 , 和
注意力計算: 由 計算得到最終輸出的過程,復雜度為 ,與 的長度成平方關系。
推斷時的顯存:,與已經解碼的長度線性相關。
注意力層的RNN替代方案
T2R的注意力層計算過程則如下圖所示:
首先,我們注意到原始的注意力計算中, 和 的相似度計算方式()需要先進行點乘,放縮后再進行指數運算,難以開展后續的近似優化。所以這里的關鍵之處就在于,T2R把的相似度計算方案替換為核函數的乘積:
此處,和的參數都是通過一個單層MLP學習得到的。 是維矩陣,是維bias向量,即,T2R的相似度計算核函數將原本維的向量降到了維然后進行相似度計算。對于多頭計算中的每一個頭,他們的和是獨立學出來的。因此,T2R在每一層中,共增加了個可學習的參數(小于總參數量的2%)。
我們把新的相似度計算方法代入到注意力的輸出式中,得到:
記,,則:
而根據 Transformers are RNNs [2] 的結論,此處的可以視作RNN遞歸的隱狀態。比如,在解碼器端做自回歸生成時,每個詞向它前文的單詞進行注意力計算來預測下一個詞,和可以被定義為遞歸的隱狀態:
注意到我們主要討論的函數是針對來計算相似度的,而是由喂入該層的隱狀態線性變化得到的。為了加速推斷速度,具體實現中把和代入,得到從隱狀態,直接線性變換得到的結果,從而在推斷的時候不需要計算,而從隱狀態直接計算得到相似度的值,即:
其中,
此時的開銷:
特征計算:我們記輸出維的特征向量,則生成的復雜度為 , 和
注意力計算: 由計算得到最終輸出的過程,假設k<<M,N,此時復雜度為,與的長度成線性關系。
推斷時的顯存:假設k<<M,則占用顯存,為常數。
Transformer和T2R對比
講到這里,我們再對比一下傳統Transformer和T2R的差異:
特征計算:計算不變,計算由, 降為,
注意力計算: 由降為,平方->線性。
推斷時的顯存:由降為,線性->常數。
實驗
數據集的效果
T2R主要使用ELU和RFA作為baseline進行比較。ELU和RFA為此前的另外兩篇使用核函數轉Transformer為RNN工作。因為ELU和RFA的核函數都是不可訓練的,所以無法取代預訓練好的模型里的注意力層進行功能上的替換和擬合。
首先,T2R在語言模型上開展了實驗。數據集使用WikiText-103,評測指標使用困惑度 perplexity 。發現T2R因為在核函數中放置了可訓練的MLP,在加載預訓練模型時獲得更大的收益。
此外,T2R在翻譯任務上開展實驗,使用數據集 WMT14 EN-DE,WMT14 EN-FR 和 WMT17 ZH-EN。研究員們發現雖然隨機初始化時,T2R弱于另外兩個baseline,但是加載預訓練后反超另外兩個baseline。
生成時的加速和顯存節省
研究員發現 T2R 比另外兩個模型的推斷速度更快(如下左圖所示),因為使用了更小的特征維度,以及更快的特征計算方法。對于推斷時的顯存占用,Transformer 隨著輸出序列的增長而線性增加,轉為 RNN 結構的模型則保持常數(如下右圖所示)。
消融實驗
隨著核函數輸出特征尺寸的增大,其效果也更加接近Transformer。相比于之前的工作,T2R 可以通過控制特征尺寸從而在效果和速度間權衡。
小結
本文提出的T2R,在 Transformers are RNNs 的基礎上,將無參數的核函數封裝為 MLP 加激活函數,從而可訓練。在此基礎上,T2R 替換掉預訓練 Transformer 的注意力層,從而降低了計算消耗和顯存使用,并且得到和原預訓練模型相似的結果。
后臺回復關鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
后臺回復關鍵詞【頂會】
獲取ACL、CIKM等各大頂會論文集!
總結
以上是生活随笔為你收集整理的Transformer太大了,我要把它微调成RNN的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 12种NumpyPandas高效技巧
- 下一篇: 95后CV工程师晒出工资单:狠补了这个,