文章目錄 前言 一、U-Net 二、U-Net的變體 2.1 3D-Unet 2.2 Attention U-Net 2.3 Inception U-Net 2.4 Residual U-Net 2.5 Recurrent U-Net 2.6 Dense U-net 2.7 U-Net++ 2.8 Adversarial U-Net 2.9 Ensemble U-Net 2.10 Comparison With Other Architectures 總結(jié)
前言
U-net是一種主要為圖像分割任務(wù)開(kāi)發(fā)的圖像分割技術(shù),在醫(yī)學(xué)圖像分割領(lǐng)域有很高的實(shí)用性。 為希望探索U-net的研究人員提供一個(gè)起點(diǎn)?;赨-net的架構(gòu)在醫(yī)學(xué)圖像分析中是有相當(dāng)潛力和價(jià)值的。自2017年依賴U-net論文的增長(zhǎng)證明了其作為醫(yī)學(xué)影像深度學(xué)習(xí)技術(shù)的地位。預(yù)計(jì)U-net將是主要的前進(jìn)道路之一。 忘了在哪聽(tīng)到的了,醫(yī)學(xué)圖像分割主要是解決位置和物體尺寸大小變化,擴(kuò)展路徑輸出的圖像有一定位置信息,加上收縮路徑的輸出對(duì)位置進(jìn)行了更加詳細(xì)的刻畫(huà);同時(shí)由于有池化類似于金字塔尺寸問(wèn)題得到了一定程度解決,所以U-net效果才會(huì)這么好。 目前主要參考下面這個(gè)綜述文章,并結(jié)合其中參考文獻(xiàn)進(jìn)行整理。不定時(shí)更新,處于小白階段,有錯(cuò)誤感謝指正,共同進(jìn)步。 綜述文章:https://ieeexplore.ieee.org/document/9446143
一、U-Net
U-net分為兩個(gè)部分。(中間對(duì)稱)一部分是左邊部分是典型的CNN架構(gòu)的收縮路徑(兩個(gè)連續(xù)的3×3卷積+ReLU激活單元+最大池化層),每一次下采樣后我們都把特征通道的數(shù)量加倍。第二部分是擴(kuò)展路徑(2×2上采樣+收縮路徑中對(duì)應(yīng)的層裁剪得到與上采樣得到的圖片大小相同大小的圖片concatenated上采樣的特征地圖上+2次連續(xù)的3×3conV+ReLU),每次使用反卷積都將特征通道數(shù)量減半,特征圖大小加倍。最后階段增加1×1卷積將特征圖減少到所需數(shù)量的通道并產(chǎn)生分割圖像。 之前它進(jìn)行卷積由于沒(méi)加padding,所以它每一次卷積過(guò)后圖片的w和h都會(huì)減2,現(xiàn)在一般加上padding,使每次卷積后的圖像大小不變,就省去了裁剪的操作(之前裁剪后才能與上采樣的圖片大小匹配,那篇文章中是說(shuō)圖片邊緣信息不重要裁剪不會(huì)造成太大影響)。 對(duì)卷積不熟悉的可以看這個(gè): 卷積算法:https://gitcode.net/mirrors/vdumoulin/conv_arithmetic?utm_source=csdn_github_accelerator
import torch
import torchvision
. transforms
. functional
from torch import nn
'' '
兩個(gè)
3 ×
3 卷積層
不管是收縮路徑還是擴(kuò)張路徑每一步都有兩個(gè)
3 ×
3 的卷積層,然后是ReLU激活。
在U
- Net論文中,它們使用
0 padding
, 這里使用
1 padding,以便最后的特征圖不會(huì)被裁剪
'' '
import torch
import torchvision
. transforms
. functional
from torch import nn
import cv2
from torchvision import transforms
'' '
兩個(gè)
3 ×
3 卷積層
不管是收縮路徑還是擴(kuò)張路徑每一步都有兩個(gè)
3 ×
3 的卷積層,然后是ReLU激活。
在U
- Net論文中,它們使用
0 padding
, 這里使用
1 padding,以便最后的特征圖不會(huì)被裁剪
'' '
class
DoubleConvolution ( nn
. Module
) : def
__init__ ( self
, in_channels
: int , out_channels
: int ) : #in_channels
: 輸入通道數(shù) out_channels
: 輸出通道數(shù)
super ( ) . __init__ ( ) self
. first
= nn
. Conv2d ( in_channels
, out_channels
, kernel_size
= 3 , padding
= 1 ) self
. act1
= nn
. ReLU ( ) #這兩行是第一個(gè)
3 ×
3 卷積層,從U
- net架構(gòu)圖可以看出在這一層圖像的通道數(shù)已經(jīng)變成out_channelself
. second
= nn
. Conv2d ( out_channels
, out_channels
, kernel_size
= 3 , padding
= 1 ) self
. act2
= nn
. ReLU ( ) #這兩行是第二個(gè)卷積
, 從U
- net架構(gòu)圖可以看出在這一層圖像的通道數(shù)不變#函數(shù)實(shí)例化,下面調(diào)用相應(yīng)的函數(shù)def
forward ( self
, x
: torch
. Tensor
) : x
= self
. first ( x
) x
= self
. act1 ( x
) x
= self
. second ( x
) return self
. act2 ( x
)
class
DownSample ( nn
. Module
) : #下采樣,收縮路徑中的每一步都使用
2 ×
2 最大池化層對(duì)特征圖進(jìn)行下采樣def
__init__ ( self
) : super ( ) . __init__ ( ) self
. pool
= nn
. MaxPool2d ( 2 ) #最大池化層def
forward ( self
, x
: torch
. Tensor
) : return self
. pool ( x
)
class
UpSample ( nn
. Module
) : #上采樣,擴(kuò)展路徑中每一步都使用
2 ×
2 上卷積def
__init__ ( self
, in_channels
: int , out_channels
: int ) : super ( ) . __init__ ( ) self
. up
= nn
. ConvTranspose2d ( in_channels
, out_channels
, kernel_size
= 2 , stride
= 2 ) '' '輸出數(shù)據(jù)體在空間上的尺寸可以通過(guò)輸入數(shù)據(jù)體尺寸
, 卷積層中卷積核尺寸(F對(duì)應(yīng)kernel_size),步長(zhǎng)(S對(duì)應(yīng)stride)和零填充的數(shù)量(P該函數(shù)中默認(rèn)為
0 )計(jì)算出來(lái)。W2
= ( W1
- F
+ 2 P
) / S
+ 1 ,上采樣大小減半
-> s
= 2 , w2
= w1
/ 2 -> P
= 0 , F
= 2 對(duì)轉(zhuǎn)置卷積感興趣的可以看這個(gè)https
: '' 'def
forward ( self
, x
: torch
. Tensor
) : return self
. up ( x
)
class
CropAndConcat ( nn
. Module
) : #裁剪并串聯(lián)要素地圖,在擴(kuò)展路徑中的每一步,來(lái)自收縮路徑的對(duì)應(yīng)特征圖與當(dāng)前特征圖連接def
forward ( self
, x
: torch
. Tensor
, contracting_x
: torch
. Tensor
) : contracting_x
= torchvision
. transforms
. functional
. center_crop ( contracting_x
, [ x
. shape
[ 2 ] , x
. shape
[ 3 ] ] ) # torchvision . transforms. functional. center_crop ( img : Tensor , output_size : List [ int ] ) , imgs是要中心裁剪的圖像,后面List是裁剪后的大小x
= torch
. cat ( [ x
, contracting_x
] , dim
= 1 ) return x
class
UNet ( nn
. Module
) : def
__init__ ( self
, in_channels
: int , out_channels
: int ) : super ( ) . __init__ ( ) self
. down_conv
= nn
. ModuleList ( [ DoubleConvolution ( i
, o
) for i
, o in
[ ( in_channels
, 64 ) , ( 64 , 128 ) , ( 128 , 256 ) , ( 256 , 512 ) ] ] ) #收縮路徑的雙層卷積。從
64 開(kāi)始的每一步中,特征的數(shù)量加倍self
. down_sample
= nn
. ModuleList ( [ DownSample ( ) for _ in
range ( 4 ) ] ) #循環(huán)
4 次self
. middle_conv
= DoubleConvolution ( 512 , 1024 ) #U
- net的底部,分辨率最低的兩個(gè)層self
. up_sample
= nn
. ModuleList ( [ UpSample ( i
, o
) for i
, o in
[ ( 1024 , 512 ) , ( 512 , 256 ) , ( 256 , 128 ) , ( 128 , 64 ) ] ] ) self
. up_conv
= nn
. ModuleList ( [ DoubleConvolution ( i
, o
) for i
, o in
[ ( 1024 , 512 ) , ( 512 , 256 ) , ( 256 , 128 ) , ( 128 , 64 ) ] ] ) self
. concat
= nn
. ModuleList ( [ CropAndConcat ( ) for _ in
range ( 4 ) ] ) self
. final_conv
= nn
. Conv2d ( 64 , out_channels
, kernel_size
= 1 ) def
forward ( self
, x
: torch
. Tensor
) : pass_through
= [ ] for i in
range ( len ( self
. down_conv
) ) : # 收縮路徑,ModuleList可以理解為這個(gè)模型中的列表,具體可以查看其他資料x(chóng)
= self
. down_conv
[ i
] ( x
) #兩個(gè)
3 x3卷積層pass_through
. append ( x
) #收集輸出,在元素結(jié)尾插入指定內(nèi)容x
= self
. down_sample
[ i
] ( x
) #下采樣x
= self
. middle_conv ( x
) for i in
range ( len ( self
. up_conv
) ) : #擴(kuò)張路徑x
= self
. up_sample
[ i
] ( x
) x
= self
. concat
[ i
] ( x
, pass_through
. pop ( ) ) #連續(xù)接收收縮路徑的輸出,pop刪除并返回最后一個(gè)元素。堆棧x
= self
. up_conv
[ i
] ( x
) x
= self
. final_conv ( x
) return x
二、U-Net的變體
2.1 3D-Unet
3D-Unet是將U-net中所有2D操作替換為對(duì)應(yīng)的3D操作。該篇文章中運(yùn)用了動(dòng)態(tài)彈性變形的數(shù)據(jù)增強(qiáng)方法。 論文:
https://arxiv.org/abs/1606.06650
為什么使用3D圖像? 是因?yàn)?D圖像可以提供額外的上下文信息。 3D U-net是U-net框架的基本拓展,支持3D立體分割。核心結(jié)構(gòu)和U-net一樣還是包含收縮和擴(kuò)張路徑,只是所有的2D操作都被相應(yīng)的3D操作,即3D Conv、3D max pooling 和 3D upconvolutions所替代,從而產(chǎn)生三維分割圖像。其中3D Conv與2DConv的區(qū)別的如下圖,3D Conv包含了深度信息。
很多生物醫(yī)學(xué)應(yīng)用中,只需很少的注釋示例就可以訓(xùn)練一個(gè)相當(dāng)好的泛化網(wǎng)絡(luò)。這是因?yàn)槊總€(gè)圖像已經(jīng)包含具有相應(yīng)變化的重復(fù)結(jié)構(gòu)。 3D Unet在生物醫(yī)學(xué)領(lǐng)域得到了很好應(yīng)用。例如下面這篇論文,創(chuàng)建了一個(gè)網(wǎng)絡(luò),該網(wǎng)絡(luò)允許在進(jìn)行診斷時(shí)進(jìn)行抽象的多級(jí)分割圖像。 3D U-net with Multi-level Deep Supervision: Fully Automatic Segmentation of Proximal Femur in 3D MR Images
2.2 Attention U-Net
論文:https://arxiv.org/abs/1804.03999 提出了用于醫(yī)學(xué)圖像處理的AG模型,該模型可以自動(dòng)學(xué)會(huì)關(guān)注不同形狀和大小的目標(biāo)結(jié)構(gòu)。 Attention U-Net的結(jié)構(gòu)如下圖所示。 Attention-Unet模型是以Unet模型為基礎(chǔ)的,可以從上圖看出,Attention-Unet和U-net的區(qū)別就在于decoder時(shí),從encoder提取的部分進(jìn)行了Attention Gate再進(jìn)行decoder。
代碼如下:
class AttU_Net ( nn
. Module
) : def __init__ ( self
, img_ch
= 3 , output_ch
= 1 ) : super ( AttU_Net
, self
) . __init__
( ) self
. Maxpool
= nn
. MaxPool2d
( kernel_size
= 2 , stride
= 2 ) self
. Conv1
= conv_block
( ch_in
= img_ch
, ch_out
= 64 ) self
. Conv2
= conv_block
( ch_in
= 64 , ch_out
= 128 ) self
. Conv3
= conv_block
( ch_in
= 128 , ch_out
= 256 ) self
. Conv4
= conv_block
( ch_in
= 256 , ch_out
= 512 ) self
. Conv5
= conv_block
( ch_in
= 512 , ch_out
= 1024 ) self
. Up5
= up_conv
( ch_in
= 1024 , ch_out
= 512 ) self
. Att5
= Attention_block
( F_g
= 512 , F_l
= 512 , F_int
= 256 ) self
. Up_conv5
= conv_block
( ch_in
= 1024 , ch_out
= 512 ) self
. Up4
= up_conv
( ch_in
= 512 , ch_out
= 256 ) self
. Att4
= Attention_block
( F_g
= 256 , F_l
= 256 , F_int
= 128 ) self
. Up_conv4
= conv_block
( ch_in
= 512 , ch_out
= 256 ) self
. Up3
= up_conv
( ch_in
= 256 , ch_out
= 128 ) self
. Att3
= Attention_block
( F_g
= 128 , F_l
= 128 , F_int
= 64 ) self
. Up_conv3
= conv_block
( ch_in
= 256 , ch_out
= 128 ) self
. Up2
= up_conv
( ch_in
= 128 , ch_out
= 64 ) self
. Att2
= Attention_block
( F_g
= 64 , F_l
= 64 , F_int
= 32 ) self
. Up_conv2
= conv_block
( ch_in
= 128 , ch_out
= 64 ) self
. Conv_1x1
= nn
. Conv2d
( 64 , output_ch
, kernel_size
= 1 , stride
= 1 , padding
= 0 ) def forward ( self
, x
) : x1
= self
. Conv1
( x
) x2
= self
. Maxpool
( x1
) x2
= self
. Conv2
( x2
) x3
= self
. Maxpool
( x2
) x3
= self
. Conv3
( x3
) x4
= self
. Maxpool
( x3
) x4
= self
. Conv4
( x4
) x5
= self
. Maxpool
( x4
) x5
= self
. Conv5
( x5
) d5
= self
. Up5
( x5
) x4
= self
. Att5
( g
= d5
, x
= x4
) d5
= torch
. cat
( ( x4
, d5
) , dim
= 1 ) d5
= self
. Up_conv5
( d5
) d4
= self
. Up4
( d5
) x3
= self
. Att4
( g
= d4
, x
= x3
) d4
= torch
. cat
( ( x3
, d4
) , dim
= 1 ) d4
= self
. Up_conv4
( d4
) d3
= self
. Up3
( d4
) x2
= self
. Att3
( g
= d3
, x
= x2
) d3
= torch
. cat
( ( x2
, d3
) , dim
= 1 ) d3
= self
. Up_conv3
( d3
) d2
= self
. Up2
( d3
) x1
= self
. Att2
( g
= d2
, x
= x1
) d2
= torch
. cat
( ( x1
, d2
) , dim
= 1 ) d2
= self
. Up_conv2
( d2
) d1
= self
. Conv_1x1
( d2
) return d1
該模型將任務(wù)簡(jiǎn)化為定位和分割。AGs能夠抑制不相關(guān)背景區(qū)域的響應(yīng),注意力系數(shù)α∈[0,1]識(shí)別顯著的圖像區(qū)域,修剪特征響應(yīng),僅僅保留與特定任務(wù)相關(guān)的響應(yīng)。AGs合并到標(biāo)準(zhǔn)U-Net架構(gòu)中,以突出通過(guò)skip連接的顯著特征。將從粗尺度提取出的信息應(yīng)用到門控中,可以消除跳躍連接產(chǎn)生的不相關(guān)和嘈雜的響應(yīng)。
AG的結(jié)構(gòu)如下圖所示:
class Attention_block ( nn
. Module
) : def __init__ ( self
, F_g
, F_l
, F_int
) : super ( Attention_block
, self
) . __init__
( ) self
. W_g
= nn
. Sequential
( nn
. Conv2d
( F_g
, F_int
, kernel_size
= 1 , stride
= 1 , padding
= 0 , bias
= True ) , nn
. BatchNorm2d
( F_int
) ) self
. W_x
= nn
. Sequential
( nn
. Conv2d
( F_l
, F_int
, kernel_size
= 1 , stride
= 1 , padding
= 0 , bias
= True ) , nn
. BatchNorm2d
( F_int
) ) self
. psi
= nn
. Sequential
( nn
. Conv2d
( F_int
, 1 , kernel_size
= 1 , stride
= 1 , padding
= 0 , bias
= True ) , nn
. BatchNorm2d
( 1 ) , nn
. Sigmoid
( ) ) self
. relu
= nn
. ReLU
( inplace
= True ) def forward ( self
, g
, x
) : g1
= self
. W_g
( g
) x1
= self
. W_x
( x
) psi
= self
. relu
( g1
+ x1
) psi
= self
. psi
( psi
) return x
* psi
2.3 Inception U-Net
大多數(shù)圖像處理算法傾向于使用固定大小的filters進(jìn)行卷積,但是調(diào)整模型以找到正確的篩選器大小通常很麻煩;此外,固定大小的filters僅適用于突出部分大小相似的圖像,不適用于突出部分的形狀大小變化較大的圖像。一種解決方法是用更深的網(wǎng)絡(luò),另一種是Inception network。 Inception block的結(jié)構(gòu)如下圖所示,以下來(lái)自沐神的動(dòng)手學(xué)深度學(xué)習(xí)的圖片。https://zh-v2.d2l.ai/ Inception塊由四條并行路徑組成。 前三條路徑使用窗口大小為 1×1、 3×3和 5×5 的卷積層,從不同空間大小中提取信息。 中間的兩條路徑在輸入上執(zhí)行 1×1卷積,以減少通道數(shù),從而降低模型的復(fù)雜性。 第四條路徑使用 3×3 最大匯聚層,然后使用 1×1卷積層來(lái)改變通道數(shù)。 這四條路徑都使用合適的填充來(lái)使輸入與輸出的高和寬一致,最后我們將每條線路的輸出在通道維度上連結(jié),并構(gòu)成Inception塊的輸出。在Inception塊中,通常調(diào)整的超參數(shù)是每層輸出通道數(shù)。(以上的話也是來(lái)自李沐的動(dòng)手學(xué)深度學(xué)習(xí)https://zh-v2.d2l.ai/)
class Inception ( nn
. Module
) : def __init__ ( self
, in_channels
, c1
, c2
, c3
, c4
, ** kwargs
) : super ( Inception
, self
) . __init__
( ** kwargs
) self
. p1_1
= nn
. Conv2d
( in_channels
, c1
, kernel_size
= 1 ) self
. p2_1
= nn
. Conv2d
( in_channels
, c2
[ 0 ] , kernel_size
= 1 ) self
. p2_2
= nn
. Conv2d
( c2
[ 0 ] , c2
[ 1 ] , kernel_size
= 3 , padding
= 1 ) self
. p3_1
= nn
. Conv2d
( in_channels
, c3
[ 0 ] , kernel_size
= 1 ) self
. p3_2
= nn
. Conv2d
( c3
[ 0 ] , c3
[ 1 ] , kernel_size
= 5 , padding
= 2 ) self
. p4_1
= nn
. MaxPool2d
( kernel_size
= 3 , stride
= 1 , padding
= 1 ) self
. p4_2
= nn
. Conv2d
( in_channels
, c4
, kernel_size
= 1 ) def forward ( self
, x
) : p1
= F
. relu
( self
. p1_1
( x
) ) p2
= F
. relu
( self
. p2_2
( F
. relu
( self
. p2_1
( x
) ) ) ) p3
= F
. relu
( self
. p3_2
( F
. relu
( self
. p3_1
( x
) ) ) ) p4
= F
. relu
( self
. p4_2
( self
. p4_1
( x
) ) ) return torch
. cat
( ( p1
, p2
, p3
, p4
) , dim
= 1 )
下面借助《DENSE-INception-U-net-for-medical-image-segmentation》的代碼看怎么應(yīng)用到U-Net中。 下面是DIU-Net的模型。 這篇文章把inception module和dense connection的結(jié)構(gòu)結(jié)合在一起,并且基于U-Net構(gòu)建這個(gè)網(wǎng)絡(luò)架構(gòu)。這個(gè)網(wǎng)絡(luò)架構(gòu)包括analysis path 和 synthesis path,這兩個(gè)路徑由四種類型的模塊構(gòu)成,分別是:Inception-Res 模塊、Dense Inception 模塊、down-sample模塊,up-sample模塊。3個(gè) Inception-Res 模塊、一個(gè)Dense Inception模塊和四個(gè)down-sample模塊構(gòu)成了分析路徑。三個(gè) Inception-Res 模塊、一個(gè)Dense Inception 模塊和四個(gè)up-sample模塊構(gòu)成了合成管道。單個(gè)Dense-Inception模塊在模型中間,它比其它部分含有更多的 inception 層。下面是各個(gè)模塊的結(jié)構(gòu)。 Inception-Res block: 多用11、33 卷積、 AdaptiveAvgPool2d替代全連接 既可以加快速度,又可以達(dá)到與全連接、大卷積核一樣的效果。還有一個(gè)規(guī)律,就是圖像尺寸減半,同時(shí)通道數(shù)指數(shù)增長(zhǎng),可以很好地保留特征。小核多卷幾次比大核效果好,一個(gè)5x5Conv可以被兩次3x3Conv代替,所以Inception block中的5x5Conv用兩次3x3Conv代替,結(jié)構(gòu)如下圖。(來(lái)自Rethinking the Inception Architecture for Computer Vision)
代碼如下,按照?qǐng)D來(lái)編寫(xiě)即可。
def inception_res_block_down ( inputs
, numFilters
) : c1
= Conv2D
( numFilters
, ( 1 , 1 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( inputs
) c11
= Conv2D
( numFilters
, ( 1 , 1 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( c1
) c12
= Conv2D
( numFilters
, ( 1 , 1 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( c1
) c13
= Conv2D
( numFilters
, ( 1 , 1 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( c1
) c11
= BatchNormalization
( ) ( c11
) c12
= Conv2D
( numFilters
, ( 3 , 3 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( c12
) c12
= BatchNormalization
( ) ( c12
) c13
= Conv2D
( numFilters
, ( 3 , 3 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( c13
) c13
= BatchNormalization
( ) ( c13
) c13
= Conv2D
( numFilters
, ( 3 , 3 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( c13
) c13
= BatchNormalization
( ) ( c13
) inception_module
= concatenate
( [ c11
, c12
, c13
] , axis
= 3 ) concat
= Conv2D
( numFilters
, ( 1 , 1 ) , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( inception_module
) out
= Add
( ) ( [ concat
, c1
] ) return out
Dense-Inception block : Down-sample and Up sample block :
2.4 Residual U-Net
這個(gè)是基于Res-Net的架構(gòu)。訓(xùn)練很深的網(wǎng)絡(luò)是一個(gè)很難的事情,深度變大,精度變差。按理說(shuō)很深的網(wǎng)絡(luò)有更多層可以學(xué)到更多,但是SGD找不到這個(gè)比較優(yōu)的解,也就是說(shuō)網(wǎng)絡(luò)訓(xùn)練不動(dòng)。 使用李沐論文精度bilibili最后10分鐘里面的插圖。(不會(huì)敲公式😂)
Resnet訓(xùn)練比較快,主要是因?yàn)樗荻壬媳3趾芎?。大部分網(wǎng)絡(luò)它的梯度越來(lái)越小是一個(gè)累乘。而Resnet還保留了上一層的梯度,梯度保持很好。 Resnet詳細(xì)的介紹可以參考論文:Deep Residual Learning for Image Recognition,李沐老師的視頻,知乎你必須要知道CNN模型:ResNet R2U-Net架構(gòu)如下圖。代碼和圖來(lái)自https://github.com/LeeJunHyun/Image_Segmentation
我們看一下他的RRCNN_block部分,返回值是改變通道數(shù)的輸入加上經(jīng)過(guò)兩次Sequential后的輸出,就相當(dāng)于輸入+輸出。 注意Recurrent Conv block和Residual Conv unit的區(qū)別。Residual Conv unit是這里提到的resnet,輸入直接連到輸出。Recurrent Conv block需要循環(huán)t次,第n次包含前n次的部分信息。
class RRCNN_block ( nn
. Module
) : def __init__ ( self
, ch_in
, ch_out
, t
= 2 ) : super ( RRCNN_block
, self
) . __init__
( ) self
. RCNN
= nn
. Sequential
( Recurrent_block
( ch_out
, t
= t
) , Recurrent_block
( ch_out
, t
= t
) ) self
. Conv_1x1
= nn
. Conv2d
( ch_in
, ch_out
, kernel_size
= 1 , stride
= 1 , padding
= 0 ) def forward ( self
, x
) : x
= self
. Conv_1x1
( x
) x1
= self
. RCNN
( x
) return x
+ x1
上圖是帶有skip連接的三個(gè)連續(xù)ResNet塊。skip信號(hào)通過(guò)逐元素加法與輸出結(jié)合。最常見(jiàn)的ResNet實(shí)現(xiàn)是雙層skip(如圖所示)或三層skip。 模型的代碼如下。
class R2U_Net ( nn
. Module
) : def __init__ ( self
, img_ch
= 3 , output_ch
= 1 , t
= 2 ) : super ( R2U_Net
, self
) . __init__
( ) self
. Maxpool
= nn
. MaxPool2d
( kernel_size
= 2 , stride
= 2 ) self
. Upsample
= nn
. Upsample
( scale_factor
= 2 ) self
. RRCNN1
= RRCNN_block
( ch_in
= img_ch
, ch_out
= 64 , t
= t
) self
. RRCNN2
= RRCNN_block
( ch_in
= 64 , ch_out
= 128 , t
= t
) self
. RRCNN3
= RRCNN_block
( ch_in
= 128 , ch_out
= 256 , t
= t
) self
. RRCNN4
= RRCNN_block
( ch_in
= 256 , ch_out
= 512 , t
= t
) self
. RRCNN5
= RRCNN_block
( ch_in
= 512 , ch_out
= 1024 , t
= t
) self
. Up5
= up_conv
( ch_in
= 1024 , ch_out
= 512 ) self
. Up_RRCNN5
= RRCNN_block
( ch_in
= 1024 , ch_out
= 512 , t
= t
) self
. Up4
= up_conv
( ch_in
= 512 , ch_out
= 256 ) self
. Up_RRCNN4
= RRCNN_block
( ch_in
= 512 , ch_out
= 256 , t
= t
) self
. Up3
= up_conv
( ch_in
= 256 , ch_out
= 128 ) self
. Up_RRCNN3
= RRCNN_block
( ch_in
= 256 , ch_out
= 128 , t
= t
) self
. Up2
= up_conv
( ch_in
= 128 , ch_out
= 64 ) self
. Up_RRCNN2
= RRCNN_block
( ch_in
= 128 , ch_out
= 64 , t
= t
) self
. Conv_1x1
= nn
. Conv2d
( 64 , output_ch
, kernel_size
= 1 , stride
= 1 , padding
= 0 ) def forward ( self
, x
) : x1
= self
. RRCNN1
( x
) x2
= self
. Maxpool
( x1
) x2
= self
. RRCNN2
( x2
) x3
= self
. Maxpool
( x2
) x3
= self
. RRCNN3
( x3
) x4
= self
. Maxpool
( x3
) x4
= self
. RRCNN4
( x4
) x5
= self
. Maxpool
( x4
) x5
= self
. RRCNN5
( x5
) d5
= self
. Up5
( x5
) d5
= torch
. cat
( ( x4
, d5
) , dim
= 1 ) d5
= self
. Up_RRCNN5
( d5
) d4
= self
. Up4
( d5
) d4
= torch
. cat
( ( x3
, d4
) , dim
= 1 ) d4
= self
. Up_RRCNN4
( d4
) d3
= self
. Up3
( d4
) d3
= torch
. cat
( ( x2
, d3
) , dim
= 1 ) d3
= self
. Up_RRCNN3
( d3
) d2
= self
. Up2
( d3
) d2
= torch
. cat
( ( x1
, d2
) , dim
= 1 ) d2
= self
. Up_RRCNN2
( d2
) d1
= self
. Conv_1x1
( d2
) return d1
2.5 Recurrent U-Net
上圖是RNN的結(jié)構(gòu),就是它的當(dāng)前的輸出不僅與當(dāng)前輸入xt有關(guān)還與包含之前信息的ht-1有關(guān)。 遞歸神經(jīng)網(wǎng)絡(luò)是一種神經(jīng)網(wǎng)絡(luò),最初被設(shè)計(jì)用于分析諸如文本或音頻數(shù)據(jù)之類的序列數(shù)據(jù)。該網(wǎng)絡(luò)以這樣的方式設(shè)計(jì),即節(jié)點(diǎn)的輸出基于來(lái)自相同節(jié)點(diǎn)的先前輸出而改變,即,與傳統(tǒng)前饋網(wǎng)絡(luò)相反的反饋回路。這個(gè)反饋回路也稱為循環(huán)連接,它創(chuàng)建一個(gè)內(nèi)部狀態(tài)或記憶,為節(jié)點(diǎn)提供以離散時(shí)間步長(zhǎng)改變輸出的時(shí)間屬性。當(dāng)擴(kuò)展到整個(gè)層時(shí),這允許網(wǎng)絡(luò)處理來(lái)自先前數(shù)據(jù)的上下文信息。 上圖是循環(huán)神經(jīng)網(wǎng)絡(luò)。在這個(gè)簡(jiǎn)單的網(wǎng)絡(luò)中,第二層和第三層是循環(huán)層。循環(huán)層中的每個(gè)神經(jīng)元在離散時(shí)間周期接收來(lái)自其輸出的反饋以及來(lái)自前一層的新信息,并相應(yīng)地產(chǎn)生新輸出。此組件允許網(wǎng)絡(luò)處理順序信息。
yijkl(t)=(wkf)Txlf(i,j)(t)+(wkr)Txlr(i,j)(t?1)+bk\begin{equation*} y_{ijk}^{l}\left ({t }\right)=\left ({w_{k}^{f} }\right)^{T}x_{l}^{f\left ({i,j }\right)}\left ({t }\right) +\,\left ({w_{k}^{r} }\right)^{T}x_{l}^{r\left ({i,j }\right)}\left ({t-1 }\right)+b_{k}\end{equation*} y ijk l ? ( t ) = ( w k f ? ) T x l f ( i , j ) ? ( t ) + ( w k r ? ) T x l r ( i , j ) ? ( t ? 1 ) + b k ? ? 其中xfl(t)是前饋輸入,xrl(t?1)是第l層的遞歸輸入,wfk是前饋權(quán)重,wrk是遞歸權(quán)重,bk是第k個(gè)特征映射的偏差。 下面的代碼是R2U-net中的Recurrent block。其中循環(huán)t次。當(dāng)前輸出等于當(dāng)前輸入進(jìn)行conv后的結(jié)果加上上一時(shí)刻輸出在做Conv。包含之前時(shí)刻的信息。
class Recurrent_block ( nn
. Module
) : def __init__ ( self
, ch_out
, t
= 2 ) : super ( Recurrent_block
, self
) . __init__
( ) self
. t
= tself
. ch_out
= ch_outself
. conv
= nn
. Sequential
( nn
. Conv2d
( ch_out
, ch_out
, kernel_size
= 3 , stride
= 1 , padding
= 1 , bias
= True ) , nn
. BatchNorm2d
( ch_out
) , nn
. ReLU
( inplace
= True ) ) def forward ( self
, x
) : for i
in range ( self
. t
) : if i
== 0 : x1
= self
. conv
( x
) x1
= self
. conv
( x
+ x1
) return x1
2.6 Dense U-net
DenseNet對(duì)于每一層所有前一層地特征圖都用作輸入,其自己的特征圖做所有后序?qū)拥妮斎搿?br /> 優(yōu)勢(shì):它們緩解了梯度消失問(wèn)題,加強(qiáng)了特征傳播,鼓勵(lì)了特征重用,并大大減少了參數(shù)的數(shù)量。 DenseNet傾向于在精度方面產(chǎn)生一致的提高,而沒(méi)有任何性能下降或過(guò)度擬合的跡象。DenseNets可能是基于卷積特征構(gòu)建的各種計(jì)算機(jī)視覺(jué)任務(wù)的良好特征提取器。 為了確保網(wǎng)絡(luò)中各層之間的最大信息流,將所有層(具有匹配的特征映射大小)直接相互連接。為了保持前饋特性,每一層從前面的所有層獲取額外的輸入,并將自己特征圖傳遞給所有后序?qū)?#xff08;具有匹配的特征映射大小)直接相互連接。為了保持前饋特性,每一層從前面的所有曾獲取額外的輸入,并將自己的特征圖傳遞個(gè)所有后序?qū)印?br /> 拼接單元從所有先前層接收特征圖并將其傳遞到下一層。這可確保任何給定圖層都具有來(lái)自塊中任何先前圖層的上下文信息。 借助這個(gè)論文:Bi-Directional ConvLSTM U-Net with Densley Connected Convolutions 這個(gè)代碼:https://github.com/rezazad68/BCDU-Net 下圖是這篇論文的模型(看代碼的時(shí)候可以借助這個(gè)圖)。 這個(gè)模型最下面是dense block。Dense U-net是每一個(gè)塊都是Dense block。 上圖結(jié)構(gòu)在代碼中是#D1,#D2,#D3標(biāo)注的部分??梢钥吹紻3的輸入是D1和D2的concatenate。
def BCDU_net_D3 ( input_size
= ( 256 , 256 , 1 ) ) : N
= input_size
[ 0 ] inputs
= Input
( input_size
) conv1
= Conv2D
( 64 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( inputs
) conv1
= Conv2D
( 64 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv1
) pool1
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( conv1
) conv2
= Conv2D
( 128 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( pool1
) conv2
= Conv2D
( 128 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv2
) pool2
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( conv2
) conv3
= Conv2D
( 256 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( pool2
) conv3
= Conv2D
( 256 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv3
) drop3
= Dropout
( 0.5 ) ( conv3
) pool3
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( conv3
) conv4
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( pool3
) conv4_1
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv4
) drop4_1
= Dropout
( 0.5 ) ( conv4_1
) conv4_2
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( drop4_1
) conv4_2
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv4_2
) conv4_2
= Dropout
( 0.5 ) ( conv4_2
) merge_dense
= concatenate
( [ conv4_2
, drop4_1
] , axis
= 3 ) conv4_3
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( merge_dense
) conv4_3
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv4_3
) drop4_3
= Dropout
( 0.5 ) ( conv4_3
) up6
= Conv2DTranspose
( 256 , kernel_size
= 2 , strides
= 2 , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( drop4_3
) up6
= BatchNormalization
( axis
= 3 ) ( up6
) up6
= Activation
( 'relu' ) ( up6
) x1
= Reshape
( target_shape
= ( 1 , np
. int32
( N
/ 4 ) , np
. int32
( N
/ 4 ) , 256 ) ) ( drop3
) x2
= Reshape
( target_shape
= ( 1 , np
. int32
( N
/ 4 ) , np
. int32
( N
/ 4 ) , 256 ) ) ( up6
) merge6
= concatenate
( [ x1
, x2
] , axis
= 1 ) merge6
= ConvLSTM2D
( filters
= 128 , kernel_size
= ( 3 , 3 ) , padding
= 'same' , return_sequences
= False , go_backwards
= True , kernel_initializer
= 'he_normal' ) ( merge6
) conv6
= Conv2D
( 256 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( merge6
) conv6
= Conv2D
( 256 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv6
) up7
= Conv2DTranspose
( 128 , kernel_size
= 2 , strides
= 2 , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv6
) up7
= BatchNormalization
( axis
= 3 ) ( up7
) up7
= Activation
( 'relu' ) ( up7
) x1
= Reshape
( target_shape
= ( 1 , np
. int32
( N
/ 2 ) , np
. int32
( N
/ 2 ) , 128 ) ) ( conv2
) x2
= Reshape
( target_shape
= ( 1 , np
. int32
( N
/ 2 ) , np
. int32
( N
/ 2 ) , 128 ) ) ( up7
) merge7
= concatenate
( [ x1
, x2
] , axis
= 1 ) merge7
= ConvLSTM2D
( filters
= 64 , kernel_size
= ( 3 , 3 ) , padding
= 'same' , return_sequences
= False , go_backwards
= True , kernel_initializer
= 'he_normal' ) ( merge7
) conv7
= Conv2D
( 128 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( merge7
) conv7
= Conv2D
( 128 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv7
) up8
= Conv2DTranspose
( 64 , kernel_size
= 2 , strides
= 2 , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv7
) up8
= BatchNormalization
( axis
= 3 ) ( up8
) up8
= Activation
( 'relu' ) ( up8
) x1
= Reshape
( target_shape
= ( 1 , N
, N
, 64 ) ) ( conv1
) x2
= Reshape
( target_shape
= ( 1 , N
, N
, 64 ) ) ( up8
) merge8
= concatenate
( [ x1
, x2
] , axis
= 1 ) merge8
= ConvLSTM2D
( filters
= 32 , kernel_size
= ( 3 , 3 ) , padding
= 'same' , return_sequences
= False , go_backwards
= True , kernel_initializer
= 'he_normal' ) ( merge8
) conv8
= Conv2D
( 64 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( merge8
) conv8
= Conv2D
( 64 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv8
) conv8
= Conv2D
( 2 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv8
) conv9
= Conv2D
( 1 , 1 , activation
= 'sigmoid' ) ( conv8
) model
= Model
( input = inputs
, output
= conv9
) model
. compile ( optimizer
= Adam
( lr
= 1e-4 ) , loss
= 'binary_crossentropy' , metrics
= [ 'accuracy' ] ) return model
總之,這篇文章充分利用U-Net、雙向ConvLSTM(BConvLSTM)和Dense Conv,并且用BN加快網(wǎng)絡(luò)收斂速度。證明了通過(guò)在skip連接中包含BConvLSTM并插入密集連接的卷積塊,網(wǎng)絡(luò)能夠捕獲更多區(qū)分信息,從而產(chǎn)生更精準(zhǔn)的分割結(jié)果。 代碼來(lái)源:https://github.com/THUHoloLab/Dense-U-net 知道dense_block的代碼,替換對(duì)應(yīng)的U-net里面卷積的代碼即可。 可以看出第一層是輸入input_tensor(假設(shè)為第0層輸出),進(jìn)行卷積。第二層輸入是第0層輸出input_tensor+第一層輸出x1的concat。第三層是第0層輸出input_tensor+第一層輸出x1+第二層輸出x2的concat。以此類推。
def dens_block ( input_tensor
, nb_filter
) : x1
= Conv_Block
( input_tensor
, nb_filter
) add1
= concatenate
( [ x1
, input_tensor
] , axis
= - 1 ) x2
= Conv_Block
( add1
, nb_filter
) add2
= concatenate
( [ x1
, input_tensor
, x2
] , axis
= - 1 ) x3
= Conv_Block
( add2
, nb_filter
) return x3
def unet ( input_shape
= ( 512 , 512 , 3 ) ) : inputs
= Input
( input_shape
) x
= Conv2D
( 32 , 7 , kernel_initializer
= 'he_normal' , padding
= 'same' , strides
= 1 , use_bias
= False , kernel_regularizer
= l2
( 1e-4 ) ) ( inputs
) down1
= dens_block
( x
, nb_filter
= 64 ) pool1
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( down1
) down2
= dens_block
( pool1
, nb_filter
= 64 ) pool2
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( down2
) down3
= dens_block
( pool2
, nb_filter
= 128 ) pool3
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( down3
) down4
= dens_block
( pool3
, nb_filter
= 256 ) pool4
= MaxPooling2D
( pool_size
= ( 2 , 2 ) ) ( down4
) conv5
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( pool4
) conv5
= Conv2D
( 512 , 3 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( conv5
) drop5
= Dropout
( 0.5 ) ( conv5
) up6
= UpSampling2D
( size
= ( 2 , 2 ) ) ( drop5
) add6
= concatenate
( [ down4
, up6
] , axis
= 3 ) up6
= dens_block
( add6
, nb_filter
= 256 ) up7
= UpSampling2D
( size
= ( 2 , 2 ) ) ( up6
) add7
= concatenate
( [ down3
, up7
] , axis
= 3 ) up7
= dens_block
( add7
, nb_filter
= 128 ) up8
= UpSampling2D
( size
= ( 2 , 2 ) ) ( up7
) add8
= concatenate
( [ down2
, up8
] , axis
= - 1 ) up8
= dens_block
( add8
, nb_filter
= 64 ) up9
= UpSampling2D
( size
= ( 2 , 2 ) ) ( up8
) add9
= concatenate
( [ down1
, up9
] , axis
= - 1 ) up9
= dens_block
( add9
, nb_filter
= 64 ) conv10
= Conv2D
( 32 , 7 , strides
= 1 , activation
= 'relu' , padding
= 'same' , kernel_initializer
= 'he_normal' ) ( up9
) conv10
= Conv2D
( 3 , 1 , activation
= 'sigmoid' ) ( conv10
) model
= Model
( input = inputs
, output
= conv10
) print ( model
. summary
( ) ) return model
2.7 U-Net++
論文:UNet++: A Nested U-Net Architecture for Medical Image Segmentation 參考:研習(xí)U-Net(非常推薦) 知乎上研習(xí)U-Net講解了U-Net++的來(lái)源。U-Net的簡(jiǎn)易結(jié)構(gòu)如下圖所示(來(lái)源是研習(xí)U-Net)。 作者提出了為什么U-Net有4層?然后他進(jìn)行了1-4層U-Net的實(shí)驗(yàn),結(jié)果并不是越深越好,所以說(shuō)淺層和深層都有其對(duì)應(yīng)的信息。對(duì)于特征提取階段,淺層結(jié)構(gòu)可以抓取圖像的一些簡(jiǎn)單的特征,而深層結(jié)構(gòu)因?yàn)楦惺芤按罅?#xff0c;經(jīng)過(guò)卷積操作多了,能抓取到圖像的一些的抽象特征。然后提出淺層結(jié)構(gòu)和深層結(jié)構(gòu)都重要,U-Net為什么只在4層之后才返回去,也就是只去抓深層特征。 作者想把1-4層U-Net結(jié)構(gòu)合在一起,如下圖,這樣他們?cè)诰幋a器那邊是共享參數(shù)的,但是更新模型參數(shù)的時(shí)候梯度只能沿著4層的U-Net網(wǎng)絡(luò)傳播,不經(jīng)過(guò)1-3層U-Net的解碼器。因?yàn)長(zhǎng)與x0,4連接,x0,4與中間結(jié)構(gòu)不相連。無(wú)法訓(xùn)練。 解決這個(gè)問(wèn)題他提了兩個(gè)方法。一種如下圖。把長(zhǎng)連接換成了短連接。這樣梯度可以傳播,但是缺少了長(zhǎng)連接的優(yōu)勢(shì)。 長(zhǎng)連接skip的優(yōu)勢(shì): 1. Fights the vanishing(消失的) gradient problem. 2. Learns pyramid level features 3. Recover info(信息) loss in down-sampling 怎么既能發(fā)揮長(zhǎng)連接的優(yōu)勢(shì)又能使網(wǎng)絡(luò)能夠訓(xùn)練就是U-Net++的結(jié)構(gòu),既有長(zhǎng)連接,又有短鏈接。很類似于Dense連接。 第二個(gè)解決不能訓(xùn)練的方法是加deep supervision。然后作者在U-Net、上面提到的網(wǎng)絡(luò)和U-net++上添加了deep supervision。最終在U-Net++上面效果更好。 使用Deep supervision可以進(jìn)行剪枝。訓(xùn)練的時(shí)候用U-net++,測(cè)試的時(shí)候剪掉最右邊一層,可以提高運(yùn)行速度,減少參數(shù)量。 做實(shí)驗(yàn)的時(shí)候?yàn)榱蓑?yàn)證是模型結(jié)構(gòu)使精度提高,而不是單純的參數(shù)增加導(dǎo)致的,作者構(gòu)造了wide U-Net,增加每一層的卷積參數(shù)。 圖1:(a)UNet++由編碼器和解碼器組成,它們通過(guò)一系列嵌套的密集卷積塊連接。UNet++背后的主要思想是在融合之前結(jié)合編碼器和解碼器特征圖之間的語(yǔ)義差距。例如,(X0,0,X1,3)之間的語(yǔ)義差距是使用具有三個(gè)卷積層的密集卷積塊來(lái)結(jié)合的。在上圖中,黑色表示原始U-Net,綠色和藍(lán)色表示跳過(guò)路徑上的密集卷積塊,紅色表示深度監(jiān)督。 紅色、綠色和藍(lán)色組件將UNet++與U-Net區(qū)分開(kāi)來(lái)。(b) 對(duì)UNet++的第一個(gè)跳躍途徑的詳細(xì)分析。(c) UNet++如果在深度監(jiān)督下進(jìn)行訓(xùn)練,可以在推理時(shí)進(jìn)行修剪。
class conv_block_nested ( nn
. Module
) : def __init__ ( self
, in_ch
, mid_ch
, out_ch
) : super ( conv_block_nested
, self
) . __init__
( ) self
. activation
= nn
. ReLU
( inplace
= True ) self
. conv1
= nn
. Conv2d
( in_ch
, mid_ch
, kernel_size
= 3 , padding
= 1 , bias
= True ) self
. bn1
= nn
. BatchNorm2d
( mid_ch
) self
. conv2
= nn
. Conv2d
( mid_ch
, out_ch
, kernel_size
= 3 , padding
= 1 , bias
= True ) self
. bn2
= nn
. BatchNorm2d
( out_ch
) def forward ( self
, x
) : x
= self
. conv1
( x
) x
= self
. bn1
( x
) x
= self
. activation
( x
) x
= self
. conv2
( x
) x
= self
. bn2
( x
) output
= self
. activation
( x
) return output
class NestedUNet ( nn
. Module
) : """Implementation of this paper:https://arxiv.org/pdf/1807.10165.pdf""" def __init__ ( self
, in_ch
= 3 , out_ch
= 1 ) : super ( NestedUNet
, self
) . __init__
( ) n1
= 64 filters
= [ n1
, n1
* 2 , n1
* 4 , n1
* 8 , n1
* 16 ] self
. pool
= nn
. MaxPool2d
( kernel_size
= 2 , stride
= 2 ) self
. Up
= nn
. Upsample
( scale_factor
= 2 , mode
= 'bilinear' , align_corners
= True ) self
. conv0_0
= conv_block_nested
( in_ch
, filters
[ 0 ] , filters
[ 0 ] ) self
. conv1_0
= conv_block_nested
( filters
[ 0 ] , filters
[ 1 ] , filters
[ 1 ] ) self
. conv2_0
= conv_block_nested
( filters
[ 1 ] , filters
[ 2 ] , filters
[ 2 ] ) self
. conv3_0
= conv_block_nested
( filters
[ 2 ] , filters
[ 3 ] , filters
[ 3 ] ) self
. conv4_0
= conv_block_nested
( filters
[ 3 ] , filters
[ 4 ] , filters
[ 4 ] ) self
. conv0_1
= conv_block_nested
( filters
[ 0 ] + filters
[ 1 ] , filters
[ 0 ] , filters
[ 0 ] ) self
. conv1_1
= conv_block_nested
( filters
[ 1 ] + filters
[ 2 ] , filters
[ 1 ] , filters
[ 1 ] ) self
. conv2_1
= conv_block_nested
( filters
[ 2 ] + filters
[ 3 ] , filters
[ 2 ] , filters
[ 2 ] ) self
. conv3_1
= conv_block_nested
( filters
[ 3 ] + filters
[ 4 ] , filters
[ 3 ] , filters
[ 3 ] ) self
. conv0_2
= conv_block_nested
( filters
[ 0 ] * 2 + filters
[ 1 ] , filters
[ 0 ] , filters
[ 0 ] ) self
. conv1_2
= conv_block_nested
( filters
[ 1 ] * 2 + filters
[ 2 ] , filters
[ 1 ] , filters
[ 1 ] ) self
. conv2_2
= conv_block_nested
( filters
[ 2 ] * 2 + filters
[ 3 ] , filters
[ 2 ] , filters
[ 2 ] ) self
. conv0_3
= conv_block_nested
( filters
[ 0 ] * 3 + filters
[ 1 ] , filters
[ 0 ] , filters
[ 0 ] ) self
. conv1_3
= conv_block_nested
( filters
[ 1 ] * 3 + filters
[ 2 ] , filters
[ 1 ] , filters
[ 1 ] ) self
. conv0_4
= conv_block_nested
( filters
[ 0 ] * 4 + filters
[ 1 ] , filters
[ 0 ] , filters
[ 0 ] ) self
. final
= nn
. Conv2d
( filters
[ 0 ] , out_ch
, kernel_size
= 1 ) def forward ( self
, x
) : x0_0
= self
. conv0_0
( x
) x1_0
= self
. conv1_0
( self
. pool
( x0_0
) ) x0_1
= self
. conv0_1
( torch
. cat
( [ x0_0
, self
. Up
( x1_0
) ] , 1 ) ) x2_0
= self
. conv2_0
( self
. pool
( x1_0
) ) x1_1
= self
. conv1_1
( torch
. cat
( [ x1_0
, self
. Up
( x2_0
) ] , 1 ) ) x0_2
= self
. conv0_2
( torch
. cat
( [ x0_0
, x0_1
, self
. Up
( x1_1
) ] , 1 ) ) x3_0
= self
. conv3_0
( self
. pool
( x2_0
) ) x2_1
= self
. conv2_1
( torch
. cat
( [ x2_0
, self
. Up
( x3_0
) ] , 1 ) ) x1_2
= self
. conv1_2
( torch
. cat
( [ x1_0
, x1_1
, self
. Up
( x2_1
) ] , 1 ) ) x0_3
= self
. conv0_3
( torch
. cat
( [ x0_0
, x0_1
, x0_2
, self
. Up
( x1_2
) ] , 1 ) ) x4_0
= self
. conv4_0
( self
. pool
( x3_0
) ) x3_1
= self
. conv3_1
( torch
. cat
( [ x3_0
, self
. Up
( x4_0
) ] , 1 ) ) x2_2
= self
. conv2_2
( torch
. cat
( [ x2_0
, x2_1
, self
. Up
( x3_1
) ] , 1 ) ) x1_3
= self
. conv1_3
( torch
. cat
( [ x1_0
, x1_1
, x1_2
, self
. Up
( x2_2
) ] , 1 ) ) x0_4
= self
. conv0_4
( torch
. cat
( [ x0_0
, x0_1
, x0_2
, x0_3
, self
. Up
( x1_3
) ] , 1 ) ) output
= self
. final
( x0_4
) return output
2.8 Adversarial U-Net
2.9 Ensemble U-Net
2.10 Comparison With Other Architectures
總結(jié)
暫無(wú)
總結(jié)
以上是生活随笔 為你收集整理的U-Net及其变体 的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
如果覺(jué)得生活随笔 網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔 推薦給好友。