从零学习SwinTransformer
論文信息
論文名稱:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原論文地址: https://arxiv.org/abs/2103.14030
官方開源代碼地址:https://github.com/microsoft/Swin-Transformer
本篇博客參考文章:從零學(xué)習(xí)SwinTransformer
名詞解答
?M/2?: 表示向下取整
像素通道: 參考-圖像的通道數(shù)問題
描述一個(gè)像素點(diǎn),如果是灰度,那么只需要一個(gè)數(shù)值來描述它,就是單通道。
如果一個(gè)像素點(diǎn),有RGB三種顏色來描述它,就是三通道。
四通道圖像,R、G、B加上一個(gè)A通道,表示透明度。一般叫做alpha通道,表示透明度的。
2通道圖像不常見,通常在程序處理中會(huì)用到,如傅里葉變換,可能會(huì)用到,一個(gè)通道為實(shí)數(shù),一個(gè)通道為虛數(shù),主要是編程方便。
dense prediction理解:
標(biāo)注出圖像中每個(gè)像素點(diǎn)的對(duì)象類別,要求不但給出具體目標(biāo)的位置,還要描繪物體的邊界,如圖像分割、語義分割、邊緣檢測(cè)等等
WindowPatchToken理解:
假設(shè)輸入圖片的尺寸為224X224,先劃分成多個(gè)大小為4x4像素的小片,每個(gè)小片之間沒有交集。224/4=56,那么一共可以劃分56x56個(gè)小片。每一個(gè)小片就叫一個(gè)patch,每一個(gè)patch將會(huì)被看成一個(gè)token,所以patch=token。而一張圖被劃分為7x7個(gè)window,每個(gè)window之間也沒有交集。那么每個(gè)window就會(huì)包含8x8個(gè)patch。
一張圖有224(pixel)X224(pixel)= 56(個(gè)patch)x56(個(gè)patch)x4(pixel)x4(pixel)=7(個(gè)window)x7(個(gè)window)x8(個(gè)patch)x8(個(gè)patch)x4(pixel)x4(pixel)
不懂可看下圖,圖中本想畫出224*224像素的圖片,奈何畫不了,所以就畫一部分就可以說明劃分關(guān)系。其中圖中每個(gè)小方塊中包含4x4個(gè)像素,紅色框起來的8x8個(gè)patch為一個(gè)window,這樣的window一張圖片有7x7個(gè),每個(gè)window之間也沒有交集。那么每個(gè)window就會(huì)包含8x8個(gè)patch。
疑問: 圖中那4x4個(gè)patch組成的叫啥:好像啥也不叫,只是順手畫出來的。哦,下圖中的4x4個(gè)patch組成的東西是一個(gè)window的一個(gè)計(jì)算單元,一個(gè)window中有4x4個(gè)這樣的單元,用于計(jì)算self-attention,意思就是在計(jì)算self-attention時(shí)是8x8個(gè)patch作為一個(gè)單元去跟別人計(jì)算的。
網(wǎng)絡(luò)整體框架
原論文中給出的關(guān)于Swin Transformer(Swin-T)網(wǎng)絡(luò)的架構(gòu)圖。通過圖(a)可以看出整個(gè)框架的基本流程如下:
輸入image
假設(shè)模型的輸入是一張224x224x3 的圖片
Patch Partition詳解
首先將圖片輸入到Patch Partition模塊中進(jìn)行分塊,即每4x4個(gè)相鄰的像素劃分為一個(gè)patch,即在上面畫的圖,將圖片劃分一個(gè)一個(gè)的patch。
詳解每4x4個(gè)像素(3通道)如何展平為1x1個(gè)patch(48通道)?
大概以上的圖形就可示意每4x4個(gè)像素(3通道)如何展平為1x1個(gè)patch(48通道)。
對(duì)224x224個(gè)像素(3通道)的圖片都這樣處理,會(huì)得到(56,56)個(gè)patch,48通道,特征圖大小為(56,56,48)見下圖第二個(gè)帶綠色部分的立體圖。
再?gòu)?fù)述一遍以上的過程:將上面4x4=16個(gè)像素的圖然后在通道(channel)方向展平(flatten)。假設(shè)輸入的是RGB三通道的圖片,那么每個(gè)patch就有4x4=16個(gè)像素,然后每個(gè)像素有R、G、B三個(gè)值,所以展平后是16x3=48,所以通過Patch Partition后圖像的shape由 [H, W, 3]=(224,224,3)變成了 [H/4, W/4, 48]=(56,56,48)。上圖中的最左下角的第一張圖片是將一張圖劃分patch之后的樣子,簡(jiǎn)略示意(224,224)個(gè)pixel,3通道,展平之后展平為(56,56)個(gè)patch,48個(gè)通道。
之后就沒像素什么事了,后來都是在patch上的討論。
Linear Embedding詳解
參考:Swin Transformer全方位解讀【ICCV2021最佳論文】
(56,56,48)————Linear Embedding———>(56,56,96)
這個(gè)層優(yōu)點(diǎn)感覺像是卷積層,只不過是對(duì)通道數(shù)進(jìn)行卷積
Linear Embeding層對(duì)每個(gè)patch的channel數(shù)據(jù)做線性變換,由48變成C=(96),即圖像shape再由 [H/4, W/4, 48]=(56,56,48)變成了 [H/4, W/4, C]=(56,56,96)。每個(gè)圖片被劃分為56x56=3136個(gè)patch,每個(gè)patch又被編碼成96維的向量。
這一步在代碼上實(shí)現(xiàn)十分簡(jiǎn)單,就是一個(gè)Conv2D,把步長(zhǎng)和kernel size都設(shè)置為patch的長(zhǎng)度即可,可看:
這步以后再flatten一下,就可以把56x56x96變?yōu)?196x96。
Swin Transformer Block(W-MSA、SW-MSA)詳解
與標(biāo)準(zhǔn)transformer不同的就是紫色部分的兩個(gè)框,分別是W-MSA和SW-MSA。
W-MSA表示,在window內(nèi)部的Multi-Head Self-Attention,就是把window當(dāng)做獨(dú)立的全局來計(jì)算window中每個(gè)token兩兩注意力。
SW-MSA與W-MSA的一丟丟不一樣,就是將window的覆蓋范圍偏移一下,原文設(shè)置為window的邊長(zhǎng)的一半。
Swin Transformer Block注意這里的Block其實(shí)有兩種結(jié)構(gòu),如圖(b)中所示,這兩種結(jié)構(gòu)的不同之處僅在于一個(gè)使用了W-MSA結(jié)構(gòu),一個(gè)使用了SW-MSA結(jié)構(gòu)。而且這兩個(gè)結(jié)構(gòu)是成對(duì)使用的,先使用一個(gè)W-MSA結(jié)構(gòu)再使用一個(gè)SW-MSA結(jié)構(gòu)。所以你會(huì)發(fā)現(xiàn)堆疊Swin Transformer Block的次數(shù)都是偶數(shù)(因?yàn)槌蓪?duì)使用)。
W-MSA詳解
全稱為Window based Multi-head Self Attention。一張圖平分為7x7個(gè)window,這些window互相都沒有overlap。然后,每個(gè)window包含一定數(shù)量的token,直接對(duì)這些token計(jì)算window內(nèi)部的自注意力。 以分而治之的方法,遠(yuǎn)遠(yuǎn)降低了標(biāo)準(zhǔn)transformer的計(jì)算復(fù)雜度。以第1層為例,7x7個(gè)window,每個(gè)window包含8x8個(gè)patch,相當(dāng)于把標(biāo)準(zhǔn)transformer應(yīng)用在window上,而不是全圖上。
在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,將特征圖劃分成了多個(gè)不相交的區(qū)域(Window),Multi-Head Self-Attention只在每個(gè)窗口(Window)內(nèi)進(jìn)行。相對(duì)于Vision Transformer中直接對(duì)整個(gè)(Global)特征圖進(jìn)行Multi-Head Self-Attention,這樣做的目的是能夠減少計(jì)算量的,尤其是在淺層特征圖很大的時(shí)候。這樣做雖然減少了計(jì)算量但也會(huì)隔絕不同窗口之間的信息傳遞,所以在論文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通過此方法能夠讓信息在相鄰的窗口中進(jìn)行傳遞。
引入Windows Multi-head Self-Attention(W-MSA)模塊是為了減少計(jì)算量。如下圖所示,左側(cè)使用的是普通的Multi-head Self-Attention(MSA)模塊,對(duì)于feature map中的每個(gè)像素(或稱作token,patch)在Self-Attention計(jì)算過程中需要和所有的像素去計(jì)算。但在圖右側(cè),在使用Windows Multi-head Self-Attention(W-MSA)模塊時(shí),首先將feature map按照MxM(例子中的M=2)大小劃分成一個(gè)個(gè)Windows,然后單獨(dú)對(duì)每個(gè)Windows內(nèi)部進(jìn)行Self-Attention。
兩者的計(jì)算量具體差多少呢?原論文中有給出下面兩個(gè)公式,這里忽略了Softmax的計(jì)算復(fù)雜度。:
h代表feature map的高度
w代表feature map的寬度
C代表feature map的深度
M代表每個(gè)窗口(Windows)的大小
W-MSA模塊計(jì)算量詳解
好像在計(jì)算計(jì)算量的時(shí)候不計(jì)算加法的次數(shù)。hw是第一個(gè)矩陣的行數(shù),第一個(gè)C是一個(gè)行與列的計(jì)算量,第二個(gè)C是后一個(gè)矩陣的列數(shù)個(gè)前面的小計(jì)算量。
即矩陣L(ab)與P(bd)相乘,計(jì)算為abd,即(第一個(gè)矩陣的行數(shù))x(第一個(gè)矩陣的列數(shù))x(第二個(gè)矩陣的列數(shù))。
SW-MSA詳解
在局部window內(nèi)計(jì)算Self-Attention確實(shí)可以極大地降低計(jì)算復(fù)雜度,但是其也缺失了窗口之間的信息交互,降低了模型的表示能力。為了引入Cross-Window Connection,SwinTransformer采用了一種移位窗口劃分的方法來實(shí)現(xiàn)這一目標(biāo),窗口會(huì)在連續(xù)兩個(gè)SwinTransformer Blocks交替移動(dòng),使得不同Windows之間有機(jī)會(huì)進(jìn)行交互。
采用W-MSA模塊時(shí),只會(huì)在每個(gè)窗口內(nèi)進(jìn)行自注意力計(jì)算,所以窗口與窗口之間是無法進(jìn)行信息傳遞的。為了解決這個(gè)問題,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模塊,即進(jìn)行偏移的W-MSA。如下圖所示,左側(cè)使用的是剛剛講的W-MSA(假設(shè)是第L層),那么根據(jù)之前介紹的W-MSA和SW-MSA是成對(duì)使用的,那么第L+1層使用的就是SW-MSA(右側(cè)圖)。根據(jù)左右兩幅圖對(duì)比能夠發(fā)現(xiàn)窗口(Windows)發(fā)生了偏移(可以理解成窗口從左上角分別向右側(cè)和下方各偏移了?M/2?個(gè)像素)。看下偏移后的窗口(右側(cè)圖),比如對(duì)于第一行第2列的2x4的窗口,它能夠使第L層的第一排的兩個(gè)窗口信息進(jìn)行交流。再比如,第二行第二列的4x4的窗口,他能夠使第L層的四個(gè)窗口信息進(jìn)行交流,其他的同理。那么這就解決了不同窗口之間無法進(jìn)行信息交流的問題。
根據(jù)上圖,可以發(fā)現(xiàn)通過將窗口進(jìn)行偏移后,由原來的4個(gè)窗口變成9個(gè)窗口了。后面又要對(duì)每個(gè)窗口內(nèi)部進(jìn)行MSA,這樣做感覺又變麻煩了。為了解決這個(gè)麻煩,作者又提出而了Efficient batch computation for shifted configuration,一種更加高效的計(jì)算方法。下面是原論文給的示意圖。
下圖左側(cè)是剛剛通過偏移窗口后得到的新窗口,右側(cè)是為了方便大家理解,對(duì)每個(gè)窗口加上了一個(gè)標(biāo)識(shí)。然后0對(duì)應(yīng)的窗口標(biāo)記為區(qū)域A,3和6對(duì)應(yīng)的窗口標(biāo)記為區(qū)域B,1和2對(duì)應(yīng)的窗口標(biāo)記為區(qū)域C。
然后先將區(qū)域A和C移到最下方。
接著,再將區(qū)域A和B移至最右側(cè)。
移動(dòng)完后,4是一個(gè)單獨(dú)的窗口;將5和3合并成一個(gè)窗口;7和1合并成一個(gè)窗口;8、6、2和0合并成一個(gè)窗口。這樣又和原來一樣是4個(gè)4x4的窗口了,所以能夠保證計(jì)算量是一樣的。這里肯定有人會(huì)想,把不同的區(qū)域合并在一起(比如5和3)進(jìn)行MSA,這信息不就亂竄了嗎?是的,為了防止這個(gè)問題,在實(shí)際計(jì)算中使用的是masked MSA即帶蒙板mask的MSA,這樣就能夠通過設(shè)置蒙板來隔絕不同區(qū)域的信息了。關(guān)于mask如何使用,可以看下下面這幅圖,下圖是以上面的區(qū)域5和區(qū)域3為例。
對(duì)于該窗口內(nèi)的每一個(gè)像素(或稱token,patch)在進(jìn)行MSA計(jì)算時(shí),都要先生成對(duì)應(yīng)的query(q),key(k),value(v)。假設(shè)對(duì)于上圖的像素0而言,得到q0后要與每一個(gè)像素的k進(jìn)行匹配(match),假設(shè)α0,0代表q0與像素0對(duì)應(yīng)的k0進(jìn)行匹配的結(jié)果,那么同理可以得到α0,0至α0,15。按照普通的MSA計(jì)算,接下來就是SoftMax操作了。但對(duì)于這里的masked MSA,像素0是屬于區(qū)域5的,我們只想讓它和區(qū)域5內(nèi)的像素進(jìn)行匹配。那么我們可以將像素0與區(qū)域3中的所有像素匹配結(jié)果都減去100(例如α0,2,α0,3,α0,6 ,α 0,7等等),由于α的值都很小,一般都是零點(diǎn)幾的數(shù),將其中一些數(shù)減去100后再通過SoftMax得到對(duì)應(yīng)的權(quán)重都等于0了。所以對(duì)于像素0而言實(shí)際上還是只和區(qū)域5內(nèi)的像素進(jìn)行了MSA。那么對(duì)于其他像素也是同理。注意,在計(jì)算完后還要把數(shù)據(jù)給挪回到原來的位置上(例如上述的A,B,C區(qū)域)。
SW-MSA模塊計(jì)算量詳解
LayerNorm
相對(duì)位置偏置(relative position bias)
Patch Merging詳解
在每個(gè)Stage中首先要通過一個(gè)Patch Merging層進(jìn)行下采樣(Stage1除外)。如下圖所示,假設(shè)輸入Patch Merging的是一個(gè)4x4大小的單通道特征圖(feature map),Patch Merging會(huì)將每個(gè)2x2的相鄰像素劃分為一個(gè)patch,然后將每個(gè)patch中相同位置(同一顏色)像素給拼在一起就得到了4個(gè)feature map。接著將這四個(gè)feature map在深度方向進(jìn)行concat拼接,然后在通過一個(gè)LayerNorm層。最后通過一個(gè)全連接層在feature map的深度方向做線性變化,將feature map的深度由C變成C/2。通過這個(gè)簡(jiǎn)單的例子可以看出,通過Patch Merging層后,feature map的高和寬會(huì)減半,深度會(huì)翻倍。
Relative Position Bias詳解
真正使用到的可訓(xùn)練參數(shù)B是保存在relative position bias table表里的,這個(gè)表的長(zhǎng)度是等于(2M?1)×(2M?1)的。那么上述公式中的相對(duì)位置偏執(zhí)參數(shù)B是根據(jù)上面的相對(duì)位置索引表根據(jù)查relative position bias table表得到的,如上圖所示。
參考:
Swin Transformer全方位解讀【ICCV2021最佳論文】
Swin-Transformer網(wǎng)絡(luò)結(jié)構(gòu)詳解
2021-Swin Transformer Attention機(jī)制的詳細(xì)推導(dǎo)
(附加:CSDN上傳圖片去水印方法)
總結(jié)
以上是生活随笔為你收集整理的从零学习SwinTransformer的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何在Python中删除字符串中的所有反
- 下一篇: image caption优秀链接