【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale
文章目錄
- 一、背景和動機
- 二、方法
- 三、效果
- 四、Vision Transformer 學(xué)習(xí)到圖像的哪些特征了
- 五、代碼
代碼鏈接:https://github.com/lucidrains/vit-pytorch
論文連接:https://openreview.net/pdf?id=YicbFdNTTy
一、背景和動機
Transformer 在 NLP 領(lǐng)域取得了很好的效果,但在計算機視覺領(lǐng)域還沒有很多應(yīng)用,所以作者想要借鑒其在 NLP 中的方法,在計算機視覺的分類任務(wù)中進(jìn)行使用。
二、方法
由于 Transformer 在 NLP 中使用時,都是接受一維的輸入,而圖像是二維的結(jié)構(gòu),所以需要先把圖像切分成大小相等的patch,然后再編碼成一個序列,送入 Transformer
Vision Transformer 的過程:
1、輸入圖像分割成 patch,并使用可學(xué)習(xí)的線性變換將其拉伸成 D 維,為輸入 Transformer 做準(zhǔn)備
-
首先,輸入圖像為 x∈RH×W×Cx\in R^{H \times W \times C}x∈RH×W×C,將其 reshape 為 xp∈RN×(P2C)x_p\in R{N \times (P^2 C)}xp?∈RN×(P2C),切分后的 patch 大小為 P×PP \times PP×P,N=HW/P2N=HW/P^2N=HW/P2 為 patch 的個數(shù)。這里,NNN 也掌握著 Transformer 的效率。
-
Transformer 在每層都使用維度大小為 DDD 的向量輸入,所以在送入 Transformer 之前,會使用可學(xué)習(xí)的線性影射(公式1)來將拉平的 patch 信息轉(zhuǎn)換為 D 維。
-
輸入:3x224x224
-
patch 大小 PPP:32x32
-
patch 個數(shù) NNN:7x7=49
-
D:128
2、給 D 維的 Transformer 輸入后面連接一個 class token
將輸入編碼成 B×N×DB\times N \times DB×N×D 輸入 Transformer 之前,會給 NNN 這個維度增加一個 class token,變成 N+1N+1N+1 維,這個 class token 是一個大小為 1×1×D1\times 1 \times D1×1×D 的可學(xué)習(xí)向量,表示 Transformer 的 Encoder 的輸出(zL0z_L^0zL0?),也就是作為圖像的特征表達(dá) y。
3、給上面的結(jié)果加上位置編碼
添加位置編碼能夠保留圖像的位置信息,作者使用可學(xué)習(xí)的 1D 位置編碼(因為從文獻(xiàn)來看,使用 2D 編碼也沒有帶來理想的效果)。ViT 中的位置編碼是隨機生成可學(xué)習(xí)參數(shù),沒有做過多設(shè)計,這樣的設(shè)計。
4、送入 Transformer Encoder
這里的 Encoder 由多層的 multiheaded self-attention(MSA)和 MLP 組成,每層之前都會使用 Layernorm,每個 block 之后都會使用殘差連接。
歸一化方法:
Transformer 中一般都使用 LayerNorm,LayerNorm 和 BatchNorm 的區(qū)別如下圖所示:
- LayerNorm:對一個 batch 的所有通道進(jìn)行歸一化(均值為 0,方差為 1)
- BatchNorm:對一個通道的所有 batch 進(jìn)行歸一化(均值為 0,方差為 1)
三、效果
- 使用中等大小的數(shù)據(jù)集(如 ImageNet),Transformer 比 ResNet 的效果稍微差點,作者認(rèn)為原因在于 Transformer 缺少了 CNN 中的歸納偏置(平移不變性和位置),泛化的也較差
- 使用大型數(shù)據(jù)集訓(xùn)練時(約 14M-300M images),作者發(fā)現(xiàn)大型的數(shù)據(jù)訓(xùn)練會勝過歸納偏置帶來的效果,ViT 在使用了大型數(shù)據(jù)集預(yù)訓(xùn)練(ImageNet-21k 或 in-house JFT-300M)然后遷移到其他任務(wù)時,效果優(yōu)于 CNN。
歸納偏置是什么?
歸納偏置可以理解成在算法在設(shè)計之初就加入的一種人為偏好,將某種方式的解優(yōu)于其他解,既包含低層數(shù)據(jù)分布假設(shè),也包含模型設(shè)計。
在深度學(xué)習(xí)時代,這種歸納性偏好更為明顯。比如深度神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)就偏好性的認(rèn)為,層次化處理信息有更好效果;卷積神經(jīng)網(wǎng)絡(luò)認(rèn)為信息具有空間局部性(locality),可以用滑動卷積共享權(quán)重方式降低參數(shù)空間;反饋神經(jīng)網(wǎng)絡(luò)則將時序信息考慮進(jìn)來強調(diào)順序重要性;圖網(wǎng)絡(luò)則是認(rèn)為中心節(jié)點與鄰居節(jié)點的相似性會更好引導(dǎo)信息流動。不同的網(wǎng)絡(luò)結(jié)構(gòu)創(chuàng)新就體現(xiàn)了不同的歸納性偏。
之前計算機視覺任務(wù)大都依賴于 CNN,CNN 有兩個內(nèi)置的歸納偏置:
- 局部相關(guān)性
- 權(quán)重共享
但基于注意力模型的 Transformer 最小化了歸納偏置,所以在大數(shù)據(jù)集上進(jìn)行訓(xùn)練時,效果甚至可以超過 CNN,但小數(shù)據(jù)集上因為缺少了這種歸納偏置,所以難以總結(jié)到有意義的特征。
CNN 有較好的歸納偏置,所以數(shù)據(jù)少的時候也能實現(xiàn)好的效果,但數(shù)據(jù)量大的時候,這些歸納偏置就會限制其效果,但 Transformer 不會被其限制,所以在大數(shù)據(jù)集上表現(xiàn)更好一些。
四、Vision Transformer 學(xué)習(xí)到圖像的哪些特征了
為了理解 Transformer 是如何學(xué)習(xí)到圖像特征的,作者分析了其內(nèi)部的特征表達(dá):
- Transformer 的第一層將 flattened patch 線性影射到了一個低維空間(公式 1),圖 7 左側(cè)可視化了前幾個主要的學(xué)習(xí)到的 embedding filters,這些組件類似于每個patch內(nèi)精細(xì)結(jié)構(gòu)的低維表示的可信基函數(shù)。
- 線性投影之后,加上位置編碼,圖 7 中間展示了模型學(xué)習(xí)了在位置嵌入相似度下對圖像內(nèi)距離進(jìn)行編碼,即離得近的 patches 更趨向于有相似的位置嵌入,然后就有了 row-column 結(jié)構(gòu),同一行或同一列的 patches 有相似的嵌入。
- 自注意力機制能夠提取整幅圖像的信息,作者為了探究這種注意力給網(wǎng)絡(luò)起了多大的作用,根據(jù)注意力的權(quán)重計算了其在空間中的平均距離(圖 7 右),這種”注意力距離”類似于 CNN 中的感受野的大小。作者注意到,一些 heads 趨向于關(guān)注最底層的大部分圖像,這表明模型確實使用了全局整合信息的能力。其他注意頭在較低層上的注意距離一直很小。這種高度位置化的注意在混合模型中不太明顯,這些模型在Transformer之前應(yīng)用了ResNet(圖7,右),這表明它可能具有與cnn中的低層卷積層類似的功能。注意距離隨網(wǎng)絡(luò)深度的增加而增加。從全局來看,發(fā)現(xiàn)模型關(guān)注與分類語義相關(guān)的圖像區(qū)域(圖6)。
上面中間的熱力圖可視化,某個位置和自己的余弦相似度肯定是最高的,然后和同行同列相似度次高,其他位置較低,這也能基本想通,因為位置本來就表示的某個像素在圖像中的某行某列,符合可視化結(jié)果。
五、代碼
總體過程:
- 輸入原圖:[1, 3, 224, 224]
- patch 編碼:[1, 49, 1024]
- cls_token:[1, 50, 1024]
- 位置編碼:[1, 50, 1024]
- Transformer:Attention + FeedForward [1, 50, 1024]
- 取第一組向量(或均值)作為全局特征:[1, 1024]
- MLP 輸出預(yù)測類別:[1, 1000],1000為類別數(shù)
調(diào)用 ViT 的方法:
pip install vit-pytorch import torch from vit_pytorch import ViTv = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1 )img = torch.randn(1, 3, 256, 256)preds = v(img) # (1, 1000)Patch embedding 結(jié)構(gòu):
Sequential((0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)(1): Linear(in_features=3072, out_features=1024, bias=True) )Transformer 結(jié)構(gòu):
Transformer((layers): ModuleList((0): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(1): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(2): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(3): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(4): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))(5): ModuleList((0): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): Attention((attend): Softmax(dim=-1)(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)(to_out): Sequential((0): Linear(in_features=1024, out_features=1024, bias=True)(1): Dropout(p=0.1, inplace=False))))(1): PreNorm((norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fn): FeedForward((net): Sequential((0): Linear(in_features=1024, out_features=2048, bias=True)(1): GELU()(2): Dropout(p=0.1, inplace=False)(3): Linear(in_features=2048, out_features=1024, bias=True)(4): Dropout(p=0.1, inplace=False)))))) )前傳方式:
def forward(self, x):for attn, ff in self.layers:# attn: attention# ff: feedforwardx = attn(x) + xx = ff(x) + xreturn xMLP 結(jié)構(gòu):
Sequential((0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(1): Linear(in_features=1024, out_features=1000, bias=True) )總結(jié)
以上是生活随笔為你收集整理的【Transformer】ViT:An image is worth 16x16: transformers for image recognition at scale的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 苹果现已推出 Apple Music R
- 下一篇: 【Transformer】DETR: E