超硬核 ICML’21 | 如何使自然语言生成提速五倍,且显存占用减低99%
文 | 煉丹學(xué)徒
編 | 小軼
我們忽略掉引言和介紹,直接把工作的效果丟上來,相信就足夠令自然語言生成的相關(guān)同學(xué)心動(dòng)——對(duì)于任何一個(gè)已有的Transformer生成模型,只需根據(jù)本文算法更改attention的計(jì)算順序,就可以實(shí)現(xiàn)
成倍速度提升!
顯存使用量降低到原來百分之個(gè)位數(shù)!
不需要重新訓(xùn)練!
保證輸出結(jié)果與原來完全一致!
以BART為例,本文方法可以把顯存使用率降低為原來的96分之一!是的,不需要在效率和質(zhì)量中做權(quán)衡!無腦地將本文策略應(yīng)用到你的Transformer里,龐大的自回歸預(yù)訓(xùn)練的生成模型速度也會(huì)變得可以接受!你甚至可以大膽地去和蒸餾模型、剪枝模型、(半)非自回歸模型比較速度。
仔細(xì)想想,我們自然語言生成的過程中,其實(shí)只有編碼和解碼是必須要計(jì)算的開銷,而作者們發(fā)現(xiàn),顯卡計(jì)算的時(shí)間遠(yuǎn)小于CPU操作和顯存IO的時(shí)間,并最終進(jìn)行理論分析降低顯存的耗時(shí)、優(yōu)化代碼降低CPU操作的耗時(shí),顯著降低顯存占用和提升生成速度。本文正是聚焦在了顯存優(yōu)化的部分。
對(duì)于該方法的理論描述先是發(fā)表在了ICML 2021上。之后文章作者又將EL-Attention等相關(guān)技術(shù)封裝成了一個(gè)工具包供大家一鍵調(diào)用,項(xiàng)目名稱FastSeq,在2021 ACL Demo paper里獲得了5 5 4的高分,并被兩位審稿人推薦為best demo paper。
感興趣的讀者可以直接安裝FastSeq工具包,僅需要一行代碼引入該庫函數(shù),只要你用的是常見的Facebook Fairseq或者Huggingface Transformers中的模型,import 完 FastSeq,甚至不需要改代碼,就可以獲得如下的加速效果:
論文題目:
EL-Attention: Memory Efficient Lossless Attention for Generation
FastSeq項(xiàng)目鏈接:
https://github.com/microsoft/fastseq
論文鏈接:
https://arxiv.org/pdf/2105.04779.pdf
Arxiv訪問慢的小伙伴也可以在 【夕小瑤的賣萌屋】訂閱號(hào)后臺(tái)回復(fù)關(guān)鍵詞 【0609】 下載論文PDF~
簡單回顧Transformer的注意力計(jì)算
注意力層中,輸入是Q,K,V即query、key、value矩陣,輸出是Q、K、V隱狀態(tài)維度相同,與Q的批大小、序列長度相同的隱狀態(tài)矩陣。訓(xùn)練過程中,自注意力層Q=K=V=隱狀態(tài)H;編碼器-解碼器注意力層中,Q=解碼器的隱狀態(tài)H,K=V=編碼器的隱狀態(tài)H。推斷過程中,自注意力層K=V=H是已經(jīng)輸出的前文隱狀態(tài),Q是預(yù)測的下一個(gè)詞;編碼器-解碼器注意力層中,K=V=H是編碼器的隱狀態(tài),Q是解碼器里預(yù)測的下一個(gè)詞。
計(jì)算時(shí),我們先把輸入的Q,K,V線性變換,得到多頭的隱狀態(tài)變小一些的(i代表第幾個(gè)頭),然后對(duì)于每一個(gè)頭,點(diǎn)乘歸一獲得注意力分布,用這個(gè)加權(quán)把的值取過來,再把這個(gè)個(gè)頭的低緯度信息線性方便換到之前Q,K,V的維度上作為這個(gè)頭的隱狀態(tài)計(jì)算結(jié)果,把每個(gè)頭i的隱狀態(tài)結(jié)果相加獲得最終結(jié)果。在自回歸推斷時(shí),無論在自注意力層,還是編碼器-解碼器注意力層中,Q都是一個(gè)單詞,而K和V為輸入編碼后的隱狀態(tài)或者已經(jīng)解碼的前文,都是比較長的內(nèi)容。為了表達(dá)方便,后續(xù)描述中,為經(jīng)過線性變換前的隱狀態(tài),經(jīng)過線性變換后的記為,表示多頭注意力中第i頭的內(nèi)容。(詳細(xì)的公式描述可以看推導(dǎo)章節(jié))
推斷過程中,由于需要進(jìn)行beam search,所以往往把編碼器的輸出重復(fù)beam size份并cache起來。其次,因?yàn)槊看沃荒茴A(yù)測下一個(gè)詞,所以自注意力層、編碼器解碼器注意力層里的會(huì)被cache起來避免重復(fù)計(jì)算,而是要預(yù)測的下一個(gè)詞的隱狀態(tài),因此不可能被cache起來,是我們想要計(jì)算的東西。
Transformer Beam Search為啥這么慢
我們回憶一下Transformer生成訓(xùn)練的時(shí)候,forward一次的速度是非常快的,但是 為什么真正去 beam search 然后推斷 inference 的時(shí)候卻很慢。
首先我們知道,訓(xùn)練和推斷的時(shí)候,編碼器端的運(yùn)行是相同的,所以變慢的原因都在解碼器端。即使我們?cè)O(shè)置beam search增加了一些計(jì)算量,但是實(shí)際上我們等待的時(shí)間遠(yuǎn)遠(yuǎn)大于理論上增加的計(jì)算量,把常見的生成任務(wù)的測試集完整的生成一遍結(jié)果,動(dòng)輒幾個(gè)小時(shí)的等待時(shí)間,到底花在哪里了呢?通過每個(gè)調(diào)用函數(shù)的時(shí)間消耗分析,作者得出了結(jié)論是:推斷的代碼中,把完整的矩陣運(yùn)算打散成了每次只能預(yù)測后續(xù)一個(gè)詞,零散的運(yùn)算(從訓(xùn)練時(shí) teacher forcing 的完整矩陣的Q,K,V計(jì)算,變成了推斷時(shí)每次Q都只有一個(gè)詞,去和K,V自回歸地計(jì)算若干遍)從而顯存的帶寬成為了推斷速度的瓶頸。
由于有cache技術(shù)的存在,beam search 時(shí)我們往往把計(jì)算過的隱狀態(tài)存起來反復(fù)使用以避免重復(fù)計(jì)算(如果不cache,會(huì)更慢,因?yàn)橐磸?fù)計(jì)算重復(fù)內(nèi)容。后續(xù)分析會(huì)告訴我們,cache的速度瓶頸在顯存IO,不cache的速度瓶頸在計(jì)算速度),頻繁的顯存內(nèi)容搬運(yùn)和粗放的顯存使用,導(dǎo)致GPU memory IO的時(shí)間超過了計(jì)算時(shí)間,顯卡一直在等顯存內(nèi)容的搬運(yùn)。如果再加上去除重復(fù)的輸出等等CPU的操作,速度就更慢了。
我們可以再看一下本篇推送引言部分的推斷時(shí)長分布圖。左側(cè)優(yōu)化前的推斷時(shí)間里,CPU相關(guān)的后處理占用了最多的時(shí)間,消耗了6.8秒;其次是庫函數(shù)中往往支持去除相鄰的連續(xù)的多少個(gè)詞的連續(xù)出現(xiàn)的問題,也就是圖中的ngram block函數(shù),去處理反復(fù)生成相同單詞短語的問題,消耗了4.5秒。顯存的搬運(yùn)也是時(shí)間的大頭,3.5秒,比真正解碼計(jì)算的時(shí)間3s要多。編碼只用了最少的時(shí)間,因?yàn)橹挥幸淮魏唵蔚膄orward。所有的這些時(shí)間里,只有編碼和解碼是必須消耗的,EL-Attention解決掉了cache的問題,FaseSeq項(xiàng)目的其他部分解決了CPU相關(guān)計(jì)算的問題,最終把不必要的計(jì)算去除,優(yōu)化達(dá)到耗時(shí)最少。本篇后續(xù)只介紹EL-Attention部分的提升。
Transformer 推斷過程顯存IO瓶頸
根據(jù)論文作者的分析,Transformer自然語言生成時(shí)的顯存IO瓶頸主要由以下三個(gè)問題組成:
1)在解碼器中的編碼器-解碼器注意力子層,把編碼器的輸出經(jīng)過每個(gè)子層不同的線性變換得到每一層都不一樣的多頭矩陣矩陣存儲(chǔ)。這就導(dǎo)致,層的解碼器,需要把encoded hidden states存遍。甚至由于開了beam search,當(dāng)前Transformer的各個(gè)庫函數(shù)中,解碼器中的每一層都還把自己層計(jì)算出來的編碼器K,V又要再重復(fù)beam size遍,占用了大量的顯存空間。解碼器中的 自注意力子層 也有相同的問題,存儲(chǔ)的同樣是經(jīng)過線性變換后的多頭矩陣。
2)在beam search過程中,因?yàn)?strong>每一步的寬度搜索,都會(huì)導(dǎo)致beam candidates的得分發(fā)生變化從而導(dǎo)致重新排序,以及生成結(jié)束符時(shí)從candidates隊(duì)列向finished隊(duì)列搬運(yùn)的過程,從而導(dǎo)致大量的memory IO消耗。
3)在顯卡中,如果兩個(gè)三維矩陣運(yùn)算時(shí),他們的第一維大小相同,則運(yùn)算通過并行運(yùn)算其中的各個(gè)二維矩陣運(yùn)算完成。推斷過程中,Q只是下一個(gè)詞的隱狀態(tài),而K,V則顯存占用比Q大得多,描述整個(gè)上文/輸入信息。Q對(duì)K和V的運(yùn)算,反復(fù)加載大量顯存占用的K和V,增大IO吞吐量負(fù)擔(dān)。(EL-Attention后面則減小query的第一維,增大query第二維,從而通過一次矩陣運(yùn)算得到完整的各個(gè)頭的計(jì)算結(jié)果,避免了反復(fù)加載key的值)
優(yōu)化方案
后續(xù)的一切優(yōu)化和計(jì)算的更改都是保證計(jì)算結(jié)果與原始Transformer完全一致的情況下展開和推導(dǎo)的
為了推導(dǎo)出更適合推斷過程的計(jì)算順序,降低存儲(chǔ)量,讓矩陣的運(yùn)算更高效,還能保持輸出結(jié)果一致,本章節(jié)介紹EL-Attention如何進(jìn)行MultiHead Attention(多頭注意力計(jì)算)的等效替換。本章節(jié)里,仍然是使用,代表線性變換前的隱狀態(tài),代表線性變換后的低多頭注意力里第i頭的結(jié)果,。
相比于存儲(chǔ)解碼器段每層計(jì)算過的,EL-Attention只緩存經(jīng)過線性變換之前的隱狀態(tài)H,由當(dāng)前要預(yù)測詞的Q和線性變換前的H,直接計(jì)算得到注意力層的結(jié)果,從而將原始的注意力計(jì)算
變成:
很明顯的我們看到,原始計(jì)算里使用的是經(jīng)過線性變換后的多頭進(jìn)行計(jì)算,而EL-Attention中,則直接使用輸入的進(jìn)行計(jì)算。這就是本文的核心做法,只cache隱狀態(tài)H而非多個(gè)低維度的多頭,從而進(jìn)行更高效的矩陣運(yùn)算,顯著減少顯存占用。
其中,和是兩個(gè)線性變換。為了表達(dá)方便,我們略去了部分計(jì)算比如矩陣運(yùn)算中的bias。完整的計(jì)算方法可以看下面的推導(dǎo)章節(jié)。此時(shí),我們可以拋棄所有的計(jì)算過的緩存,從而只緩存一份隱狀態(tài) 即可。其中,無需把Q計(jì)算到隱狀態(tài)變小的多頭狀態(tài)進(jìn)行零散矩陣運(yùn)算,而是直接在原本的hidden size進(jìn)行更加完整的矩陣運(yùn)算,詳細(xì)內(nèi)容見推導(dǎo)章節(jié)如下:
推導(dǎo)
本章節(jié)我們一起看一下推導(dǎo),確保EL-Attention的計(jì)算結(jié)果是和MultiHead Attention完全一致的。回顧傳統(tǒng)的注意力計(jì)算方法,是將輸入的 Q,K,V 線性變換得到維度更小,但是多份的多頭隱狀態(tài),對(duì)于每一個(gè)頭i,進(jìn)行注意力計(jì)算,然后再用 線性變換到之前隱狀態(tài)的大維度,把每個(gè)頭的隱狀態(tài)加起來。
我們假設(shè)原本 的隱狀態(tài)是維的( 均為 維),多頭數(shù)為 ,每一頭的隱狀態(tài)是 維的,則 ,,,。
其中,,。我們記 ,則:
其中,
最終我們得到,
在推斷過程中,。
降低了多少
直觀的減少顯存使用
顯存占用:假設(shè)編碼器-解碼器注意力層 beam search 的 size 大小為b,解碼器層,則原始的 beam search會(huì)緩存 倍的encoded隱狀態(tài)。其中的倍是因?yàn)楫?dāng)前的庫函數(shù)實(shí)現(xiàn)不佳,重復(fù)beam size份造成的,可以簡單的優(yōu)化掉,剩下的倍通過EL-Attention優(yōu)化掉。即,編碼器-解碼器注意力子層中,把encoded hidden states的顯存占用降為。類似的,解碼器的自注意力子層中,可以把顯存占用降低1/2。
詳細(xì)的計(jì)算復(fù)雜度和顯存優(yōu)化
進(jìn)一步分析,EL-Attention分析注意力計(jì)算中三個(gè)步驟的計(jì)算復(fù)雜度和顯存占用復(fù)雜度。它把注意力的計(jì)算分解成三部分進(jìn)行分析,第一部分是Build Key and Value(即原本計(jì)算中的把H線性變換到多頭的),第二部分是Build Query(即原本計(jì)算中的把Q線性變換到多頭的),第三部分是進(jìn)行注意力的計(jì)算。
首先看Build Key and Value,傳統(tǒng)的做法中,如果不cache,則需要每次進(jìn)行的計(jì)算,然后把計(jì)算結(jié)果存起來(存儲(chǔ)復(fù)雜度)。他的計(jì)算復(fù)雜度高,需要反復(fù)重新計(jì)算,cache則相反。而EL-Attention中,由于直接使用原始的輸入K,V進(jìn)行計(jì)算,無需計(jì)算出多頭的那些,因此計(jì)算和顯存都為0。
其次是Build Query,對(duì)于要預(yù)測的下一個(gè)詞的計(jì)算是繞不開的,所以無論傳統(tǒng)做法中是否cache,Q都要被計(jì)算到多頭的,因此計(jì)算復(fù)雜度和顯存使用相同。EL-Attetnion的這一步是函數(shù),由于多乘了將多頭的低hidden size隱狀態(tài)變成原本的高h(yuǎn)idden size計(jì)算,因此此處顯存多使用了(多頭數(shù))倍。然而這個(gè)其實(shí)很小,因?yàn)楫吘筈只有后續(xù)要預(yù)測的那一個(gè)單詞的隱狀態(tài)。最終是注意力計(jì)算部分,可以看到,因?yàn)闆]有緩存那些計(jì)算過的,EL-Attention的計(jì)算復(fù)雜度增大為倍,與此同時(shí),顯存消耗降低了。
為了比較上述三個(gè)步驟,用計(jì)算換取減少顯存的操作是否收益大于付出,EL-Attention使用下面的圖來表示這種權(quán)衡的收益。下圖中,橫軸是顯存的使用量,縱軸是計(jì)算量,面積代表時(shí)間消耗。傳統(tǒng)做法的時(shí)間消耗由三部分組成,圖中為無邊框的藍(lán)色的大圈,灰色的大圈和橙色的小圈。EL-Attention的時(shí)間消耗由兩部分組成,虛線邊框的灰色小圈和橙色圈,可以看到,由于重新平衡了指令密度,顯存消耗和計(jì)算消耗,總時(shí)間消耗(兩個(gè)虛線邊框圓的總面積)明顯小于傳統(tǒng)做法(三個(gè)無邊框圓的總面積)。
實(shí)驗(yàn)結(jié)果
首先,因?yàn)镋L-Attention優(yōu)化后的輸出結(jié)果與優(yōu)化前的Transformer模型完全一致,不需要重新訓(xùn)練,只需要優(yōu)化推斷的計(jì)算順序,因此performance和輸出結(jié)果,原始論文中沒有展示。為了分析速度,首先,EL-Attention使用固定的假輸入去分析速度影響,他固定了編碼器端輸入1024長,然后嘗試不同的解碼器段長度、不同的beam size去比較EL-Attention和原始attention的速度。我們可以看到,cache機(jī)制雖然增加了顯存使用,但因?yàn)楸苊饬酥貜?fù)計(jì)算,明顯比不cahe的速度快,而EL-Attention則又明顯的優(yōu)于帶cache的beam search生成。
此外,EL-Attention在真實(shí)模型和數(shù)據(jù)集上開展試驗(yàn)。它使用Transformer,BART,GPT-2作為實(shí)驗(yàn)?zāi)P?#xff0c;其中Transformer和BART為編碼器-解碼器結(jié)構(gòu),GPT-2為只有解碼器的結(jié)構(gòu),在SQuAD 1.1問題生成、XSum摘要任務(wù)、CNN/DM摘要任務(wù)上開展試驗(yàn)。beam size越大,EL-Attention的加速效果越明顯,EL-Attention作者很保守的把所有模型的beam size都開的比較小,只有4,就有了若干倍的速度提升。
EL-Attention由于顯著地減少了顯存的占用,所以可以在有限的顯存里,把batch size開大很多倍。通過增大batch size的方法,繼續(xù)提高GPU的使用率和推斷吞吐量(下表的顯存占用對(duì)比令人吃驚):
總結(jié)
EL-Attention通過分析自然語言生成中的速度瓶頸,精確定位到了顯存IO的問題,然后通過理論分析顯存的計(jì)算方案,找到了若干致命問題,對(duì)于已經(jīng)訓(xùn)練好的模型,通過提出新的計(jì)算順序和算法來在對(duì)原輸出無損的情況下,優(yōu)化計(jì)算量和顯存使用,從而達(dá)到了降顯存、加速生成的效果。
萌屋作者:煉丹學(xué)徒
在微軟搬磚的聯(lián)培博士在讀生,擅長烹飪和摸魚,被迫掌握豐富的增肥和減肥經(jīng)驗(yàn)。祝大家吃好喝好,減肥成功。
作品推薦
把數(shù)據(jù)集刷穿是什么體驗(yàn)?MetaQA已100%準(zhǔn)確率
Transformer太大了,我要把它微調(diào)成RNN
后臺(tái)回復(fù)關(guān)鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
后臺(tái)回復(fù)關(guān)鍵詞【頂會(huì)】
獲取ACL、CIKM等各大頂會(huì)論文集!
創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)總結(jié)
以上是生活随笔為你收集整理的超硬核 ICML’21 | 如何使自然语言生成提速五倍,且显存占用减低99%的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 11月AI大事件回顾:GPT3开放使用/
- 下一篇: 史上最大多模态图文数据集发布!