谷歌提出 RNN 版 Transformer,或为长文本建模的当前最优解
文 | 小軼
今天給大家介紹一篇谷歌的最新工作,解決的是 Transformer 的長(zhǎng)文本處理問(wèn)題。在原生 Transformer 中,attention 的復(fù)雜度是輸入序列長(zhǎng)度的平方級(jí)別,因此限制了它處理長(zhǎng)文本的能力。簡(jiǎn)單來(lái)說(shuō),本文提出的解決方案就是把 Transformer當(dāng)做 RNN 中的循環(huán)單元來(lái)用。
和傳統(tǒng) RNN 的區(qū)別只在于:傳統(tǒng) RNN encoder 每個(gè)循環(huán)單元負(fù)責(zé)編碼一個(gè) token,而本文中每個(gè)循環(huán)單元負(fù)責(zé)編碼一段長(zhǎng)度為 的文本片段,且每個(gè)循環(huán)單元都由構(gòu)造相同的 Transformer Block 來(lái)實(shí)現(xiàn)。如此一來(lái),每個(gè)片段在編碼時(shí),都能用類似 RNN 的方式,同時(shí)考慮之前文本中的信息了。
想法很簡(jiǎn)單,但具體實(shí)現(xiàn)起來(lái)還是有一些難點(diǎn)。接下來(lái),我們展開(kāi)介紹一下本文所提出的 Block-Recurrent Transformer。
論文標(biāo)題:
BLOCK-RECURRENT TRANSFORMERS
論文鏈接:
https://arxiv.org/pdf/2203.07852.pdf
滑動(dòng)注意力機(jī)制
先來(lái)看一下每個(gè) block 的 attention 范圍。本文采用的是一種滑動(dòng)窗口注意力機(jī)制,一種專門(mén)針對(duì)長(zhǎng)文檔場(chǎng)景的技術(shù)。由于文本過(guò)長(zhǎng),讓每個(gè) token 都 attend 到整個(gè)文本中的所有 token 難以實(shí)現(xiàn)。在滑動(dòng)窗口注意力機(jī)制中:每個(gè) token 只需要 attend 到它的前 個(gè) token。在本文中,滑動(dòng)窗口長(zhǎng)度 與每個(gè)循環(huán)單元所需處理的文本長(zhǎng)度 相等,即 。
上圖示例中,假設(shè)窗口長(zhǎng)度為 8;相應(yīng)地,輸入文本也被分為長(zhǎng)度為 8 的片段,交由 Transformer blocks 分別處理。圖中淺藍(lán)色區(qū)域表示了 attention 范圍。
圖中兩個(gè)黑框分別對(duì)應(yīng)了兩個(gè) Transformer block 。8個(gè)紅色標(biāo)記點(diǎn),代表右下角那個(gè) block 所需要處理的 8 個(gè) token??梢钥吹?#xff0c;每個(gè) block 的 attention 矩陣大小為 。因此,對(duì)于長(zhǎng)度為 N 的輸入來(lái)說(shuō),整個(gè)模型的 attention 復(fù)雜度為 O(N)。
循環(huán)單元
接下來(lái),我們就往每個(gè) Transformer block 內(nèi)部看看,究竟是如何實(shí)現(xiàn)循環(huán)的。
▲傳統(tǒng) RNN 結(jié)構(gòu)類似傳統(tǒng) RNN,每個(gè)循環(huán)單元:
輸入是input embeddings 和 current state
輸出是 output embeddings 和 next state
所以,我們這里所需要理解的兩個(gè)核心問(wèn)題也就是:在 Block-Recurrent Transformer 中,這兩個(gè)輸出分別是如何得到的?
垂直方向:如何得到 output embeddings?
下圖展示了得到 output embeddings 的過(guò)程。
▲垂直方向:如何得到 output embeddings和傳統(tǒng)的 Transformer layer 非常相像,差別只集中在紅框標(biāo)識(shí)出來(lái)的部分。在這一部分中,為了融合上一個(gè)循環(huán)單元給的 current state 信息,他們將 input embeddings 和 current state vectors 做了一個(gè) cross attention。另一方面,input embeddings 自身也會(huì)過(guò)一個(gè) self-attention 層。這兩部分拼接后,通過(guò)線性層融合在了一起。
水平方向:如何得到 next state?
下圖展示了得到 next state 的過(guò)程。
▲水平方向:循環(huán)單元之間如何傳遞 state與傳統(tǒng) Transformer 不同的地方,用紅色和粉色框標(biāo)識(shí)了。紅色部分,同樣是用 cross attention 將 input embeddings 和 current state vectors 融合。粉色部分則是用兩個(gè) gate 替代了原本 Transformer 中的殘差層。這兩個(gè) gate 的作用與 LSTM 中的遺忘門(mén)類似,控制了對(duì)前一個(gè) state 信息的保留程度。
垂直方向如何多層疊加?
最后還有一個(gè)問(wèn)題。我們都知道,傳統(tǒng) Transformer Encoder 通常是由多個(gè) Transformer Layer 疊加起來(lái)的。也就是下圖中那個(gè) 的意義。那么,在 Block-Recurrent Transformer 中,如何實(shí)現(xiàn)垂直方向上的多層疊加呢?
▲傳統(tǒng) Transformer Encoder文中討論了兩種方式,Single Recurrent Layer 和 Feedback。
▲Single Recurrent LayerSingle Recurrent Layer (SRL) 的實(shí)現(xiàn)比較簡(jiǎn)單。我簡(jiǎn)單花了張示意圖,大致如上圖所示。垂直方向上疊加的多個(gè)層:大多數(shù)都是普通的 Transformer Layer;只有其中的一層,在水平方向上接收了 current state,做了循環(huán)操作。這種方式的運(yùn)算復(fù)雜度也比較低,只相當(dāng)于在普通的 Transformer 基礎(chǔ)上多加了一層 layer 的運(yùn)算量。也就是說(shuō),如果垂直疊加了 12 層,相當(dāng)于普通 Transformer 疊加 13 層的運(yùn)算量。
▲FeedbackFeedback 在 SRL 的基礎(chǔ)上,current state 還會(huì)廣播給其他 Transformer Layer。這些層會(huì)用 cross attention 的方式,將 current state 的信息融合。實(shí)驗(yàn)中,Feedback 比 SRL 性能有小幅提升,不過(guò)它的模型參數(shù)更多,訓(xùn)練時(shí)長(zhǎng)也要陡增 35~40%。
實(shí)驗(yàn)
實(shí)驗(yàn)在三個(gè)長(zhǎng)文本數(shù)據(jù)集上進(jìn)行,分別是 PG19,arxiv 和 Github。評(píng)測(cè)任務(wù)是自回歸語(yǔ)言建模,指標(biāo)為 perplexity。結(jié)果如下圖所示。
其中,黃色高亮的是本文所提出方法的兩個(gè)變種,獲得了 SOTA 的效果。
紅色框出的是三個(gè)比較重要的 baseline。其中,上面兩個(gè) baseline 是此前經(jīng)典的長(zhǎng)文檔處理模型 Transformer-XL 的兩個(gè)變種。可以看到本文方法的性能要比他們好不少。
最后一行的 Memorizing Transformer 同樣是谷歌的工作,剛剛被 ICLR'2022 錄用。其基本思想是:編碼長(zhǎng)文本時(shí),模型一邊往下讀,一邊把之前見(jiàn)過(guò)的所有 token 保存在一個(gè)數(shù)據(jù)庫(kù)中;在讀當(dāng)前片段時(shí),會(huì)用 kNN 的方式找到數(shù)據(jù)庫(kù)中相似的內(nèi)容,然后和當(dāng)前內(nèi)容同時(shí)交互編碼。
可以看到,這個(gè)模型的效果其實(shí)和本文方法相差不大,但復(fù)雜度要高很多,運(yùn)算時(shí)延也要長(zhǎng)[1]。雖然...但是,本文并沒(méi)有把 Memorizing Transformer 的 step time 明確寫(xiě)在表格中。個(gè)人感覺(jué)有些不妥。
小結(jié)
本文的想法其實(shí)很簡(jiǎn)單:把 Transformer 作為 RNN 的循環(huán)單元,解決長(zhǎng)文本問(wèn)題。我相信想到過(guò)類似 idea 的應(yīng)該早有人在。我確實(shí)也看到了類似的 previous works,不過(guò)它們的模型復(fù)雜度和性能效果都遜于本文。
就本文來(lái)說(shuō),只是擁有一個(gè) idea 肯定是不夠的,還要解決很多問(wèn)題,包括:
相鄰的 block 之間如何以適配 Transformer 的方式傳遞信息
模型設(shè)計(jì)的時(shí)候還要同時(shí)考慮到將運(yùn)算復(fù)雜度的降到最低,能并行運(yùn)算的絕不搞串行
還有最后工程實(shí)現(xiàn)上的一些問(wèn)題。比如說(shuō),模型訓(xùn)練的時(shí)候是否會(huì)像傳統(tǒng) RNN 一樣遇到梯度消失的問(wèn)題?如果有,該如何解決?我在本篇推送中,沒(méi)有涵蓋這方面的討論。原文確實(shí)提了一些方法來(lái)提高模型訓(xùn)練的穩(wěn)定性。
從一個(gè)宏觀的 idea 到真正落實(shí),還是有很長(zhǎng)距離的。所以還是不能輕易地說(shuō)一篇論文的 idea “too simple”。
往期回顧
《Longformer:超越RoBERTa,為長(zhǎng)文檔而生的預(yù)訓(xùn)練模型》
《告別自注意力,谷歌為T(mén)ransformer打造新內(nèi)核Synthesizer》
《Google綜述:細(xì)數(shù)Transformer模型的17大高效變種》
萌屋作者:小軼
是小軼,不是小秩!更不要叫小鐵!高冷的形象是需要大家共同維護(hù)的!作為成熟的大人,正在勤儉節(jié)約、兢兢業(yè)業(yè),為成為一名合格的(但是仍然發(fā)量充足的)PhD而努力著。日常沉迷對(duì)話系統(tǒng)。說(shuō)不定,正在和你對(duì)話的,并不是不是真正的小軼哦(!?)
“高冷?那是站在冰箱頂端的意思啦。” ?——白鹡鸰
作品推薦:
寫(xiě)了一篇關(guān)于 NLP 綜述的綜述!
全球44家機(jī)構(gòu),55位大佬,歷時(shí)兩年,打造最強(qiáng)NLG評(píng)測(cè)基準(zhǔn)!
谷歌重磅:可以優(yōu)化自己的優(yōu)化器!手動(dòng)調(diào)參或?qū)⒊蔀闅v史!?
ACL20 Best Paper揭曉!NLP模型評(píng)價(jià)體系或?qū)⒂瓉?lái)重大轉(zhuǎn)折
后臺(tái)回復(fù)關(guān)鍵詞【入群】
加入賣萌屋NLP、CV與搜推廣求職討論群
后臺(tái)回復(fù)關(guān)鍵詞【頂會(huì)】
獲取ACL、CIKM等各大頂會(huì)論文集!
?
[1] Memorizing Transformers https://arxiv.org/abs/2203.08913
總結(jié)
以上是生活随笔為你收集整理的谷歌提出 RNN 版 Transformer,或为长文本建模的当前最优解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 回顾经典,Netflix的推荐系统架构
- 下一篇: 2020年,中国AI创业公司将走向何方