48小时单GPU训练DistilBERT!这个检索模型轻松达到SOTA
?PaperWeekly 原創(chuàng) ·?作者 | Maple小七
單位 | 北京郵電大學(xué)
研究方向 | 自然語(yǔ)言處理
論文標(biāo)題:?
Efficiently Teaching an Effective Dense Retriever with Balanced Topic Aware Sampling
收錄會(huì)議:
SIGIR 2021
論文鏈接:
https://arxiv.org/abs/2104.06967
代碼鏈接:
https://github.com/sebastian-hofstaetter/tas-balanced-dense-retrieval
基于 BERT 的稠密檢索模型雖然在 IR 領(lǐng)域取得了階段性的成功,但檢索模型的訓(xùn)練、索引和查詢效率一直是 IR 社區(qū)關(guān)注的重點(diǎn)問題,雖然超越 SOTA 的檢索模型越來越多,但模型的訓(xùn)練成本也越來越大,以至于要訓(xùn)練最先進(jìn)的稠密檢索模型通常都需要 8×V100 的配置。而采用本文提出的 TAS-Balanced 和 Dual-supervision 訓(xùn)練策略,我們僅需要在單個(gè)消費(fèi)級(jí) GPU 上花費(fèi) 48 小時(shí)從頭訓(xùn)練一個(gè) 6 層的 DistilBERT 就能取得 SOTA 結(jié)果,這再一次證明了當(dāng)前大部分稠密檢索模型的訓(xùn)練是緩慢且低效的。
緒言
在短短的兩年時(shí)間內(nèi),當(dāng)初被質(zhì)疑是 Neural Hype 的 Neural IR 現(xiàn)在已經(jīng)被 IR 社區(qū)廣泛接受,不少開源搜索引擎也逐漸支持了基于 BERT 的稠密檢索(dense retrieval),基本達(dá)到了開箱即用的效果。其中,DPR 提出的 是當(dāng)前最主流的稠密檢索模型,然而眾所周知的是, 的可遷移性遠(yuǎn)不如 BM25 這類 learning-free 的傳統(tǒng)檢索方法,想要在具體的業(yè)務(wù)場(chǎng)景下使用 并取得理想的結(jié)果,我們通常需要準(zhǔn)備充足的標(biāo)注數(shù)據(jù)進(jìn)一步訓(xùn)練檢索模型。
因此,如何高效地訓(xùn)練一個(gè)又快又好的 一直是 Neural IR 的研究熱點(diǎn)。目前來看,改進(jìn) 主要有兩條路線可走,其中一條路線是改變 batch 內(nèi)的樣本組合,讓模型能夠獲取更豐富的對(duì)比信息:
優(yōu)化模型的訓(xùn)練過程:這類方法的代表作是 ANCE 提出的動(dòng)態(tài)負(fù)采樣策略,其基本思路是在訓(xùn)練過程中定期刷新索引,從而為模型提供更優(yōu)質(zhì)的難負(fù)樣本,而不是像 DPR 那樣僅從 BM25 中獲取負(fù)樣本。在此基礎(chǔ)上,LTRe 指出目前的檢索模型其實(shí)是按 learning to rank 來訓(xùn)練的,因?yàn)橛?xùn)練過程中模型僅能看到一個(gè) batch 內(nèi)的樣本,但如果我們只訓(xùn)練 query encoder,凍結(jié) passage embedding,我們就可以按照 learning to retrieve 的方式計(jì)算全局損失,而不是僅計(jì)算一個(gè) batch 的損失。除此之外,RocketQA 提出了 Cross Batch 技巧來增大 batch size,由于檢索模型采用對(duì)比損失訓(xùn)練,因此理論上增大 batch size 帶來的基本都是正收益。
然而,這三種策略都在原始的 的基礎(chǔ)上增加了額外的計(jì)算成本,并且實(shí)現(xiàn)都比較復(fù)雜。除此之外,我們也可以利用知識(shí)蒸餾(knowledge distillation)為模型提供更優(yōu)質(zhì)的監(jiān)督信號(hào):
優(yōu)化模型的監(jiān)督信號(hào):?我們可以將表達(dá)能力更強(qiáng)但運(yùn)行效率更低的 或 當(dāng)作 teacher model 來為 提供 soft label。在檢索模型的訓(xùn)練中,知識(shí)蒸餾的損失函數(shù)有很多可能的選擇,本文僅討論 pairwise loss 和 in-batch negative loss,其中 in-batch negative loss 在 pairwise loss 的基礎(chǔ)上將 batch 內(nèi)部其他 query 的負(fù)樣本也當(dāng)作當(dāng)前 query 的負(fù)樣本,這兩類蒸餾 loss 的詳細(xì)定義后文會(huì)講。
本文同樣是在上述兩個(gè)方面對(duì) 做出優(yōu)化,在訓(xùn)練過程方面,作者提出了 Balanced Topic Aware Sampling(TAS-Balanced)策略來構(gòu)建 batch 內(nèi)的訓(xùn)練樣本;在監(jiān)督信號(hào)方面,作者提出了將 pairwise loss 和 in-batch negative loss 結(jié)合的 dual-supervision 蒸餾方式。
Dual Supervision
越來越多的證據(jù)表明知識(shí)蒸餾能夠帶來稠密檢索模型性能的提升,本文將 提供的 pairwise loss 和 提供的 in-batch negative loss 結(jié)合起來為 提供監(jiān)督信號(hào),下面先簡(jiǎn)單介紹一下 teacher model 和 student model。
Teacher Model:、?
是當(dāng)前應(yīng)用最為廣泛的排序模型,它簡(jiǎn)單地將 query 和 passage 的拼接作為 的輸入序列,然后對(duì) 輸出向量做一個(gè)線性變換得到相關(guān)性打分:
是一個(gè)經(jīng)典的多向量表示模型,它將 query 和 passage 之間的交互簡(jiǎn)化為 max-sum 來克服 無法緩存 passage 向量的問題,其基本思路是首先對(duì) query 和 passage 分別編碼
然后計(jì)算每個(gè) query term 和每個(gè) passage term 的點(diǎn)積相似度,按 doc term 做 max-pooling 并按 query term 求和獲取 query 和 passage 的相似度:
雖然理論上 可以對(duì) passage 建立離線索引,但存儲(chǔ) passage 多向量表示的資源開銷是非常大的,并且該存儲(chǔ)成本隨著語(yǔ)料庫(kù)的 term 數(shù)量呈線性增長(zhǎng),再加上 max-sum 的操作也會(huì)帶來額外的計(jì)算成本,因此這里我們將 當(dāng)作 的 teacher。
Student Model:?
DPR 提出的 僅使用二元標(biāo)簽和 BM25 生成的負(fù)樣本訓(xùn)練模型, 首先將 query 和 passage 獨(dú)立編碼為單個(gè)向量:
然后計(jì)算 和 的點(diǎn)積相似度:
在檢索階段, 首先對(duì) query 編碼,然后利用 faiss 做最大內(nèi)積檢索,下表展示了在單個(gè)消費(fèi)級(jí) GPU 上 6 層 DistilBERT 在 800 萬(wàn) passage 集合上的檢索速度。
2.1 Dual-Teacher Supervision
如果僅看監(jiān)督信號(hào)的質(zhì)量, 提供的 in-batch negative loss 當(dāng)然是最優(yōu)質(zhì)的。然而, 雖然在表達(dá)能力上比 更強(qiáng),但它實(shí)際上很少用于計(jì)算 in-batch negative loss,因?yàn)? 需要單獨(dú)編碼每個(gè) query-passage 樣本對(duì),所以其計(jì)算開銷隨著 batch size 二次增長(zhǎng),而 解耦了 query 和 passage 的表示,因此它的開銷是隨著 batch size 線性增長(zhǎng)的,其 in-batch negative loss 的計(jì)算效率要高得多。
因此這里我們只讓 提供 pairwise loss,具體來說,我們首先利用訓(xùn)練好的 對(duì)訓(xùn)練集中所有的 query-passage 樣本對(duì)打分,然后計(jì)算 的蒸餾損失,蒸餾損失的具體形式有很多選擇,這里作者選擇了 Margin-MSE loss 作為 pairwise loss:
其中 和 分別為 和 。
我們同時(shí)讓 提供 in-batch negative loss:
in-batch negative loss 中的 其實(shí)也可以替換成別的 loss,作者在后續(xù)實(shí)驗(yàn)中也嘗試了一些看起來更有效的 listwise loss,然而實(shí)驗(yàn)結(jié)果表明 Margin-MSE loss 依舊是最佳的選擇。因此,作者最終提出的蒸餾 loss 是 pairwise loss 和 in-batch negative loss 的加權(quán)平均,在后續(xù)實(shí)驗(yàn)中,作者設(shè)加權(quán)系數(shù) :
Balanced Topic Aware Sampling
在原始的 的訓(xùn)練中,我們首先隨機(jī)地從 query 集合 中采樣 個(gè) ,然后再為每個(gè) 隨機(jī)采樣一個(gè)正樣本 和一個(gè)負(fù)樣本 組成一個(gè) batch:
其中 表示從集合 無放回地采樣 個(gè)樣本。由于訓(xùn)練集是非常大的,每個(gè) batch 中的 幾乎都是沒有相關(guān)性的,但是當(dāng)我們計(jì)算 in-batch negative loss 時(shí),query 不僅和自身的 交互,也和別的 query 對(duì)應(yīng)的 交互,然而,由于 對(duì)模型來說大概率是簡(jiǎn)單樣本,因此它所能提供的信息增益是非常少的,這也導(dǎo)致了每個(gè) batch 所能提供的信息量偏少,使得檢索模型需要長(zhǎng)時(shí)間的訓(xùn)練才能收斂。
3.1?TAS
針對(duì)這個(gè)問題,作者提出了 Topic Aware Sampling(TAS)策略來構(gòu)建 batch 內(nèi)的訓(xùn)練樣本,具體來說,在訓(xùn)練之前,我們先利用 k-means 算法將 query 聚類到 k 個(gè) cluster 中:
其中 query 的表示 由基線模型 提供, 為 的聚類中心,這樣,每個(gè) cluster 中的 query 都是主題相關(guān)的,在構(gòu)建 batch 的時(shí)候,我們可以先從 cluster 的集合 中隨機(jī)抽樣 個(gè) cluster,然后在每個(gè) cluster 上隨機(jī)抽樣 個(gè) query:
在后續(xù)的實(shí)驗(yàn)中,作者為 40 萬(wàn)個(gè) query 創(chuàng)建了?k=2000?個(gè) cluster,并設(shè) batch size 大小為 b=32,組建 batch 時(shí)隨機(jī)抽樣的 cluster 數(shù)量為 n=1,這樣,每個(gè) batch 中的樣本都來自于同一個(gè) cluster。如下圖所示,相比于在整個(gè) query 集合上隨機(jī)抽樣,TAS 策略生成的 batch 內(nèi)部的 query 有更高的主題相似性。
3.2?TAS-balanced
在組建 batch 的時(shí)候,我們還需要為每個(gè)采樣到的 query 配置正負(fù)樣本對(duì) 。不難想到,幾乎所有 query 對(duì)應(yīng)的 都比 少得多,如果用獨(dú)立隨機(jī)抽樣的方式獲取 和 ,那么組成的 的 margin(也就是 )大概率是很大的,因此大部分 對(duì)模型來說是簡(jiǎn)單樣本,因?yàn)槟P秃苋菀讓? 和 分開。
因此,我們可以在 TAS 策略的基礎(chǔ)上進(jìn)一步均衡 batch 內(nèi)正負(fù)樣本對(duì)的 margin 分布以減少 high margin(low information)的正負(fù)樣本對(duì)。具體來說,針對(duì)每個(gè) query,我們首先計(jì)算它對(duì)應(yīng)的樣本對(duì)集合的最小 margin 和最大 margin,然后將該區(qū)間分割為 個(gè)子區(qū)間,在為 query 配置 時(shí),我們首先從這 個(gè)子區(qū)間中隨機(jī)選擇一個(gè)子區(qū)間,然后從 margin 落在該子區(qū)間內(nèi)的 集合中隨機(jī)采樣并組成一個(gè)訓(xùn)練樣本:
這樣,在構(gòu)建一個(gè) batch 的時(shí)候,我們首先需要采樣一個(gè) cluster,然后采樣 b?個(gè) query,接下來為每個(gè) query 采樣一個(gè) margin 子區(qū)間,最后在該子區(qū)間上采樣一個(gè)正負(fù)樣本對(duì),這整套流程就是所謂的 TAS-balanced batch sampling:
需要注意的是,TAS-balanced 策略不會(huì)影響模型的訓(xùn)練速度,因?yàn)?batch 的構(gòu)建是可以并行處理或者預(yù)先處理好的。TAS-balanced 策略組建的 batch 對(duì)模型來說整體的難度更大,因此為模型提供了更多的信息量,即使采用較小的 batch size,模型也能很好地收斂。如下表所示,我們可以在消費(fèi)級(jí)顯卡上(11GB 內(nèi)存)高效地訓(xùn)練 而不需要昂貴的 8×V100 的配置,因?yàn)樵摲椒ú恍枰?ANCE 那樣重復(fù)刷新索引,也不需要像 RocketQA 那樣進(jìn)行超大批量的訓(xùn)練。
3.3 Experiment
作者選擇 MSMACRO-Passage 官方提供的 4000 萬(wàn)正負(fù)樣本對(duì)作為檢索模型的訓(xùn)練集,并選擇 MSMACRO-DEV(sparsely-judged,包含 6980 個(gè) query)和 TREC-DL 19/20(densely-judged,包含 43/54 個(gè) query)作為驗(yàn)證集。同時(shí) 和 ?均采用 6 層的 DistilBERT 初始化,且沒有使用預(yù)訓(xùn)練的檢索模型。
Results
4.1 Source of Effectiveness
首先我們對(duì)作者提出的 Dual-supervision 做消融實(shí)驗(yàn),如下表所示。對(duì)于基于 pairwise loss 的知識(shí)蒸餾,Margin-MSE loss 的優(yōu)越性已經(jīng)被之前的論文證明,所以這里僅討論 in-batch negative loss 的有效性。作者對(duì)比了基于 listwise loss 的 KL Divergence、ListNet 和 Lambdarank,實(shí)驗(yàn)結(jié)果表明這些損失的效果都不如 Margin-MSE loss,尤其是在 R@1K 上面。
為什么 pairwise 的 Margin-MSE 比 listwise loss 更好呢?因?yàn)?Margin-MSE 不僅僅是讓模型去學(xué)習(xí) teacher 所給出的排序,同時(shí)還學(xué)習(xí) teacher score 的分布,由于 batch 內(nèi)部樣本的 order 實(shí)際上是有偏的,它并不能準(zhǔn)確刻畫樣本間的真實(shí)距離,因此比起學(xué)習(xí) order,學(xué)習(xí) score 分布其實(shí)是一種更精確的方式。另外,由于 teacher 和 student 在訓(xùn)練階段所使用的損失是一致的,這也會(huì)讓 student 更容易學(xué)習(xí)到 teacher 的 score 分布。
接下來我們對(duì) TAS-Balanced 策略做消融實(shí)驗(yàn),如下表所示。總體來說,TAS-balanced 策略加上 Dual-supervision 蒸餾可以在各個(gè)數(shù)據(jù)集上取得最優(yōu)性能。值得關(guān)注的是,在單獨(dú)的 pairwise loss 的監(jiān)督下使用 TAS 策略其實(shí)并不能帶來明顯的提升,這是因?yàn)?TAS 是面向 in-batch negative loss 設(shè)計(jì)的,使用 pairwise loss 訓(xùn)練時(shí),batch 內(nèi)的樣本是沒有交互的,因此 TAS 也就不會(huì)起作用。而 TAS-balanced 策略會(huì)影響正負(fù)樣本對(duì)的組成方式,因此會(huì)對(duì) pairwise loss 產(chǎn)生一定的影響。
4.2 Comparing to Baselines
下表對(duì)比了作者的模型和其他模型的性能,對(duì)比最后三行,我們可以發(fā)現(xiàn)一個(gè)有趣的現(xiàn)象:增大 batch size 在 TREC-DL 這類 densely-judge 的數(shù)據(jù)集上沒有帶來提升,但在 MSMACRO-DEV 這類 sparsely-judge 的數(shù)據(jù)集上會(huì)帶來持續(xù)的提升。?因此作者猜想增大 batch size 會(huì)導(dǎo)致模型在 sparsely-judge 的 MSMACRO 上過擬合,RocketQA 的 SOTA 表現(xiàn)可能僅僅是因?yàn)樗?batch size 夠大。
4.3 TAS-Balanced Retrieval in a Pipeline
為了進(jìn)一步證明方法的有效性,作者嘗試將 TAS-Balance 訓(xùn)練的檢索模型應(yīng)用到召回-排序系統(tǒng)中。眾所周知,稠密檢索和稀疏檢索是互補(bǔ)的,且融合稀疏檢索幾乎不會(huì)影響召回速度,因此作者考慮將稀疏檢索的 docT5query 的檢索結(jié)果和 TAS-balanced 稠密檢索模型的結(jié)果融合,然后使用最先進(jìn)的 mono-duo-T5 排序模型對(duì)檢索結(jié)果做重排。
選擇不同的召回模型、排序模型和不同大小的候選集,我們可以得到不同延遲水平的檢索系統(tǒng)。如上表所示,作者提出的模型在各個(gè)延遲水平上均取得了優(yōu)異的表現(xiàn)。值得注意的是,在高延遲系統(tǒng)中,排序模型 mono-duo-T5 是在 BM25 的召回結(jié)果上訓(xùn)練的,這實(shí)際上會(huì)導(dǎo)致訓(xùn)練測(cè)試分布不一致的問題,所以 TAS-B+mono-duo-T5 甚至沒能超越 BM25+mono-duo-T5,為了取得更好的性能,我們應(yīng)該先訓(xùn)召回模型,然后在召回模型的給出召回結(jié)果上訓(xùn)練排序模型,這其實(shí)也間接反映了當(dāng)前的排序模型泛化性不足的問題。
Discussion
本篇論文最大的亮點(diǎn)是 TAS-Balanced 策略的高效性,使用作者的模型,我們僅需要在單個(gè)消費(fèi)級(jí) GPU 上從頭訓(xùn)練 48 小時(shí)就能取得 SOTA 結(jié)果,極大地降低了檢索模型的訓(xùn)練成本,這在之前是無法想象的。實(shí)際上,比起 NLP 社區(qū),IR 社區(qū)更加強(qiáng)調(diào)模型和數(shù)據(jù)的 Efficiency,這一課題在將來也一定會(huì)受到持續(xù)的關(guān)注。
特別鳴謝
感謝 TCCI 天橋腦科學(xué)研究院對(duì)于 PaperWeekly 的支持。TCCI 關(guān)注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識(shí)的人。
總有一些你不認(rèn)識(shí)的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵(lì)高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺(tái)上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)術(shù)熱點(diǎn)剖析、科研心得或競(jìng)賽經(jīng)驗(yàn)講解等。我們的目的只有一個(gè),讓知識(shí)真正流動(dòng)起來。
📝?稿件基本要求:
? 文章確系個(gè)人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺(tái)已發(fā)表或待發(fā)表的文章,請(qǐng)明確標(biāo)注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無版權(quán)問題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競(jìng)爭(zhēng)力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請(qǐng)備注即時(shí)聯(lián)系方式(微信),以便我們?cè)诟寮x用的第一時(shí)間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長(zhǎng)按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁(yè)搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的48小时单GPU训练DistilBERT!这个检索模型轻松达到SOTA的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2021胡润全球富豪榜 钟睒睒成全球前十
- 下一篇: EMNLP 2021 | 正则表达式与神