生活随笔
收集整理的這篇文章主要介紹了
手把手带你撸深度学习经典模型(一)----- UNet
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
一、前言
經(jīng)過慎重考慮,決定新開一個系列,該系列文章主要的目的就是利用PyTorch、Python實現(xiàn)深度學(xué)習(xí)中的一些經(jīng)典模型,接下來一段時間的安排如下:
本文首先實現(xiàn)UNet,關(guān)于UNet的詳細介紹請移步深度學(xué)習(xí)模型解析系列文章–白話詳解UNet
二、網(wǎng)絡(luò)結(jié)構(gòu)詳解
UNet總體上分為編碼器和解碼器,其中編碼器負責(zé)提取特征信息,解碼器負責(zé)還原特征信息;編碼器主要由4個塊組成,每個塊分別由2個卷積層、1個最大池化層組成。解碼器也是由4個塊組成,每個塊都是由1個上采樣層、2個卷積層組成,詳細信息請見下圖。
三、網(wǎng)絡(luò)組成部分實現(xiàn)
import torch
import torch
.nn
as nn
import torch
.nn
.functional
as F
- 第2步:我們需要自定義一個卷積的基礎(chǔ)塊,該基礎(chǔ)塊由2個卷積層組成。
class DoubleConv(nn
.Module
):"""(convolution => [BN] => ReLU) * 2"""def __init__(self
, in_channels
, out_channels
, mid_channels
=None):super().__init__
()if not mid_channels
:mid_channels
= out_channelsself
.double_conv
= nn
.Sequential
(nn
.Conv2d
(in_channels
, mid_channels
, kernel_size
=3, padding
=1),nn
.BatchNorm2d
(mid_channels
),nn
.ReLU
(inplace
=True),nn
.Conv2d
(mid_channels
, out_channels
, kernel_size
=3, padding
=1),nn
.BatchNorm2d
(out_channels
),nn
.ReLU
(inplace
=True))def forward(self
, x
):return self
.double_conv
(x
)
- 第3步:我們需要自定義一個編碼器的基礎(chǔ)塊,該塊由1個最大池化層和第2步的卷積基礎(chǔ)塊組成。
class Down(nn
.Module
):"""Downscaling with maxpool then double conv"""def __init__(self
, in_channels
, out_channels
):super().__init__
()self
.maxpool_conv
= nn
.Sequential
(nn
.MaxPool2d
(2),DoubleConv
(in_channels
, out_channels
))def forward(self
, x
):return self
.maxpool_conv
(x
)
- 第4步:我們需要自定義一個解碼器的基礎(chǔ)塊,該基礎(chǔ)塊由1個上采樣層和2個卷積層組成。
class Up(nn
.Module
):"""Upscaling then double conv"""def __init__(self
, in_channels
, out_channels
, bilinear
=True):super().__init__
()if bilinear
:self
.up
= nn
.Upsample
(scale_factor
=2, mode
='bilinear', align_corners
=True) self
.conv
= DoubleConv
(in_channels
, out_channels
, in_channels
// 2)else:self
.up
= nn
.ConvTranspose2d
(in_channels
, in_channels
// 2, kernel_size
=2, stride
=2) self
.conv
= DoubleConv
(in_channels
, out_channels
)def forward(self
, x1
, x2
):x1
= self
.up
(x1
)diffY
= x2
.size
()[2] - x1
.size
()[2]diffX
= x2
.size
()[3] - x1
.size
()[3]x1
= F
.pad
(x1
, [diffX
// 2, diffX
- diffX
// 2,diffY
// 2, diffY
- diffY
// 2])x
= torch
.cat
([x2
, x1
], dim
=1)return self
.conv
(x
)
class OutConv(nn
.Module
):def __init__(self
, in_channels
, out_channels
):super(OutConv
, self
).__init__
()self
.conv
= nn
.Conv2d
(in_channels
, out_channels
, kernel_size
=1)def forward(self
, x
):return self
.conv
(x
)
四、網(wǎng)絡(luò)結(jié)構(gòu)實現(xiàn)
- 第1步:我們需要把上述定義的類一股腦的導(dǎo)入到你要定義的網(wǎng)絡(luò)文件中,因為每個人的文件夾不同,這里就不詳細講述。
- 第2步:初始化你的網(wǎng)絡(luò)模型參數(shù)
- 第3步:編寫前向傳播方法
class UNet(nn
.Module
):def __init__(self
, args
, n_channels
, n_classes
, bilinear
=True):super(UNet
, self
).__init__
() self
.n_channels
= n_channelsself
.n_classes
= n_classesself
.bilinear
= bilinear
"""DoubleConv <-> (convolution => [BN] => ReLU) * 2"""self
.inc
= DoubleConv
(n_channels
, 64)self
.down1
= Down
(64, 128)self
.down2
= Down
(128, 256)self
.down3
= Down
(256, 512)factor
= 2 if bilinear
else 1self
.down4
= Down
(512, 1024 // factor
)self
.up1
= Up
(1024, 512 // factor
, bilinear
)self
.up2
= Up
(512, 256 // factor
, bilinear
)self
.up3
= Up
(256, 128 // factor
, bilinear
)self
.up4
= Up
(128, 64, bilinear
)self
.outc
= OutConv
(64, n_classes
)def forward(self
, x
):x1
= self
.inc
(x
)x2
= self
.down1
(x1
)x3
= self
.down2
(x2
)x4
= self
.down3
(x3
)x5
= self
.down4
(x4
)x
= self
.up1
(x5
, x4
)x
= self
.up2
(x
, x3
)x
= self
.up3
(x
, x2
)x
= self
.up4
(x
, x1
)logits
= self
.outc
(x
)return logits
至此,UNet經(jīng)典網(wǎng)絡(luò)結(jié)構(gòu)就編寫好了,是不是非常的簡單呢?如果您覺得寫的還不錯,歡迎一鍵三連,這對我真的幫助很大,非常感謝!我也會繼續(xù)努力,提升文章的質(zhì)量與數(shù)量!
總結(jié)
以上是生活随笔為你收集整理的手把手带你撸深度学习经典模型(一)----- UNet的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網(wǎng)站內(nèi)容還不錯,歡迎將生活随笔推薦給好友。