Swin Transformer理论讲解
本文提出了一個(gè)新的視覺Transformer,稱為Swin Transformer,它可以作為計(jì)算機(jī)視覺的一個(gè)通用骨干(backbone)。將Transformer從語(yǔ)言改編為視覺的挑戰(zhàn)來(lái)自于兩個(gè)領(lǐng)域之間的差異,比如視覺實(shí)體的尺度變化很大,以及與文本中的文字相比,圖像中的像素分辨率很高。為了解決這些差異,我們提出了一個(gè)層次化的Transformer,其表示方法是通過 S \textbf{S} Shifted win \textbf{win} windows來(lái)計(jì)算的。移位的窗口方案通過將自我注意(self-attention)的計(jì)算限制在不重疊的局部窗口,同時(shí)也允許跨窗口的連接,從而帶來(lái)了更高的效率。這種分層結(jié)構(gòu)具有在不同尺度上建模的靈活性,并且相對(duì)于圖像大小具有線性計(jì)算復(fù)雜性。Swin Transformer的這些特質(zhì)使其與廣泛的視覺任務(wù)兼容,包括圖像分類(ImageNet-1K上87.3%的最高準(zhǔn)確率)和密集預(yù)測(cè)任務(wù),如物體檢測(cè)(COCO test-dev上58.7%的APbox和51.1%的APmask)和語(yǔ)義分割(ADE20K val上53.5% mIoU)。它的性能超過了以前的最先進(jìn)水平,在COCO上為+2.7% APbox和+2.6% APmask,在ADE20K上為 +3.2% mIoU,證明了基于Transformer的模型作為視覺骨干的潛力。分層設(shè)計(jì)和移位窗口的方法也被證明對(duì)所有MLP架構(gòu)有益。代碼和模型在this https URL公開提供。
This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text. To address these differences, we propose a hierarchical Transformer whose representation is computed with \textbf{S}hifted \textbf{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation (53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and +2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones. The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures. The code and models are publicly available at~\url{this https URL}.
Subjects: Computer Vision and Pattern Recognition (cs.CV); Machine Learning (cs.LG)
Cite as: arXiv:2103.14030 [cs.CV]
(or arXiv:2103.14030v2 [cs.CV] for this version)
https://doi.org/10.48550/arXiv.2103.14030
Focus to learn more
Submission history
From: Han Hu [view email]
[v1] Thu, 25 Mar 2021 17:59:31 UTC (1,064 KB)
[v2] Tue, 17 Aug 2021 16:41:34 UTC (1,065 KB)
ICCV 2021 Best Paper
論文地址:https://doi.org/10.48550/arXiv.2103.14030
源碼地址:https://github.com/microsoft/Swin-Transformer
0. 引言
0.1 Swin Transformer與Vision Transformer的對(duì)比
二者的不同之處:
在ViT模型中,是直接對(duì)特征圖下采樣16倍,在后面的結(jié)構(gòu)中也一致保持這樣的下采樣規(guī)律不變(只有16x下采樣,不Swin Transformer那樣有多種下采樣尺度 -> 這樣就導(dǎo)致ViT不能構(gòu)建出具有層次性的特征圖)
0.2 Swin Transformer與其他網(wǎng)絡(luò)準(zhǔn)確率對(duì)比分析
0.2.1 ImageNet-1K數(shù)據(jù)集準(zhǔn)確率對(duì)比
這些模型先在ImageNet-1K數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練后,再在ImageNet-1K上的表現(xiàn)。
可以看到:
- RegNet的準(zhǔn)確率整體表現(xiàn)是不如EfficientNet系列的(考慮到EfficientNet有不同的輸入尺寸,其實(shí)這么比較也不是那么公平),模型的參數(shù)量也比EfficientNet要大。
- ViT整體的Top-1準(zhǔn)確率是最低的,而且尤其是ViT-L/16在參數(shù)量和FLOPs上“一騎絕塵😂”,我個(gè)人猜測(cè)是因?yàn)閂iT的參數(shù)量過于大,模型的容量也很大,所以需要大量的數(shù)據(jù)去擬合,很明顯,ImageNet-1K并不能滿足它。
- DeiT最小規(guī)格的模型是不如RegNet和EfficientNet的,但最高規(guī)格的準(zhǔn)確率強(qiáng)于ReNet。
- Swin-B的準(zhǔn)確率是所有模型中最高的,且相比ViT而言,其準(zhǔn)確率提升很大。
0.2.2 ImageNet-22K預(yù)訓(xùn)練后在ImageNet-1K的準(zhǔn)確率
我們看一下這些模型先在ImageNet-22K數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練后,再在ImageNet-1K上的表現(xiàn)。
ImageNet-22K規(guī)模遠(yuǎn)大于ImageNet-1K
從表中可以看到,在ImageNet-22K預(yù)訓(xùn)練后,所有模型的ImageNet-1K準(zhǔn)確率都有提升。
- ViT-B/16的準(zhǔn)確率提升6.1個(gè)點(diǎn)(+7.83%)
- ViT-L/16的準(zhǔn)確率提升8.7個(gè)點(diǎn)(+11.17%)
- Swin-B(2242)的準(zhǔn)確率提升1.7個(gè)點(diǎn)(+2.04%)
- Swin-B(3842)的準(zhǔn)確率提升1.9個(gè)點(diǎn)(+2.25%)
- Swin-L(3842)的準(zhǔn)確率為最高(但此時(shí)的FLOPs仍比ViT要低)
1. Swin Transformer框架
假設(shè)我們的輸入圖片的shape為 H × W × 3 H \times W \times 3 H×W×3的圖片,首先通過Patch Partition模塊 -> 圖片的shape變?yōu)? H 4 × W 4 × 48 \frac{H}{4} \times \frac{W}{4} \times 48 4H?×4W?×48。接下來(lái)再依次通過 Stage1 ~ Stage8。
這個(gè)結(jié)構(gòu)非常像ResNet,圖片首先通過一個(gè)stem層,之后經(jīng)過若干個(gè)Stage結(jié)構(gòu)對(duì)特征圖進(jìn)行特征提取和下采樣。
1.1 注意事項(xiàng)
- 在Swin Transformer中,每經(jīng)過依次下采樣, H , W H, W H,W會(huì)減半,而 C C C會(huì)翻倍。
- Stage1和其他Stage n n n不同的是,Stage1的第一個(gè)層結(jié)構(gòu)是Linear Embedding層,而其他Stage的第一層是Patch Merging層。
1.2 Patch Partition
partition 英[pɑ??t??n] 美[pɑ?r?t??n]
n. 隔斷; 分割; 隔扇; 隔板墻; 分治; 瓜分;
vt. 分割; 使分裂;
假如左邊的矩形是輸入圖片,shape為 4 × 4 × 3 4 \times 4 \times 3 4×4×3(注意是三通道而非單通道)。Patch Partition會(huì)使用一個(gè) 4 × 4 4 \times 4 4×4大小的窗口對(duì)輸入圖像進(jìn)行分割。分割之后對(duì)每一個(gè)小的窗口在channel方向進(jìn)行展平處理。即圖片的長(zhǎng)度和寬度縮小4倍,而channel變?yōu)?×4×3=48。
經(jīng)過Patch Partition層之后,tensor經(jīng)過Linear Embedding層對(duì)輸入特征圖的channel進(jìn)行調(diào)整。通過調(diào)整之后,特征圖的channel變?yōu)?span id="ozvdkddzhkzd" class="katex--inline"> C C C。這里的 C C C具體為多少是根據(jù)Swin Transformer的具體類型進(jìn)行調(diào)整的。
Note:
- 在Stage1的Linear Embeddding層中還包含了一個(gè)Linear Norm層。
- 這里的Patch Partition和Linear Embedding層看起來(lái)很高大上,說(shuō)白了是通過一個(gè)卷積層實(shí)現(xiàn)的。
- Patch Partition使用卷積核大小為4×4,個(gè)數(shù)為48,stride=4的二維卷積實(shí)現(xiàn) -> nn.Conv2d(inp=3, oup=48, kernel_size=(4, 4), stride=4)
- Linear Embedding使用的是tensor.flatten()和一維卷積實(shí)現(xiàn),即nn.Conv1d(inp=48, oup=C, kernel_size=1, stirde=1),最后加上一個(gè)nn.LinearNorm()即可。
- Swin Transformer Block的次數(shù)都是偶數(shù)次。
- 那么為什么是偶數(shù)次呢?
- 因?yàn)樵诙询BSwin Transformer Block時(shí),先使用圖3(b)中的左邊的Block,再使用右邊的Block。
- 左邊的Block的W-MSA其實(shí)就是一個(gè)Multi-head Self-attention模塊(Window Multi-head Self-attenton)
- 右邊的Block的SW-MSA本質(zhì)上也是一個(gè)Multi-head Self-attention模塊(Shifted Multi-head Self-attenton)
- 這兩個(gè)MSA是成對(duì)使用的,所以Swin Transformer Block的次數(shù)都是偶數(shù)次
1.3 Patch Merging
Patch Merging的實(shí)際作用是下采樣。通過Patch Merging后,特征圖的高和寬會(huì)縮減為原來(lái)的一半,Channel會(huì)翻倍。
從上圖可以看到,特征圖的尺寸變?yōu)樵瓉?lái)的一半,深度(通道數(shù))翻倍。
- 4 -> 2
- 1 -> 2
2. W-MSA(Windows Multi-head Self-Attention)
對(duì)于普通的MSA模塊,會(huì)對(duì)輸入特征圖的每一個(gè)像素求解 Q , K , V Q, K, V Q,K,V,每一個(gè)像素求得的 Q Q Q 會(huì)和特征圖上每一個(gè)像素的 K K K 進(jìn)行匹配。然后再進(jìn)行一系列的操作。
而對(duì)應(yīng)Window Multi-head Self-Attention而言,首先會(huì)對(duì)特征圖進(jìn)行分割處理,分割為一個(gè)一個(gè)的Window,然后在每一個(gè)Window內(nèi)部開始執(zhí)行MSA。注意:在進(jìn)行MSA時(shí),Window與Window之間是沒有任何通信的。
這么設(shè)計(jì)WMSA的目的是:減少計(jì)算量。
同樣的,這樣的設(shè)計(jì)也會(huì)引入一些缺點(diǎn):Window之間無(wú)法進(jìn)行信息交互。這將會(huì)導(dǎo)致特征圖的感受野變小,沒法看到全局的視野,這肯定對(duì)最終的預(yù)測(cè)結(jié)果有影響。
3. W-MSA和MSA理論上的計(jì)算量對(duì)比
3.1 MSA計(jì)算量推導(dǎo)
首先回憶下單頭Self-Attention的公式:
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V {\rm Attention}(Q, K, V) = {\rm SoftMax}(\frac{QK^T}{\sqrtozvdkddzhkzd})V Attention(Q,K,V)=SoftMax(d?QKT?)V
對(duì)于特征圖中的每個(gè)像素(或稱作token,patch),都要通過 W q , W k , W v W_q, W_k, W_v Wq?,Wk?,Wv? 生成對(duì)應(yīng)的 q u e r y ( q ) query(q) query(q), k e y ( k ) key(k) key(k) 以及 v a l u e ( v ) value(v) value(v)。這里假設(shè) q , k , v q, k, v q,k,v 的向量長(zhǎng)度與特征圖的深度 C C C 保持一致。那么對(duì)應(yīng)所有像素生成 Q Q Q 的過程如下式:
Q h w × C = A h w × C ? W q C × C Q^{hw \times C} = A^{hw \times C} \cdot W_q^{C \times C} Qhw×C=Ahw×C?WqC×C?
- A h w × C A^{hw \times C} Ahw×C 為將所有像素(token)拼接在一起得到的矩陣(一共有 h w hw hw 個(gè)像素,每個(gè)像素的深度為 C C C)
- W q C × C W^{C \times C}_q WqC×C? 為生成 q u e r y query query 的變換矩陣(因?yàn)檩斎胼敵鎏卣鲌D通道數(shù)不變,所以是 C × C C \times C C×C)
- Q h w × C Q^{hw \times C} Qhw×C 為所有像素通過 W q C × C W^{C \times C}_q WqC×C? 得到的query拼接后的矩陣
補(bǔ)充一個(gè)矩陣乘法FLOPs計(jì)算方式,假設(shè)有如下兩個(gè)矩陣做矩陣乘法:
A a × b ? B b × c A^{a\times b} \cdot B^{b \times c} Aa×b?Bb×c
這兩個(gè)矩陣相乘之后,FLOPs為: a × b × c a \times b \times c a×b×c。
所以根據(jù)矩陣運(yùn)算的計(jì)算量公式可以得到生成 Q Q Q 的計(jì)算量為 h w × C × C hw \times C \times C hw×C×C,生成 K K K 和 V V V 同理都是 h w C 2 hwC^2 hwC2,那么總共是 3 h w C 2 3hwC^2 3hwC2。接下來(lái) Q Q Q 和 K T K^T KT 相乘:
X h w × h w = Q h w × C ? K T ( C × h w ) X^{hw \times hw} = Q^{hw \times C} \cdot K^{T(C \times hw)} Xhw×hw=Qhw×C?KT(C×hw)
對(duì)應(yīng)計(jì)算量為 ( h w ) 2 C (hw)^2C (hw)2C。
接下來(lái)忽略除以 d \sqrtozvdkddzhkzd d? 以及 s o f t m a x {\rm softmax} softmax的計(jì)算量,假設(shè)得到 Λ h w × h w \Lambda^{hw \times hw} Λhw×hw,最后還要乘以 V V V:
B h w × C = Λ h w × h w ? V h w × C B^{hw \times C} = \Lambda^{hw \times hw} \cdot V^{hw \times C} Bhw×C=Λhw×hw?Vhw×C
對(duì)應(yīng)的計(jì)算量為 ( h w ) 2 C (hw)^2C (hw)2C。
那么對(duì)應(yīng)單頭的Self-Attention模塊,總共的計(jì)算量為:
3 h w C 2 Q , K , V + ( h w ) 2 C Q K T + ( h w ) 2 C ? V = 3 h w C 2 + 2 ( h w ) 2 C \underset{Q, K, V}{3hwC^2} + \underset{QK^T}{(hw)^2C} + \underset{\cdot V}{(hw)^2C} = 3hwC^2 + 2(hw)^2C Q,K,V3hwC2?+QKT(hw)2C?+?V(hw)2C?=3hwC2+2(hw)2C
在實(shí)際使用過程中,使用的是多頭的Multi-head Self-Attention模塊(MSA),在之前的文章中有進(jìn)行過實(shí)驗(yàn)對(duì)比,多頭注意力模塊相比單頭注意力模塊的計(jì)算量?jī)H多了最后一個(gè)融合矩陣 W O W_O WO? 的計(jì)算量 h w C 2 hwC^2 hwC2。
O h w × C = B h w × C ? W O C × C O^{hw \times C} = B^{hw \times C} \cdot W^{C \times C}_O Ohw×C=Bhw×C?WOC×C?
對(duì)應(yīng)的計(jì)算量為 h w C 2 hwC^2 hwC2。
所以總共加起來(lái)是: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C。
3.2 W-MSA計(jì)算量推導(dǎo)
對(duì)于W-MSA模塊首先要將特征圖劃分到一個(gè)個(gè)窗口(Window)中,假設(shè)每個(gè)窗口的寬高都是 M M M,那么總共會(huì)得到 h M × w M \frac {h} {M} \times \frac {w}{M} Mh?×Mw? 個(gè)窗口,然后對(duì)每個(gè)窗口內(nèi)使用多頭注意力模塊(MSA)。
剛剛計(jì)算高為 h h h,寬為 w w w,深度為 C C C 的特征圖的計(jì)算量為 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C,這里每個(gè)窗口的高為 M M M 寬為 M M M,帶入公式得:
4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2 + 2(M)^4C 4(MC)2+2(M)4C
又因?yàn)橛?span id="ozvdkddzhkzd" class="katex--inline"> h M × w M \frac {h} {M} \times \frac {w}{M} Mh?×Mw? 個(gè)窗口,則:
F L O P s ( W - M S A ) = h M × w M × [ 4 ( M C ) 2 + 2 ( M ) 4 C ] = 4 h w C 2 + 2 M 2 h w C \begin{aligned} {\rm FLOPs(W{\text -}MSA)} & = \frac{h}{M} \times \frac{w}{M} \times [4(MC)^2 + 2(M)^4C] \\ & = 4hwC^2 + 2M^2 hw C \end{aligned} FLOPs(W-MSA)?=Mh?×Mw?×[4(MC)2+2(M)4C]=4hwC2+2M2hwC?
故使用W-MSA的計(jì)算量為 4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2 hw C 4hwC2+2M2hwC。
3.3 計(jì)算量對(duì)比
F L O P s ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C F L O P s ( W - M S A ) = 4 h w C 2 + 2 M 2 h w C \begin{aligned} {\rm FLOPs}({\rm MSA}) & = 4 h w C^2 + 2(hw)^2 C \\ {\rm FLOPs}({\rm W{\text -}MSA}) & = 4hwC^2 + 2M^2 hwC \end{aligned} FLOPs(MSA)FLOPs(W-MSA)?=4hwC2+2(hw)2C=4hwC2+2M2hwC?
其中:
- h , w h, w h,w代表特征圖尺寸
- C C C代表特征圖深度
- M M M代表每個(gè)Window的大小
為了直觀看出兩者計(jì)算的不同,假設(shè)輸入圖片的shape為 X ∈ R 112 × 112 × 128 \mathcal{X}\in {\mathbb R}^{112 \times 112 \times 128} X∈R112×112×128,Window的尺寸為 M M M。(輸入輸出通道數(shù)不變),則兩種MSA的計(jì)算量如下:
| MSA (M) | 41104.1792 | 41104.1792 | 41104.1792 | 41104.1792 | 41104.1792 | 41104.1792 | 41104.1792 |
| W-MSA (M) | 825.2948 | 834.9286 | 850.9850 | 873.4638 | 902.3652 | 937.6891 | 979.4355 |
| Δ \Delta Δ | -97.9922% | -97.9688% | -97.9297% | -97.8750% | -97.8047% | -97.7188% | -97.6172% |
從表中的數(shù)據(jù)可以看到,W-MSA相比MSA而已,切割窗口的設(shè)計(jì)可以為模型省出巨大的計(jì)算量。
4. Shifted Window Multi-head Self-Attention(SW-MSA)
圖2. 所提出的Swin Transformer架構(gòu)中計(jì)算自我注意力的移位窗口方法的說(shuō)明。在第 l l l 層(左),采用了一個(gè)常規(guī)的窗口劃分方案,在每個(gè)窗口內(nèi)計(jì)算自我注意力。在接下來(lái)的第 l + 1 l+1 l+1 層(右),窗口分區(qū)被轉(zhuǎn)移,產(chǎn)生了新的窗口。新窗口中的自我注意計(jì)算跨越了第 l l l 層中先前窗口的邊界,提供了它們之間的聯(lián)系。
前面有說(shuō),采用W-MSA模塊時(shí),只會(huì)在每個(gè)窗口內(nèi)進(jìn)行自注意力計(jì)算(MSA),所以窗口與窗口之間是無(wú)法進(jìn)行信息傳遞的。為了解決這個(gè)問題,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模塊,即進(jìn)行偏移的W-MSA。如上圖所示,左側(cè)使用的是剛剛講的W-MSA(假設(shè)是第 l l l 層),那么根據(jù)之前介紹的W-MSA和SW-MSA是成對(duì)使用的,那么第 l + 1 l+1 l+1 層使用的就是SW-MSA(右側(cè)圖)。根據(jù)左右兩幅圖對(duì)比能夠發(fā)現(xiàn)窗口(Windows)發(fā)生了偏移(可以理解成窗口從左上角分別向右側(cè)和下方各偏移了 ? M 2 ? \left \lfloor \frac {M} {2} \right \rfloor ?2M??個(gè)像素)。
4.1 SW-MSA移動(dòng)窗口示意圖
看下偏移后的窗口(右側(cè)圖),比如對(duì)于第一行第2列的 2 × 4 2\times 4 2×4 的窗口,它能夠使第 l l l 層的第一排的兩個(gè)窗口信息進(jìn)行交流。再比如,第二行第二列的 4 × 4 4\times 4 4×4 的窗口,他能夠使第 l l l 層的四個(gè)窗口信息進(jìn)行交流,其他的同理。那么這就解決了不同窗口之間無(wú)法進(jìn)行信息交流的問題。
根據(jù)上圖,可以發(fā)現(xiàn)通過將窗口進(jìn)行偏移后,由原來(lái)的 4 4 4 個(gè)窗口變成 9 9 9 個(gè)窗口了。后面又要對(duì)每個(gè)窗口內(nèi)部進(jìn)行MSA,這樣做感覺又變麻煩了。
對(duì)于新生成的 9 個(gè)Window,如果想要實(shí)現(xiàn)并行計(jì)算,那么就需要對(duì)邊上 8 個(gè)Window進(jìn)行填充,填充到 4 × 4 4 \times 4 4×4 大小。如果我們使用這種策略,那么我們就相當(dāng)于是計(jì)算了 9 9 9 個(gè) 4 × 4 4 \times 4 4×4 大小Window,計(jì)算量又增加了。
為了解決這個(gè)麻煩,作者又提出而了Efficient batch computation for shifted configuration,一種更加高效的計(jì)算方法。下面是原論文給的示意圖。
圖4. 移位窗口分區(qū)中自我注意(SW-MSA)的高效批量計(jì)算方法的說(shuō)明。
感覺不太好描述,然后霹靂巴拉WZ重新繪制了該圖。下圖左側(cè)是剛剛通過偏移窗口后得到的新窗口,右側(cè)是為了方便大家理解,對(duì)每個(gè)窗口加上了一個(gè)標(biāo)識(shí)。然后0對(duì)應(yīng)的窗口標(biāo)記為區(qū)域A,3和6對(duì)應(yīng)的窗口標(biāo)記為區(qū)域B,1和2對(duì)應(yīng)的窗口標(biāo)記為區(qū)域C。
接下來(lái)對(duì)劃分的區(qū)域進(jìn)行了2次平移,如下圖所示。
移動(dòng)完畢后,我們對(duì)Window重新進(jìn)行劃分,如下圖所示。
移動(dòng)完后,4是一個(gè)單獨(dú)的窗口;將5和3合并成一個(gè)窗口;7和1合并成一個(gè)窗口;8, 6, 2, 0合并成一個(gè)窗口。這樣又和原來(lái)一樣是 4 4 4 個(gè) 4 × 4 4 \times 4 4×4 的窗口了,在對(duì)這4個(gè)4×4的Window進(jìn)行W-MSA計(jì)算的話能夠保證計(jì)算量是一樣的。
4.2 masked MSA
但是我們直接簡(jiǎn)單粗暴地在每個(gè)Window中進(jìn)行W-MSA計(jì)算(其實(shí)就是MSA計(jì)算)的話,就會(huì)引入一個(gè)新的問題。
對(duì)于第一個(gè)4×4的Window來(lái)說(shuō)其實(shí)沒有影響,因?yàn)樗旧砭褪且粋€(gè)4×4的Window,但對(duì)于B來(lái)說(shuō),這個(gè)Window是由兩個(gè)分開的區(qū)域組合在一起的,而且5和3本來(lái)就不是相鄰的兩個(gè)區(qū)域,如果我們強(qiáng)行MSA計(jì)算的話,其實(shí)是有問題的。所以我們希望在B中個(gè)Window中可以單獨(dú)計(jì)算區(qū)域5的MSA和區(qū)域3的MSA。
那么具體是怎么實(shí)現(xiàn)的呢?
在論文中,使用的不是原本的MSA而是masked MSA即帶蒙板mask的MSA,這樣就能夠通過設(shè)置蒙板來(lái)隔絕不同區(qū)域的信息了。
關(guān)于mask如何使用,可以看下下面這幅圖,下圖是以上面的區(qū)域5和區(qū)域3為例。
對(duì)于該窗口內(nèi)的每一個(gè)像素(或稱token,patch)在進(jìn)行MSA計(jì)算時(shí),都要先生成對(duì)應(yīng)的 q u e r y ( q ) query(q) query(q), k e y ( k ) key(k) key(k), v a l u e ( v ) value(v) value(v)。假設(shè)對(duì)于上圖的像素0而言,得到 q 0 q^0 q0 后要與每一個(gè)像素的 k k k 進(jìn)行匹配(match)。
假設(shè) α 0 , 0 \alpha _{0,0} α0,0? 代表 q 0 q^0 q0 與像素0對(duì)應(yīng)的 k 0 k^0 k0 進(jìn)行匹配的結(jié)果,那么同理可以得到 α 0 , 0 \alpha _{0,0} α0,0? 至 α 0 , 15 \alpha _{0,15} α0,15?。按照普通的MSA計(jì)算,接下來(lái)就是SoftMax操作了。
但對(duì)于這里的masked MSA,像素0是屬于區(qū)域5的,我們只想讓它和區(qū)域5內(nèi)的像素進(jìn)行匹配。那么我們可以將像素0與區(qū)域3中的所有像素匹配結(jié)果都減去100(例如 α 0 , 2 , α 0 , 3 , α 0 , 6 , α 0 , 7 \alpha _{0,2}, \alpha _{0,3}, \alpha _{0,6}, \alpha _{0,7} α0,2?,α0,3?,α0,6?,α0,7? 等等),由于 α \alpha α 的值都很小,一般都是零點(diǎn)幾的數(shù),將其中一些數(shù)減去 100 100 100 后再通過SoftMax得到對(duì)應(yīng)的權(quán)重都等于 0 0 0 了。所以對(duì)于像素0而言實(shí)際上還是只和區(qū)域5內(nèi)的像素進(jìn)行了MSA。
對(duì)于其他像素也是同理,具體代碼是怎么實(shí)現(xiàn)的,后面會(huì)在代碼講解中進(jìn)行詳解。
注意,在計(jì)算完后還要把數(shù)據(jù)給挪回到原來(lái)的位置上(例如上述的A,B,C區(qū)域)。
4.3 masked MSA例子
5. Relative Position Bias
5.1 Relative position bias的效果
關(guān)于相對(duì)位置偏執(zhí),論文里也沒有細(xì)講,就說(shuō)了參考的哪些論文,然后說(shuō)使用了相對(duì)位置偏執(zhí)后給夠帶來(lái)明顯的提升。根據(jù)原論文中的表4可以看出,在ImageNet數(shù)據(jù)集上如果不使用任何位置偏執(zhí),top-1為 80.1 % 80.1\% 80.1%,但使用了相對(duì)位置偏執(zhí)(rel. pos.)后top-1為 83.3 % 83.3\% 83.3%,提升還是很明顯的。
第一二行:
- 第一行:全部使用W-MSA模塊,不使用SW-MSA,那么ImageNet Top-1準(zhǔn)確率可以達(dá)到80.2%
- 第二行:除了W-MSA模塊,還使用了SW-MSA模塊,那么ImageNet Top-1準(zhǔn)確率可以達(dá)到81.3%,而且在COCO和分割任務(wù)的性能也得到提升。
這說(shuō)明窗口與窗口之間的信息交互是非常有必要的。
- 如果加了絕對(duì)位置(abs. pos.)后,雖然在ImageNet數(shù)據(jù)集上的top- n n n增加了,但在COCO和分割任務(wù)上的性能降低了。所以絕對(duì)位置編碼效果并不好。
- 如果使用本文使用的相對(duì)位置偏置(rel. pos.),那么在ImageNet top準(zhǔn)確率最好的情況下,COCO和分割任務(wù)上的性能都提升最多。這也說(shuō)明了,使用相對(duì)位置偏置(relative position bias)是最合理的。
5.2 定義及解釋
原版的MSA計(jì)算公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V {\rm Attention}(Q, K, V) = {\rm softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk??QKT?)V
在Swin Transformer中,給出的公式為:
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d k + B ) V {\rm Attention}(Q, K, V) = {\rm SoftMax}(\frac{QK^T}{\sqrt{d_k}} + B)V Attention(Q,K,V)=SoftMax(dk??QKT?+B)V
那這個(gè)相對(duì)位置偏執(zhí)是加在哪的呢,根據(jù)論文中提供的公式可知是在 Q Q Q 和 K K K 進(jìn)行匹配并除以 d \sqrt d d? 后加上了相對(duì)位置偏執(zhí) B B B。
由于論文中并沒有詳解講解這個(gè)相對(duì)位置偏執(zhí),所以霹靂吧啦WZ根據(jù)閱讀源碼做了簡(jiǎn)單的總結(jié)。
如下圖,假設(shè)輸入的特征圖高寬都為 2 2 2,那么首先我們可以構(gòu)建出每個(gè)像素的絕對(duì)位置(左下方的矩陣),對(duì)于每個(gè)像素的絕對(duì)位置是使用行號(hào)和列號(hào)表示的。
比如藍(lán)色的像素對(duì)應(yīng)的是第0行第0列所以絕對(duì)位置索引是 ( 0 , 0 ) (0,0) (0,0),接下來(lái)再看看相對(duì)位置索引。
首先看下藍(lán)色的像素,在藍(lán)色像素使用 q q q 與所有像素 k k k 進(jìn)行匹配過程中,是以藍(lán)色像素為參考點(diǎn)。然后用藍(lán)色像素的絕對(duì)位置索引與其他位置索引進(jìn)行相減,就得到其他位置相對(duì)藍(lán)色像素的相對(duì)位置索引。例如黃色像素的絕對(duì)位置索引是 ( 0 , 1 ) (0,1) (0,1),則它相對(duì)藍(lán)色像素的相對(duì)位置索引為 ( 0 , 0 ) ? ( 0 , 1 ) = ( 0 , ? 1 ) (0,0)?(0,1)=(0,?1) (0,0)?(0,1)=(0,?1),這里是嚴(yán)格按照源碼中來(lái)講的,請(qǐng)不要杠😂。那么同理可以得到其他位置相對(duì)藍(lán)色像素的相對(duì)位置索引矩陣。
同樣,也能得到相對(duì)黃色,紅色以及綠色像素的相對(duì)位置索引矩陣。接下來(lái)將每個(gè)相對(duì)位置索引矩陣按行展平,并拼接在一起可以得到下面的 4 × 4 4\times 4 4×4 矩陣 。
請(qǐng)注意,我這里描述的一直是相對(duì)位置索引,并不是相對(duì)位置偏執(zhí)參數(shù)(并不是公式中的那個(gè) B B B)。因?yàn)楹竺嫖覀儠?huì)根據(jù)相對(duì)位置索引去取對(duì)應(yīng)的參數(shù)。
比如說(shuō)黃色像素是在藍(lán)色像素的右邊,所以相對(duì)藍(lán)色像素的相對(duì)位置索引為 ( 0 , ? 1 ) (0,?1) (0,?1)。綠色像素是在紅色像素的右邊,所以相對(duì)紅色像素的相對(duì)位置索引為 ( 0 , ? 1 ) (0,?1) (0,?1)??梢园l(fā)現(xiàn)這兩者的相對(duì)位置索引都是 ( 0 , ? 1 ) (0,?1) (0,?1),所以他們使用的相對(duì)位置偏執(zhí)參數(shù)都是一樣的。
5.3 源碼的操作
其實(shí)講到這基本已經(jīng)講完了,但在源碼中作者為了方便把二維索引給轉(zhuǎn)成了一維索引。具體這么轉(zhuǎn)的呢,有人肯定想到,簡(jiǎn)單啊直接把行、列索引相加不就變一維了嗎?
比如上面的相對(duì)位置索引中有 ( 0 , ? 1 ) (0,?1) (0,?1) 和 ( ? 1 , 0 ) (?1,0) (?1,0) 在二維的相對(duì)位置索引中明顯是代表不同的位置,但如果簡(jiǎn)單相加都等于 -1 那不就出問題了嗎?
- (0, -1) -> 0 + (-1) = -1
- (-1, 0) -> -1 + 0 = -1
這說(shuō)明如果直接相加,那么位置索引就沒有了(會(huì)有明明位置不同,但索引值相同的情況)!
接下來(lái)我們看看源碼中是怎么做的。
5.3.1 第一步
首先在原始的相對(duì)位置索引上加上 ( M ? 1 ) (M-1) (M?1) ( M M M 為窗口的大小,在本示例中 M = 2 M=2 M=2),加上之后索引中就不會(huì)有負(fù)數(shù)了。如下圖所示:
5.3.2 第二步
接著將所有的行標(biāo)都乘上2M-1。如下圖:
5.3.3 第三步
最后將行標(biāo)和列標(biāo)進(jìn)行相加。
這樣就得到一元相對(duì)位置索引矩陣。這個(gè)矩陣即保證了相對(duì)位置關(guān)系,而且不會(huì)出現(xiàn)上述 0 + (-1) = (-1) + 0 的問題了。
5.4 Relative Position Bias Table
剛剛上面也說(shuō)了,之前計(jì)算的是相對(duì)位置索引,并不是相對(duì)位置偏執(zhí)參數(shù)。真正使用到的可訓(xùn)練參數(shù) B ^ \hat{B} B^ 是保存在relative position bias table表里的,這個(gè)表的長(zhǎng)度是等于 ( 2 M ? 1 ) × ( 2 M ? 1 ) (2M-1) \times (2M-1) (2M?1)×(2M?1) 的。那么上述公式中的相對(duì)位置偏執(zhí)參數(shù) B B B 是根據(jù)上面的相對(duì)位置索引表根據(jù)查relative position bias table表得到的,如下圖所示。
剛才我們求的是索引,并不是用到的值,用到值需要通過求得的索引去查表得到。
看圖說(shuō)話, 我們發(fā)現(xiàn)relative position index這個(gè)矩陣的一共有9個(gè)數(shù),而relative position bias table的個(gè)數(shù)也是9個(gè)。圖中也寫了,矩陣的大小為 ( 2 M ? 1 ) × ( 2 M ? 1 ) (2M-1)\times (2M-1) (2M?1)×(2M?1)。
那么為什么是 ( 2 M ? 1 ) × ( 2 M ? 1 ) (2M-1)\times (2M-1) (2M?1)×(2M?1)呢?看下面這張圖:
M M M 是Window的大小,不是數(shù)量,這里我就懶得改了
6. 模型詳細(xì)配置參數(shù)
首先回憶下Swin Transformer的網(wǎng)絡(luò)架構(gòu):
圖3:(a)Swin Transformer(Swin-T)的結(jié)構(gòu);(b)兩個(gè)連續(xù)的Swin Transformer區(qū)塊(用公式(3)表示)。W-MSA和SW-MSA是多頭自我注意模塊,分別具有常規(guī)和移位的窗口配置。
下圖(表7)是原論文中給出的關(guān)于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:
- win. sz. 7×7表示使用的窗口(Window)的大小
- dim表示特征圖的channel深度(或者說(shuō)token的向量長(zhǎng)度)
- head表示多頭注意力模塊中head的個(gè)數(shù)
參考:
總結(jié)
以上是生活随笔為你收集整理的Swin Transformer理论讲解的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 5个超棒的自我提升App
- 下一篇: java 自旋锁实现