【小白学习PyTorch教程】十九、 基于torch实现UNet 图像分割模型
@Author:Runsen
在圖像領域,除了分類,CNN 今天還用于更高級的問題,如圖像分割、對象檢測等。圖像分割是計算機視覺中的一個過程,其中圖像被分割成代表圖像中每個不同類別的不同段。
上面圖片一段代表貓,另一段代表背景。
從自動駕駛汽車到衛星,圖像分割在許多領域都很有用。其中最重要的是醫學成像。
UNet 是一種卷積神經網絡架構,在 CNN 架構幾乎沒有變化的情況下進行了擴展。它的發明是為了處理生物醫學圖像,其目標不僅是對是否存在感染進行分類,而且還要識別感染區域。
UNet
論文:https://arxiv.org/abs/1505.04597
UNet結構看起來像一個“U”,該架構由三部分組成:收縮部分、瓶頸部分和擴展部分。收縮段由許多收縮塊組成。每個塊接受一個輸入,應用兩個 3X3 卷積層,然后是 2X2 最大池化。每個塊之后的內核或特征圖的數量加倍,以便架構可以有效地學習復雜的結構。最底層介于收縮層和膨脹層之間。它使用兩個 3X3 CNN 層,然后是 2X2 上卷積層。
每個塊將輸入傳遞給兩個 3X3 CNN 層,然后是一個 2X2 上采樣層。同樣在每個塊之后,卷積層使用的特征圖數量減半以保持對稱性。然而,每次輸入也會附加相應收縮層的特征圖。此操作將確保在收縮圖像時學習的特征將用于重建它。擴展塊的數量與收縮塊的數量相同。之后,生成的映射通過另一個 3X3 CNN 層,特征映射的數量等于所需的片段數量。
torch實現
使用的數據集是:https://www.kaggle.com/paultimothymooney/chiu-2015
這個數據集用于分割糖尿病性黃斑水腫的光學相干斷層掃描圖像的圖像。
對于mat的數據,使用scipy.io.loadmat進行加載
下面使用 Pytorch 框架實現了 UNet 模型,代碼來源下面的Github:https://github.com/Hsankesara/DeepResearch
import torch from torch import nn import torch.nn.functional as F import torch.optim as optimclass UNet(nn.Module):def contracting_block(self, in_channels, out_channels, kernel_size=3):block = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),torch.nn.ReLU(),torch.nn.BatchNorm2d(out_channels),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),torch.nn.ReLU(),torch.nn.BatchNorm2d(out_channels),)return blockdef expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):block = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))return blockdef final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):block = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),torch.nn.ReLU(),torch.nn.BatchNorm2d(out_channels),)return blockdef __init__(self, in_channel, out_channel):super(UNet, self).__init__()#Encodeself.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)self.conv_encode2 = self.contracting_block(64, 128)self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)self.conv_encode3 = self.contracting_block(128, 256)self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)# Bottleneckself.bottleneck = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),torch.nn.ReLU(),torch.nn.BatchNorm2d(512),torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),torch.nn.ReLU(),torch.nn.BatchNorm2d(512),torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1))# Decodeself.conv_decode3 = self.expansive_block(512, 256, 128)self.conv_decode2 = self.expansive_block(256, 128, 64)self.final_layer = self.final_block(128, 64, out_channel)def crop_and_concat(self, upsampled, bypass, crop=False):if crop:c = (bypass.size()[2] - upsampled.size()[2]) // 2bypass = F.pad(bypass, (-c, -c, -c, -c))return torch.cat((upsampled, bypass), 1)def forward(self, x):# Encodeencode_block1 = self.conv_encode1(x)encode_pool1 = self.conv_maxpool1(encode_block1)encode_block2 = self.conv_encode2(encode_pool1)encode_pool2 = self.conv_maxpool2(encode_block2)encode_block3 = self.conv_encode3(encode_pool2)encode_pool3 = self.conv_maxpool3(encode_block3)# Bottleneckbottleneck1 = self.bottleneck(encode_pool3)# Decodedecode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)cat_layer2 = self.conv_decode3(decode_block3)decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)cat_layer1 = self.conv_decode2(decode_block2)decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)final_layer = self.final_layer(decode_block1)return final_layer上面代碼中的 UNet 模塊代表了 UNet 的整個架構。contraction_block和expansive_block分別用于創建收縮段和膨脹段。該函數crop_and_concat將收縮層的輸出與新的擴展層輸入相加。
unet = Unet(in_channel=1,out_channel=2) #out_channel represents number of segments desired criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99) optimizer.zero_grad() outputs = unet(inputs) # permute such that number of desired segments would be on 4th dimension outputs = outputs.permute(0, 2, 3, 1) m = outputs.shape[0] # Resizing the outputs and label to caculate pixel wise softmax loss outputs = outputs.resize(m*width_out*height_out, 2) labels = labels.resize(m*width_out*height_out) loss = criterion(outputs, labels) loss.backward() optimizer.step()對于該數據集解決標準教程代碼:https://www.kaggle.com/hsankesara/unet-image-segmentation
總結
以上是生活随笔為你收集整理的【小白学习PyTorch教程】十九、 基于torch实现UNet 图像分割模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 西安火车站可以带自热米饭吗?
- 下一篇: 100克香菇多糖需要多少香菇粉的提取?