论文阅读|DeiT
Training data-efficient image transformers & distillation through attention
論文鏈接:https://export.arxiv.org/pdf/2012.12877
代碼:https://github.com/facebookresearch/deit
摘要
純基于注意力的神經(jīng)網(wǎng)絡(luò)被證明可以解決圖像分類(lèi)等圖像理解任務(wù),但是這些高性能的網(wǎng)絡(luò)結(jié)構(gòu)通常需要使用大型的基礎(chǔ)設(shè)施預(yù)先訓(xùn)練了數(shù)億個(gè)圖像,因此限制了他們的采用。
為此,對(duì)于這種設(shè)計(jì)龐大的預(yù)訓(xùn)練量,作者提出了一種convolution-free transformers的結(jié)構(gòu),只在Imagenet上進(jìn)行訓(xùn)練,就具有競(jìng)爭(zhēng)力。在不需要其他額外數(shù)據(jù)進(jìn)行預(yù)訓(xùn)練的情況下,在ImageNet上達(dá)到了top-1 accuracy 達(dá)到83.1%的效果。
此外,作者還引入了teacher-student策略,依賴(lài)于一個(gè)蒸餾標(biāo)記(distillation token),確保學(xué)生通過(guò)注意力從老師那里學(xué)習(xí)。當(dāng)卷積網(wǎng)絡(luò)作為teacher時(shí),效果達(dá)到了85.2%的準(zhǔn)確率
DeiT能取得更好效果的方法:
- better hyperparameter更好的超參數(shù)設(shè)置(模型初始化,learning-rate等設(shè)置),
- data augmentation(多種數(shù)據(jù)增強(qiáng)方式)
- distillation(知識(shí)蒸餾)
Introduction
ViT的問(wèn)題:
- ViT需要大量的GPU資源
- ViT的預(yù)訓(xùn)練數(shù)據(jù)集JFT-300M并沒(méi)有公開(kāi)
- 超參數(shù)設(shè)置不好很容易導(dǎo)致訓(xùn)練效果差
- 只用ImageNet訓(xùn)練準(zhǔn)確率沒(méi)有很好
對(duì)于VIT訓(xùn)練數(shù)據(jù)巨大,超參數(shù)難設(shè)置導(dǎo)致訓(xùn)練效果不好的問(wèn)題,提出了DeiT。
DeiT : Data-efficient image Transformers
DeiT的模型和VIT的模型幾乎是相同的,可以理解為本質(zhì)上是在訓(xùn)一個(gè)VIT。
DeiT特點(diǎn):
- DeiT不包含卷積層,可以在沒(méi)有外部數(shù)據(jù)的情況下實(shí)現(xiàn)與ImageNet上的最新技術(shù)相媲美的結(jié)果。
- 引入了一種基于distillation token的新蒸餾過(guò)程,它扮演著與class token相同的角色,只不過(guò)它的目的是學(xué)習(xí)教師網(wǎng)絡(luò)中的預(yù)測(cè)結(jié)果,兩個(gè)標(biāo)記都通過(guò)注意力在轉(zhuǎn)換器中交互。
- 通過(guò)蒸餾,圖像transformers從一個(gè)convnet學(xué)到的比從另一個(gè)性能相當(dāng)?shù)膖ransformers學(xué)到的更多。
- 在Imagenet上預(yù)先學(xué)習(xí)的模型在轉(zhuǎn)移到不同的下游任務(wù)時(shí)是有競(jìng)爭(zhēng)力的
Related work
Knowledge Distillation 知識(shí)蒸餾
參考資料:(58條消息) 深度學(xué)習(xí)之知識(shí)蒸餾(Knowledge Distillation)_AndyJ的學(xué)習(xí)之旅-CSDN博客_知識(shí)蒸餾溫度
簡(jiǎn)單來(lái)說(shuō)就是用teacher模型去訓(xùn)練student模型,通常teacher模型更大而且已經(jīng)訓(xùn)練好了,student模型是我們當(dāng)前需要訓(xùn)練的模型。在這個(gè)過(guò)程中,teacher模型是不訓(xùn)練的。
軟蒸餾 soft distillation
?當(dāng)teacher模型和student模型拿到相同的圖片時(shí),都進(jìn)行各自的前向,這時(shí)teacher模型就拿到了具有分類(lèi)信息的feature,在進(jìn)行softmax之前先除以一個(gè)參數(shù)?,叫做temperature(蒸餾溫度),然后softmax得到soft labels(區(qū)別于one-hot形式的hard-label)。
student模型也是除以同一個(gè)?,然后softmax得到一個(gè)soft-prediction,我們希望student模型的soft-prediction和teacher模型的soft labels盡量接近,使用KLDivLoss進(jìn)行兩者之間的差距度量,計(jì)算一個(gè)對(duì)應(yīng)的損失teacher loss。
在訓(xùn)練的時(shí)候,我們是可以拿的到訓(xùn)練圖片的真實(shí)的ground truth(hard label)的,可以看到上面圖中student模型下面一路,就是預(yù)測(cè)結(jié)果和真是標(biāo)簽之間計(jì)算交叉熵crossentropy。
交叉熵:損失函數(shù)|交叉熵?fù)p失函數(shù) - 知乎 (zhihu.com)
然后兩路計(jì)算的損失:KLDivLoss和CELoss,按照一個(gè)加權(quán)關(guān)系計(jì)算得到一個(gè)總損失total loss,反向修改參數(shù)的時(shí)候這個(gè)teacher模型是不做訓(xùn)練的,只依據(jù)total loss訓(xùn)練student模型。
其中表示的是教師網(wǎng)絡(luò)的輸出概率,表示學(xué)生網(wǎng)絡(luò)的輸出概率,表示蒸餾溫度,λ表示Kullback-Leibler散度損失與交叉熵?fù)p失之間的權(quán)重因子,y表示真實(shí)標(biāo)簽,ψ表示softmax函數(shù)
公式很容易可以理解,loss為學(xué)生網(wǎng)絡(luò)與真實(shí)標(biāo)簽的損失加上學(xué)生網(wǎng)絡(luò)輸出值與教師網(wǎng)絡(luò)輸出值的標(biāo)簽分布差異。一方面希望學(xué)生網(wǎng)絡(luò)的輸出值與真實(shí)標(biāo)簽相近,同時(shí)還希望其與教師網(wǎng)絡(luò)的輸出分布相近,這樣才可以學(xué)習(xí)到教師網(wǎng)絡(luò)對(duì)某些錯(cuò)誤數(shù)據(jù)與正確數(shù)據(jù)的相識(shí)情況。
?
硬蒸餾 hard diatillation
其中:,一方面使得學(xué)生網(wǎng)絡(luò)與真實(shí)標(biāo)簽的損失最小,同時(shí)也希望與教師網(wǎng)絡(luò)得出來(lái)的標(biāo)簽損失最小,這兩個(gè)損失各占一半的權(quán)重。
對(duì)于給定的圖像,與教師相關(guān)的硬標(biāo)簽可能會(huì)根據(jù)具體的數(shù)據(jù)增強(qiáng)而改變,而這種選擇比傳統(tǒng)的選擇更好,教師預(yù)測(cè)與真實(shí)labely扮演相同的角色。還要注意,硬標(biāo)簽也可以通過(guò)標(biāo)簽平滑轉(zhuǎn)換為軟標(biāo)簽。
軟蒸餾是限制student和teacher的模型輸出類(lèi)別分布盡可能接近,而硬蒸餾是限制兩種模型輸出的類(lèi)別標(biāo)簽盡可能接近。
KLDivloss
KL散度,又叫相對(duì)熵,用于衡量?jī)蓚€(gè)分布(連續(xù)分布和離散分布)之間的距離,在knowledge distillation中,兩個(gè)分布為teacher模型和student模型的softmax輸出。
當(dāng)兩個(gè)分布很相近時(shí)候,對(duì)應(yīng)class的預(yù)測(cè)值就會(huì)很接近,取log之后的差值就會(huì)很小,KL散度就很小。當(dāng)兩個(gè)分布完全一致時(shí)候,KL散度就等于0。
transformer中加入蒸餾——distillation token
在VIT中時(shí)使用class tokens去做分類(lèi)的,相當(dāng)于是一個(gè)額外的patch,這個(gè)patch去學(xué)習(xí)和別的patch之間的關(guān)系,然后連classifier,計(jì)算CELoss。在DeiT中為了做蒸餾,又額外加一個(gè)distill token,這個(gè)distill token也是去學(xué)和其他tokens之間的關(guān)系,然后連接teacher model計(jì)算KLDivLoss,那CELoss和KLDivLoss共同加權(quán)組合成一個(gè)新的loss取指導(dǎo)student model訓(xùn)練(知識(shí)蒸餾中teacher model不訓(xùn)練)。
在預(yù)測(cè)階段,class token和distill token分別產(chǎn)生一個(gè)結(jié)果,然后將其加權(quán)(分別0.5),再加在一起,得到最終的結(jié)果做預(yù)測(cè)。
在patches中加入與class token類(lèi)似的distillation token,兩者的通過(guò)網(wǎng)絡(luò)時(shí)的計(jì)算方式相同,區(qū)別在于class token目標(biāo)是重現(xiàn)ground truth標(biāo)簽,而distillation token目標(biāo)是重現(xiàn)教師模型的預(yù)測(cè),Distillation token讓模型從教師模型輸出中學(xué)習(xí),文章發(fā)現(xiàn):
- 最初class token和distillation token區(qū)別很大,余弦相似度為0.06
- 隨著class 和 distillation embedding互相傳播和學(xué)習(xí),通過(guò)網(wǎng)絡(luò)逐漸變得相似,到最后一層,余弦相似度為0.93,相似但不相同
- 當(dāng)用一個(gè)class token替換distillation token時(shí),兩個(gè)class token輸出的余弦相似度為0.999,網(wǎng)絡(luò)性能與一個(gè)class token相近,而加入distillation token的網(wǎng)絡(luò)性能明顯提升。這表明distillation token的設(shè)定是有效的。
Experiments
DeiT不同參數(shù)
定義了與ViT-B參數(shù)相同的DeiT-B模型,和更小的DeiT-S、DeiT-Ti模型,區(qū)別在于heads數(shù)目和embedding dimension不同。超參數(shù)如下:
?teacher模型的選擇
實(shí)驗(yàn)發(fā)現(xiàn)RegNetY-16GF是效果最好的教師模型,后續(xù)實(shí)驗(yàn)?zāi)J(rèn)選擇。
CNN效果更好,這可能是因?yàn)閠ransformer可以學(xué)到CNN的歸納假設(shè)。CNN是有inductive bias的,例如局部感受野,參數(shù)共享等,這些設(shè)計(jì)比較適應(yīng)于圖像任務(wù),這里將CNN作為teacher,可以通過(guò)蒸餾,使得Transformer學(xué)習(xí)得到CNN的inductive bias,從而提升Transformer對(duì)圖像任務(wù)的處理能力。
同時(shí)還可以發(fā)現(xiàn),學(xué)生網(wǎng)絡(luò)可以取得超越老師的性能,能夠在準(zhǔn)確率和吞吐量權(quán)衡方面做的更好。
m小標(biāo)表示使用了蒸餾策略的網(wǎng)絡(luò)模型。↑384表示student在224*224圖像上進(jìn)行預(yù)訓(xùn)練,然后在384*384圖像上進(jìn)行fine-tune。
蒸餾策略的選擇
硬蒸餾的性能比軟蒸餾更好。
在pretrain上測(cè)試的時(shí)候,distillation token和class token兩個(gè)一起用性能更佳,這表明兩個(gè)token提供了對(duì)分類(lèi)有用的互補(bǔ)信息。只用一個(gè)的時(shí)候,distillation token性能略好于class token,這可能是因?yàn)閐istillation token里有更多從CNN中學(xué)到的歸納假設(shè)
與其他模型性能對(duì)比
?訓(xùn)練策略和消融實(shí)驗(yàn)
初始化和超參數(shù)?
參數(shù)初始化方式:truncated normal distribution(截?cái)鄻?biāo)準(zhǔn)分布)
soft蒸餾參數(shù):= 3 , = 0.1
數(shù)據(jù)增強(qiáng)
總結(jié)