Poly-encoders(2020 ICLR)
1.摘要
Cross - encoder 對(duì)句子對(duì)進(jìn)行完全self-attention ,Bi - encoder分別對(duì)句子對(duì)進(jìn)行編碼。前者往往性能更好,但實(shí)際使用起來(lái)太慢。在這項(xiàng)工作中,作者開(kāi)發(fā)了一種新的Transformer 架構(gòu):Poly - encoder,它學(xué)習(xí)全局而不是token級(jí)的self-attention功能。
2.介紹
本文對(duì)BERT模型進(jìn)行改進(jìn),將其用于多句子評(píng)分:給定一個(gè)輸入上下文,給一組候選標(biāo)簽的評(píng)分。這是檢索和對(duì)話任務(wù)的常見(jiàn)形式,它們必須考慮兩個(gè)方面:
因此提出Poly - encoder,相比于Cross - encoder 速度更快,預(yù)測(cè)質(zhì)量比 Bi - encoder更高。并且發(fā)現(xiàn),使用與下游任務(wù)更相似的數(shù)據(jù)對(duì)Poly - encoder進(jìn)行預(yù)訓(xùn)練相比于BERT帶來(lái)了顯著的受益。
3. task
考慮兩個(gè)任務(wù):
3. 對(duì)話任務(wù)(句子選擇):ConvAI2 , DSTC7,Ubuntu V2
4. IR任務(wù)(文章檢索):Wikipedia Article Search
4.方法
4.1 Transformer 和pre-train策略
Transformers:本文的transformer(Bi , Cross,Ploy)在后面描述,它們與BERT有相同大小和尺寸,12層,12個(gè)attention heads,hidden size 768。除了考慮預(yù)訓(xùn)練的BERT權(quán)重外,作者還重新訓(xùn)練了兩個(gè)Transformer:
前者是為了驗(yàn)證再現(xiàn)類(lèi)似BERT的設(shè)置給我們提供了與之前報(bào)告相同的結(jié)果,而后者測(cè)試對(duì)更類(lèi)似于感興趣的下游任務(wù)的數(shù)據(jù)的預(yù)訓(xùn)練是否有所幫助。
Input Representation:本文的預(yù)訓(xùn)練輸入[INPUT,LABEL]的連接,其中兩者都被特殊的標(biāo)記[SSS]包圍。
在Reddit上進(jìn)行預(yù)訓(xùn)練時(shí),輸入的是上下文,label的是下一句話。當(dāng)在維基百科和多倫多圖書(shū)上進(jìn)行預(yù)訓(xùn)練時(shí),輸入是一句話,標(biāo)簽是文本中的下一句話。每個(gè)輸入令牌都表示為三個(gè)嵌入的總和:token嵌入、position嵌入和segment嵌入,輸入標(biāo)記的段為0,標(biāo)簽標(biāo)記的段為1。Pre-training Procedure:對(duì)于維基百科數(shù)據(jù),與BERT相同,采用MLM訓(xùn)練。在Reddit上的預(yù)訓(xùn)練中添加了一個(gè)下一個(gè)話語(yǔ)預(yù)測(cè)任務(wù),它與BERT略有不同,因?yàn)橐粋€(gè)話語(yǔ)可以由幾個(gè)句子組成。在訓(xùn)練過(guò)程中,一半時(shí)間是真實(shí)的下一句話,另一半時(shí)間是從數(shù)據(jù)集中隨機(jī)抽取的一句話。
使用Adam優(yōu)化器,其學(xué)習(xí)率為2e-4,β1= 0.9,β2= 0.98,沒(méi)有L2權(quán)重衰減,線性學(xué)習(xí)率預(yù)熱,以及學(xué)習(xí)率的平方根倒數(shù)衰減。在所有層上使用0.1的dropout。
batch size :32000個(gè)由相似長(zhǎng)度的連接[INPUT,LABEL]組成的token。在32個(gè)GPU上訓(xùn)練模型14天。
Fine-tuning:經(jīng)過(guò)預(yù)訓(xùn)練后,考慮三種架構(gòu)來(lái)微調(diào)Transformer :Bi - encoder,Cross - encoder Poly - encoder。
預(yù)先定義顏色標(biāo)志
4.2 Bi - encoder
在Bi - encoder中,輸入上下文和候選標(biāo)簽都被編碼成向量:
其中T1T_1T1?和T2T_2T2?是兩個(gè)預(yù)訓(xùn)練好的Transformer ,它們最初以相同的權(quán)重開(kāi)始,但在微調(diào)期間可以單獨(dú)更新。TTT(xxx) = h1h_1h1?,…,hNh_NhN?是Transformer TTT的輸出,rrreeeddd((()))是一個(gè)函數(shù),它將向量序列簡(jiǎn)化為一個(gè)向量。由于輸入和標(biāo)簽是分開(kāi)編碼的,因此兩者的段標(biāo)記都是000。
h1h_1h1?對(duì)應(yīng)于token : [S]。本文考慮了通過(guò)rrreeeddd((()))將輸出減少為一個(gè)表示的三種方法:
結(jié)果在實(shí)驗(yàn)中展示。
Scoring : 由點(diǎn)積sss(ccctttxxxttt,cccaaannndddi_ii?) = yyyc_cc?t_tt?x_xx?t_tt??·?yyyc_cc?a_aa?n_nn?d_dd?i_ii?給出的候選cccaaannndddi_ii?的分?jǐn)?shù)。網(wǎng)絡(luò)被訓(xùn)練成最小化交叉熵?fù)p失,其中l(wèi)ogits是yyyc_cc?t_tt?x_xx?t_tt??·?yyyc_cc?a_aa?n_nn?d_dd?1_11?,…,yyyc_cc?t_tt?x_xx?t_tt??·?yyyc_cc?a_aa?n_nn?d_dd?n_nn?,其中c_cc?a_aa?n_nn?d_dd?1_11?是正確的標(biāo)簽,其他標(biāo)簽是從訓(xùn)練集中選擇的(每個(gè)batch 中其他標(biāo)簽為負(fù)例)。
Inference speed :使用FAISS庫(kù)構(gòu)建索引,存儲(chǔ)embedding向量,在推理時(shí),只需要點(diǎn)積操作這一步驟。
4.3 Cross-encoder
上下文和標(biāo)簽的connection:
fffiiirrrsssttt是一個(gè)函數(shù),取最后一層輸出的第一個(gè)向量([S] token )。Cross-encoder能夠在上下文和候選之間執(zhí)行self-attention,從而產(chǎn)生比Bi - encoder更豐富的提取機(jī)制。
Scoring : 用一個(gè)線性層WWW映射Transformer的輸出為一個(gè)標(biāo)量作為評(píng)分:
LLLooossssss與Bi - encoder同為交叉熵函數(shù)。與Bi - encoder不同,Cross-encoder不能將同batch內(nèi)的其他標(biāo)簽作為負(fù)例回收,因此在訓(xùn)練集中采樣負(fù)例。Cross-encoder使用的內(nèi)存比Bi - encoder多得多,導(dǎo)致batch小得多。
Inference speed:在推理時(shí),每個(gè)候選對(duì)象都必須與輸入上下文相連接,并且必須通過(guò)整個(gè)模型的正向傳播。所以并不適用大規(guī)模對(duì)象。
4.4 Poly-encoder
Poly-encoder目標(biāo)是從上述兩個(gè)類(lèi)型encoder中獲得最佳結(jié)果。
像Bi - encoder一樣,Poly-encoder對(duì)上下文和標(biāo)簽使用兩個(gè)獨(dú)立的Transformer,候選對(duì)象被編碼成單個(gè)向量yyyc_cc?a_aa?n_nn?d_dd?i_ii?。因此,可以使用預(yù)計(jì)算緩存來(lái)實(shí)現(xiàn)Poly-encoder方法。
但是,輸入的上下文通常比候選長(zhǎng)得多,Poly-encoder用mmm個(gè)向量聚合Transformer輸出表示為:(yyy1^11c_cc?t_tt?x_xx?t_tt?,…,yyym^mmc_cc?t_tt?x_xx?t_tt?),其中mmm將影響推理速度。為了獲得代表輸入的這mmm個(gè)全局特征,學(xué)習(xí)mmm個(gè)上下文codes (c1c_1c1?,…,cmc_mcm?),其中,cic_ici?通過(guò)關(guān)注前一層的所有輸出來(lái)表示yyyi^iic_cc?t_tt?x_xx?t_tt?。也就是說(shuō),獲得yyyi^iic_cc?t_tt?x_xx?t_tt?:(這里其實(shí)就是Inner - attention的想法)
mmm個(gè)上下文代碼是隨機(jī)初始化的,并在微調(diào)期間學(xué)習(xí)。最后,給定mmm個(gè)全局上下文特征,使用yyyc_cc?a_aa?n_nn?d_dd?i_ii?作為查詢來(lái)處理它們:(這里就是attention思路)
5 Experiments
data
評(píng)價(jià)指標(biāo)
Recall@C , MRR
5.1 Bi-encoder and Cross-encoder
首先實(shí)驗(yàn)原始BERT的權(quán)重 微調(diào)。Bi-encoder的情況下,可以通過(guò)將其他batch視為負(fù)訓(xùn)練樣本來(lái)使用大量的負(fù)樣本,從而避免重新計(jì)算它們的嵌入。在8個(gè)Nvidia Volta v100 GPUs上,使用浮點(diǎn)運(yùn)算,在ConvAI2上達(dá)到512個(gè)元素的batch。表2顯示,在這種設(shè)置下,使用更大的batch獲得更高的性能,其中511個(gè)負(fù)例產(chǎn)生最佳結(jié)果。
對(duì)于其他任務(wù),將batch大小保持在256,因?yàn)檫@些數(shù)據(jù)集中較長(zhǎng)的序列會(huì)占用更多內(nèi)存。Cross-encoder計(jì)算量更大,因?yàn)槊看味急仨氈匦掠?jì)算(上下文,候選)對(duì)的嵌入。因此,將其批量限制為16,并從訓(xùn)練集中提供負(fù)隨機(jī)樣本。對(duì)于DSTC7和Ubuntu V2,選15個(gè)負(fù)例;對(duì)于ConvAI2,數(shù)據(jù)集提供19個(gè)負(fù)例。
本文嘗試了兩個(gè)優(yōu)化器:權(quán)重衰減為0.01的Adam和沒(méi)有權(quán)重衰減的Adamax ;
基于驗(yàn)證集的性能,選擇在使用BERT權(quán)重時(shí)與Adam一起微調(diào)。學(xué)習(xí)速率初始化為5e-5,Bi-encoder和poly-encoder預(yù)熱100次,Cross-encoder預(yù)熱1000次。在每半個(gè)epoch,在有效集合上評(píng)估的損失平穩(wěn)區(qū)間,學(xué)習(xí)率下降0.4倍。 表3顯示了使用帶有衰減優(yōu)化器的Adam對(duì)BERT提供的權(quán)重的各個(gè)層進(jìn)行微調(diào)時(shí)的驗(yàn)證性能。 除詞嵌入外,對(duì)整個(gè)網(wǎng)絡(luò)進(jìn)行微調(diào)非常重要。
5.2 Poly-encoder
poly - encoder 后面的數(shù)字表示為集合codes mmm的大小,得到的結(jié)論是,在計(jì)算成本允許的情況下,mmm越大,模型性能越好。有些甚至優(yōu)于Cross-encoder,但是Cross-encoder的計(jì)算成本是巨大的。見(jiàn)下圖推理時(shí)間比較表。
CPU:80 core Intel Xeon processor CPU E5-2698
GPU:單個(gè)Nvidia Quadro GP100 using cuda 10.0 and cudnn 7.4
另外,在與下游任務(wù)相似的數(shù)據(jù)上訓(xùn)練時(shí),性能相比于BERT有相當(dāng)大的提升。
附錄
A . 在四個(gè)數(shù)據(jù)集上的訓(xùn)練時(shí)間(小時(shí)):
B. : Bi-encoder的輸出采樣策略比較:(first表示取[SSS]位置的token輸出,Avg first 16 outputs表示取前16個(gè)輸出(h1h_1h1?,h2h_2h2?,…,h1h_1h1?6_66?)…),可以看到直接取[SSS]效果最好。
C. 上下文向量的替代選擇
本文考慮了其他幾種從輸出(hhh1^11c_cc?t_tt?x_xx?t_tt?,…,hhhN^NNc_cc?t_tt?x_xx?t_tt? )導(dǎo)出上下文向量(yyy1^11c_cc?t_tt?x_xx?t_tt?,…,yyym^mmc_cc?t_tt?x_xx?t_tt?)的方法(mmm《《《《《《NNN):
下表報(bào)告了m為{1,4,16,64,360}時(shí)各個(gè)策略的評(píng)價(jià)指標(biāo):
三個(gè)數(shù)據(jù)集上的整體評(píng)價(jià):
可以看到,在不同數(shù)據(jù)集上Learnt-m和First-m各有千秋。
總結(jié)
以上是生活随笔為你收集整理的Poly-encoders(2020 ICLR)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 智慧医疗、互联网医疗相关术语
- 下一篇: 项目中使用completablefutu