vit transformer中的cls_token
1.源碼
# timm.model.vision_transformer def forward_head(self, x, pre_logits: bool = False):'''# self.global_pool == 'avg'則取所有token的均值作為一個(gè)類別的表征# self.global_pool == 'token'則取第一個(gè)cls_token作為一個(gè)類別的表征'''if self.global_pool: # [bs,token,dim] -> [bs,dim] 經(jīng)過gapx = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) # bs, dim=768 -> bs, class_numreturn x if pre_logits else self.head(x)2.說明
假設(shè)我們將原始圖像切分成共9個(gè)小圖像塊,最終的輸入序列長(zhǎng)度卻是10,也就是說我們這里人為的增加了一個(gè)向量進(jìn)行輸入,我們通常將人為增加的這個(gè)向量稱為 Class Token。
我們可以想象,如果沒有這個(gè)向量,也就是將9個(gè)向量(1~9)輸入 Transformer 結(jié)構(gòu)中進(jìn)行編碼,我們最終會(huì)得到9個(gè)編碼向量,可對(duì)于圖像分類任務(wù)而言,我們應(yīng)該選擇哪個(gè)輸出向量進(jìn)行后續(xù)分類呢?
方案一,即vit的方案:ViT算法提出了一個(gè)可學(xué)習(xí)的嵌入向量 Class Token( 向量0),將它與9個(gè)向量一起輸入到 Transformer 結(jié)構(gòu)中,輸出10個(gè)編碼向量,然后用這個(gè) Class Token 進(jìn)行分類預(yù)測(cè)。即,基于添加的cls_token執(zhí)行類別預(yù)測(cè),位置在所有token的第一個(gè)位置token[0],見編碼中的x[:,0]
方案二,取除了cls_token之外的所有token的均值作為類別特征表示,即編碼中的x[:, self.num_tokens:].mean(dim=1)
?3.思考
根據(jù)自注意機(jī)制,每個(gè)patch token一定程度上聚合了全局信息,但是主要是自身特征。ViT論文還使用了所有token取平均的方式,這意味每個(gè)patch對(duì)預(yù)測(cè)的貢獻(xiàn)相同,似乎不太合理?。實(shí)際上,這樣做的效果基本和引入cls_token差不多。
參考:
?vit 中的 cls_token 與 position_embed 理解_mingqian_chu的博客-CSDN博客_cls token
ViT為何引入cls_token_gltangwq的博客-CSDN博客_cls token
總結(jié)
以上是生活随笔為你收集整理的vit transformer中的cls_token的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 用友系统检查iis服务器不符,安装用友U
- 下一篇: PCB多层板设计总结