生活随笔
收集整理的這篇文章主要介紹了
42_ResNet (深度残差网络)---学习笔记
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
1.39.ResNet (深度殘差網絡)
Why call Residual
import torch
.nn
as nn
import torch
.nn
.functional
as F
class ResBlk(nn
.Module
):def __init__(self
, ch_in
, ch_out
):self
.conv1
= nn
.Conv2d
(ch_in
, ch_out
, kernel_size
=3, stride
=1, padding
=1)self
.bn1
= nn
.BatchNorm2d
(ch_out
)self
.conv2
= nn
.Conv2d
(ch_out
, ch_out
, kernel_size
=3, stride
=1, padding
=1)self
.bn2
= nn
.BatchNorm2d
(ch_out
)self
.extra
= nn
.Sequential
()if ch_out
!= ch_in
:self
.extra
= nn
.Sequential
(nn
.Conv2d
(ch_in
, ch_out
, kernel_size
=1, stride
=1),nn
.BatchNorm2d
(ch_out
))def forward(self
, x
):out
= F
.relu
(self
.bn1
(self
.conv1
(x
)))out
= self
.bn2
(self
.conv2
(out
))out
= self
.extra
(x
) + out
return out
import torch
from torch
.nn
import functional
as F
from torch
.utils
.data
import DataLoader
from torchvision
import datasets
from torchvision
import transforms
from torch
import nn
, optim
class ResBlk(nn
.Module
):"""resnet block"""def __init__(self
, ch_in
, ch_out
):""":param ch_in::param ch_out:"""super(ResBlk
, self
).__init__
()self
.conv1
= nn
.Conv2d
(ch_in
, ch_out
, kernel_size
=3, stride
=1, padding
=1)self
.bn1
= nn
.BatchNorm2d
(ch_out
)self
.conv2
= nn
.Conv2d
(ch_out
, ch_out
, kernel_size
=3, stride
=1, padding
=1)self
.bn2
= nn
.BatchNorm2d
(ch_out
)self
.extra
= nn
.Sequential
()if ch_out
!= ch_in
:self
.extra
= nn
.Sequential
(nn
.Conv2d
(ch_in
, ch_out
, kernel_size
=1, stride
=1),nn
.BatchNorm2d
(ch_out
))def forward(self
, x
):""":param x: [b, ch, h, w]:return:"""out
= F
.relu
(self
.bn1
(self
.conv1
(x
)))out
= self
.bn2
(self
.conv2
(out
))out
= self
.extra
(x
) + out
return out
class ResNet18(nn
.Module
):def __init__(self
):super(ResNet18
, self
).__init__
()self
.conv1
= nn
.Sequential
(nn
.Conv2d
(3, 16, kernel_size
=3, stride
=1, padding
=1),nn
.BatchNorm2d
(16))self
.blk1
= ResBlk
(16, 16)self
.blk2
= ResBlk
(16, 32)self
.outlayer
= nn
.Linear
(32 * 32 * 32, 10)def forward(self
, x
):""":param x::return:"""x
= F
.relu
(self
.conv1
(x
))x
= self
.blk1
(x
)x
= self
.blk2
(x
)x
= x
.view
(x
.size
(0), -1)x
= self
.outlayer
(x
)return x
def main():batchsz
= 32cifar_train
= datasets
.CIFAR10
('cifar', True, transform
=transforms
.Compose
([transforms
.Resize
((32, 32)),transforms
.ToTensor
()]), download
=True)cifar_train
= DataLoader
(cifar_train
, batch_size
=batchsz
, shuffle
=True)cifar_test
= datasets
.CIFAR10
('cifar', False, transform
=transforms
.Compose
([transforms
.Resize
((32, 32)),transforms
.ToTensor
()]), download
=True)cifar_test
= DataLoader
(cifar_test
, batch_size
=batchsz
, shuffle
=True)x
, label
= iter(cifar_train
).next()print('x:', x
.shape
, 'label:', label
.shape
)device
= torch
.device
('cuda')model
= ResNet18
().to
(device
)criteon
= nn
.CrossEntropyLoss
().to
(device
)optimizer
= optim
.Adam
(model
.parameters
(), lr
=1e-3)print(model
)for epoch
in range(1000):model
.train
()for batchidx
, (x
, label
) in enumerate(cifar_train
):x
, label
= x
.to
(device
), label
.to
(device
)logits
= model
(x
)loss
= criteon
(logits
, label
)optimizer
.zero_grad
()loss
.backward
()optimizer
.step
()print(epoch
, 'loss:', loss
.item
())model
.eval()with torch
.no_grad
():total_correct
= 0total_num
= 0for x
, label
in cifar_test
:x
, label
= x
.to
(device
), label
.to
(device
)logits
= model
(x
)pred
= logits
.argmax
(dim
=1)correct
= torch
.eq
(pred
, label
).float().sum().item
()total_correct
+= correcttotal_num
+= x
.size
(0)acc
= total_correct
/ total_num
print(epoch
, 'acc:', acc
)if __name__
== '__main__':main
()
總結
以上是生活随笔為你收集整理的42_ResNet (深度残差网络)---学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。