bert速度提升fastbert
FastBERT
自從BERT問(wèn)世以來(lái),大多數(shù)NLP任務(wù)的效果都有了一次質(zhì)的飛躍。BERT Large在GLUE test上甚至提升了7個(gè)點(diǎn)之多。但BERT同時(shí)也開(kāi)啟了模型的“做大做深”之路,普通玩家根本訓(xùn)不起,高端玩家雖然訓(xùn)得起但也不一定用得起。
所以BERT之后的發(fā)展也比較清晰,一部分壕大佬們繼續(xù)搞預(yù)訓(xùn)練提升效果,當(dāng)你對(duì)BERT Large望而卻步的時(shí)候,又出了GPT2,又雙出了威震天Megatron-LM,又雙叒出了T5,又雙叒叕出了DeepSpeed。。。每次都是照著一個(gè)數(shù)量級(jí)去加,剩下的人只能默默觀望,翻翻《顯存不夠,如何訓(xùn)練大型神經(jīng)網(wǎng)絡(luò)?》看哪個(gè)trick可以用上。
另一部分大佬著力于給BERT瘦身提升速度。比如剪枝,剪掉多余的連接、多余的注意力頭、甚至LayerDrop[1]直接砍掉一半Transformer層;再比如量化,把FP32改成FP16或者INT8;還有蒸餾,用一個(gè)學(xué)生模型來(lái)學(xué)習(xí)大模型的知識(shí),不僅要學(xué)logits,還要學(xué)attention score。。。
然而,大部分減肥方法都會(huì)帶來(lái)精度的下降。剪枝會(huì)直接降低模型的擬合能力,量化雖然有提升但也有瓶頸,蒸餾的不確定性最大,很難預(yù)知你的BERT教出來(lái)怎樣的學(xué)生。
但!是!
昨天刷到了一篇讓我眼前一亮的文章《FastBERT: a Self-distilling BERT with Adaptive Inference Time》[2],是北大+騰訊+北師大的ACL2020。作者提出了一種新的inference速度提升方式,相比單純的student蒸餾有更高的確定性,且可以自行權(quán)衡效果與速度,簡(jiǎn)單實(shí)用。
后臺(tái)回復(fù)【0409】獲取論文PDF噢~
FastBERT
模型結(jié)構(gòu)
FastBERT的創(chuàng)新點(diǎn)很容易理解,就是在每層Transformer后都去預(yù)測(cè)樣本標(biāo)簽,如果某樣本預(yù)測(cè)結(jié)果的置信度很高,就不用繼續(xù)計(jì)算了。論文把這個(gè)邏輯稱為樣本自適應(yīng)機(jī)制(Sample-wise adaptive mechanism),就是自適應(yīng)調(diào)整每個(gè)樣本的計(jì)算量,容易的樣本通過(guò)一兩層就可以預(yù)測(cè)出來(lái),較難的樣本則需要走完全程。
那么問(wèn)題來(lái)了,用什么去預(yù)測(cè)中間層的結(jié)果呢?作者的解決方案是給每層后面接一個(gè)分類(lèi)器,畢竟分類(lèi)器比Transformer需要的成本小多了:
注:FLOPs (floating point operations)是Tensorflow中提供的浮點(diǎn)數(shù)計(jì)算量統(tǒng)計(jì)
于是模型的整體結(jié)構(gòu)就呼之欲出了:
作者將原BERT模型稱為主干(Backbone),每個(gè)分類(lèi)器稱為分支(Branch)。
要注意的是,這里的分支Classifier都是最后一層的分類(lèi)器蒸餾來(lái)的,作者將這稱為自蒸餾(Self-distillation)。就是在預(yù)訓(xùn)練和精調(diào)階段都只更新主干參數(shù),精調(diào)完后freeze主干參數(shù),用分支分類(lèi)器(圖中的student)蒸餾主干分類(lèi)器(圖中的teacher)的概率分布。
之所以叫自蒸餾,是因?yàn)橹暗恼麴s都是用兩個(gè)模型去做,一個(gè)模型學(xué)習(xí)另一個(gè)模型的知識(shí),而FastBERT是自己(分支)蒸餾自己(主干)的知識(shí)。值得注意的是,蒸餾時(shí)需要freeze主干部分,保證pretrain和finetune階段學(xué)習(xí)的知識(shí)不被影響,僅用brach 來(lái)盡可能的擬合teacher的分布。
那為什么不直接用標(biāo)注數(shù)據(jù)訓(xùn)分支分類(lèi)器呢?因?yàn)橹苯佑?xùn)效果不好唄(攤手~下面是作者在消融實(shí)驗(yàn)給出的結(jié)果:
可以看到,非蒸餾的結(jié)果沒(méi)有蒸餾要好。個(gè)人認(rèn)為是合理的,因?yàn)檫@兩種方式在精調(diào)階段的目標(biāo)不一樣。非自蒸餾是在精調(diào)階段訓(xùn)練所有分類(lèi)器,目標(biāo)函數(shù)有所改變,迫使前幾層編碼器抽取更多的任務(wù)feature。但BERT強(qiáng)大的能力與網(wǎng)絡(luò)深度的相關(guān)性很大,所以過(guò)早地判斷不一定準(zhǔn)確,致使效果下降。
同時(shí),使用自蒸餾還有一點(diǎn)重要的好處,就是不再依賴于標(biāo)注數(shù)據(jù)。蒸餾的效果可以通過(guò)源源不斷的無(wú)標(biāo)簽數(shù)據(jù)來(lái)提升。
模型訓(xùn)練與推理
了解模型結(jié)構(gòu)之后,訓(xùn)練與推理也就很自然了。只比普通的BERT模型多了自蒸餾這個(gè)步驟:
-
Pre-training:同BERT系模型是一樣的,網(wǎng)上那么多開(kāi)源的模型也可以隨意拿來(lái)~
-
Fine-tuning for Backbone:主干精調(diào),也就是給BERT加上分類(lèi)器,用任務(wù)數(shù)據(jù)訓(xùn)練,這里也用不到分支分類(lèi)器,可以盡情地優(yōu)化
-
Self-distillation for branch:分支自蒸餾,用無(wú)標(biāo)簽任務(wù)數(shù)據(jù)就可以,將主干分類(lèi)器預(yù)測(cè)的概率分布蒸餾給分支分類(lèi)器。這里使用KL散度衡量分布距離,loss是所有分支分類(lèi)器與主干分類(lèi)器的KL散度之和
-
Adaptive inference:自適應(yīng)推理,及根據(jù)分支分類(lèi)器的結(jié)果對(duì)樣本進(jìn)行層層過(guò)濾,簡(jiǎn)單的直接給結(jié)果,困難的繼續(xù)預(yù)測(cè)。這里作者定義了新的不確定性指標(biāo),用預(yù)測(cè)結(jié)果的熵來(lái)衡量,熵越大則不確定性越大:
效果
對(duì)于每層分類(lèi)結(jié)果,作者用“Speed”代表不確定性的閾值,和推理速度是正比關(guān)系。因?yàn)殚撝翟叫?=> 不確定性越小 => 過(guò)濾的樣本越少 => 推理速度越慢。
模型最終在12個(gè)數(shù)據(jù)集(6個(gè)中文的和6個(gè)英文的)上的表現(xiàn)還是很好的:
可以看到,在Speed=0.2時(shí)速度可以提升1-10倍,且精度下降全部在0.11個(gè)點(diǎn)之內(nèi),甚至部分任務(wù)上還有細(xì)微提升。相比之下HuggingFace的DistillBERT的波動(dòng)就比較劇烈了,6層模型速度只提升2倍,但精度下降最高會(huì)達(dá)到7個(gè)點(diǎn)。
總結(jié)
FastBERT是一個(gè)在工程上十分實(shí)用的模型,通過(guò)提前輸出簡(jiǎn)單樣本的預(yù)測(cè)結(jié)果,減少模型的計(jì)算負(fù)擔(dān),從而提高推理速度。雖然每層都多了一個(gè)分類(lèi)器,但分類(lèi)器的計(jì)算量也比Transformer小了兩個(gè)數(shù)量級(jí),對(duì)速度影響較小。后續(xù)的分支自蒸餾也設(shè)計(jì)的比較巧妙,可以利用無(wú)監(jiān)督數(shù)據(jù)不斷提升分支分類(lèi)器的效果。
另外,Sample-wise adaptive mechanism和Self-distillation這兩個(gè)idea也是在本文首次提出,相信會(huì)起到拋玉引玉的作用,衍生更多此類(lèi)工作。論文本身也還有一些想象空間,比如分別優(yōu)化每層分類(lèi)器,因?yàn)樵谥鞲杀籪reeze的情況下各個(gè)分支是獨(dú)立的;或者自蒸餾unfreeze主干,再加上數(shù)據(jù)自蒸餾一把,說(shuō)不定還會(huì)有性能提升。
值得一提的是,本文的一作劉偉杰(北大)正是K-BERT[3]的作者,那也是一篇我很喜歡的文章,把知識(shí)融入到BERT的方式比較優(yōu)雅,真心期待作者有更多的idea~
最后再回來(lái)夸一下,FastBERT著實(shí)很實(shí)用,而且完全不會(huì)影響到手頭調(diào)好的BERT,只需要蒸餾幾個(gè)淺層分類(lèi)器,再把判斷機(jī)制加上就可以了。而且比起不太穩(wěn)定的蒸餾來(lái)說(shuō)放在線上也更有底,穩(wěn)穩(wěn)的幸福。
唯一的遺憾是源碼要在文章發(fā)表時(shí)才會(huì)放出來(lái),一起去催更吧~
https://github.com/autoliuweijie/FastBERT
總結(jié)
以上是生活随笔為你收集整理的bert速度提升fastbert的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: oracle sqlserver 查看
- 下一篇: 显卡不够时,如何训练大型网络