swin_transformer用于做图像分类(已跑通)
生活随笔
收集整理的這篇文章主要介紹了
swin_transformer用于做图像分类(已跑通)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
B站大佬:霹靂吧啦Wz視頻:12.2 使用Pytorch搭建Swin-Transformer網絡
講解鏈接:https://www.bilibili.com/video/BV1yg411K7Yc?spm_id_from=333.999.0.0
swin_transformer用于做圖像分類的任務鏈接:
https://github.com/Ydjiao/deep-learning-for-image-processing/tree/master/pytorch_classification/swin_transformer
數據集的類別數
train.py運行結果:
predit.py運行結果:
C:\Users\deep\anaconda3\envs\swin\python.exe D:/***/swin_transformer_flower/predict.py class: daisy prob: 1.0 class: dandelion prob: 6.63e-06 class: roses prob: 1.39e-05 class: sunflowers prob: 4.82e-05 class: tulips prob: 2.21e-06 Traceback (most recent call last):File "D:/jiaoyidi/swin_transformer_flower/predict.py", line 69, in <module>main()File "D:/jiaoyidi/swin_transformer_flower/predict.py", line 65, in mainplt.show()File "C:\Users\deep\anaconda3\envs\swin\lib\site-packages\matplotlib\pyplot.py", line 368, in showreturn _backend_mod.show(*args, **kwargs)File "C:\Users\deep\anaconda3\envs\swin\lib\site-packages\matplotlib\backend_bases.py", line 3544, in showcls.mainloop()File "C:\Users\deep\anaconda3\envs\swin\lib\site-packages\matplotlib\backends\_backend_tk.py", line 958, in mainloopfirst_manager.window.mainloop()File "C:\Users\deep\anaconda3\envs\swin\lib\tkinter\__init__.py", line 1283, in mainloopself.tk.mainloop(n) KeyboardInterruptProcess finished with exit code 1自己的實現:
model.py
# -*- coding: utf-8 -*- """ Created on Tue Dec 7 10:37:32 2021@author: admin """""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`- https://arxiv.org/pdf/2103.14030Code/weights from https://github.com/microsoft/Swin-Transformer"""import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import numpy as np from typing import Optionaldef drop_path_f(x, drop_prob: float = 0., training: bool = False):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted forchanging the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use'survival rate' as the argument."""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNetsrandom_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() # binarizeoutput = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path_f(x, self.drop_prob, self.training)def window_partition(x, window_size: int):"""將feature map按照window_size劃分成一個個沒有重疊的windowArgs:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)# permute: [B, H//Mh, Mh, W//Mw, Mw, C] --permute--調換2和3> [B, H//Mh, W//Mh, Mw, Mw, C]# view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size: int, H: int, W: int):#H: int, W: int對應的是圖片分割之前的H,W"""將一個個window還原成一個feature mapArgs:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window size(M)H (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))# view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)# permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]# view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]# contiguous()變成內存連續的形式x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self,patch_size=4, #下采樣的倍率in_c=3,#圖片的通道數embed_dim=96,#轉的維度norm_layer=None):#nn.LayerNormsuper().__init__()patch_size = (patch_size, patch_size)self.patch_size = patch_sizeself.in_chans = in_cself.embed_dim = embed_dimself.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):#獲取輸入圖片的高度和寬度_, _, H, W = x.shape# padding# 如果輸入圖片的H,W不是patch_size的整數倍,需要進行paddingpad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)#if Trueif pad_input:#查看官方文檔,查看Padding方法# to pad the last 3 dimensions,# (W_left, W_right, H_top,H_bottom, C_front, C_back)x = F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))#### 下采樣patch_size倍x = self.proj(x)_, _, H, W = x.shape# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = x.flatten(2).transpose(1, 2)#對第二個維度進行展平,即H#轉置:對第一個第二個維度進行轉換x = self.norm(x)###return x, H, Wclass PatchMerging(nn.Module):r""" Patch Merging Layer.Args:dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm"""def __init__(self, dim, norm_layer=nn.LayerNorm):super().__init__()self.dim = dim#4 * dim是指經過LayerNorm之后的特征圖是原始特征圖的四倍,再經過Linear層之后的特征圖的通道數self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)#4 * dim是指經過LayerNorm之后的特征圖是原始特征圖的四倍self.norm = norm_layer(4 * dim)def forward(self, x, H, W):"""x: B, H*W, C"""B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)#### padding# 如果輸入feature map的H,W不是2的整數倍,需要進行paddingpad_input = (H % 2 == 1) or (W % 2 == 1)if pad_input:# to pad the last 3 dimensions, starting from the last dimension and moving forward.# (C_front, C_back, W_left, W_right, H_top, H_bottom)# 注意這里的Tensor通道是[B, H, W, C],所以會和官方文檔有些不同x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))#拿出藍色的方塊拼在一起x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]# 拿出黃色的方塊拼在一起x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]# 拿出綠色的方塊拼在一起x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]# 拿出紅色的方塊拼在一起x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]#拼接x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]###x = self.norm(x)###x = self.reduction(x) # [B, H/2*W/2, 2*C]return xclass Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.drop1 = nn.Dropout(drop)self.fc2 = nn.Linear(hidden_features, out_features)self.drop2 = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size # [Mh, Mw]self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]# 以下11行代碼是計算相對位置編碼# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])# coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # [2, Mh, Mw]coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]#第一個維度進行展平# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]#把兩維的位置編碼變成1維的relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask: Optional[torch.Tensor] = None):"""Args:x: input features with shape of (num_windows*B, Mh*Mw, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""# [batch_size*num_windows, Mh*Mw, total_embed_dim]B_, N, C = x.shape# qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]# reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]q = q * self.scale#這個地方就是計算相似度矩陣后再除以根號d的值attn = (q @ k.transpose(-2, -1))#根據相對位置編碼表進行選取向量的過程# relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]attn = attn + relative_position_bias.unsqueeze(0)###permute、contiguous、review、unsqueezeif mask is not None:# mask: [nW, Mh*Mw, Mh*Mw]nW = mask.shape[0] # num_windows# attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]# mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]# transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]# reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.window_size (int): Window size.shift_size (int): Shift size for SW-MSA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm"""def __init__(self, dim, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioassert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, x, attn_mask):H, W = self.H, self.WB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map給pad到window size的整數倍pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shape# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = xattn_mask = None# partition windowsx_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xif pad_r > 0 or pad_b > 0:# 把前面pad的數據移除掉x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass BasicLayer(nn.Module):"""A basic Swin Transformer layer for one stage.Args:dim (int): Number of input channels.depth (int): Number of blocks.num_heads (int): Number of attention heads.window_size (int): Local window size.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.depth = depthself.window_size = window_sizeself.use_checkpoint = use_checkpointself.shift_size = window_size // 2# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim,num_heads=num_heads,window_size=window_size,shift_size=0 if (i % 2 == 0) else self.shift_size,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,drop=drop,attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])# patch merging layerif downsample is not None:self.downsample = downsample(dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef create_mask(self, x, H, W):# calculate attention mask for SW-MSA# 保證Hp和Wp是window_size的整數倍Hp = int(np.ceil(H / self.window_size)) * self.window_sizeWp = int(np.ceil(W / self.window_size)) * self.window_size# 擁有和feature map一樣的通道排列順序,方便后續window_partitionimg_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]#對于相減為0的不動,不為0的將其值置為-100# [nW, Mh*Mw, Mh*Mw]attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_maskdef forward(self, x, H, W):attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw]for blk in self.blocks:blk.H, blk.W = H, Wif not torch.jit.is_scripting() and self.use_checkpoint:x = checkpoint.checkpoint(blk, x, attn_mask)else:x = blk(x, attn_mask)if self.downsample is not None:x = self.downsample(x, H, W)H, W = (H + 1) // 2, (W + 1) // 2return x, H, Wclass SwinTransformer(nn.Module):#類SwinTransformer繼承nn.Module這個父類r""" Swin TransformerA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -https://arxiv.org/pdf/2103.14030Args:patch_size (int | tuple(int)): Patch size. Default: 4in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000embed_dim (int): Patch embedding dimension. Default: 96depths (tuple(int)): Depth of each Swin Transformer layer.num_heads (tuple(int)): Number of attention heads in different layers.window_size (int): Window size. Default: 7mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truedrop_rate (float): Dropout rate. Default: 0attn_drop_rate (float): Attention dropout rate. Default: 0drop_path_rate (float): Stochastic depth rate. Default: 0.1norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.patch_norm (bool): If True, add normalization after patch embedding. Default: Trueuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False"""def __init__(self, patch_size=4,in_chans=3, #表示輸入圖片的通道數num_classes=1000,#分類的類別數embed_dim=96,#對應于第一個SwinTransformer模塊之后的維度depths=(2, 2, 6, 2),#每個stage中的SwinTransformer模塊的個數num_heads=(3, 6, 12, 24),#每個stage中的SwinTransformer模塊中每個Multihead的頭的個數window_size=7,#這是在MSA和WMSA中采用的窗口的大小mlp_ratio=4.,#MLP的第一步會有一個線性層,這是翻的倍數qkv_bias=True,#這是問是否在MSA中使用偏置drop_rate=0.,attn_drop_rate=0.,#MSA中使用使用的dropratedrop_path_rate=0.1,#是在SwinTransformer模塊中使用的droprate,逐漸遞增norm_layer=nn.LayerNorm,patch_norm=True,use_checkpoint=False, **kwargs):super().__init__()self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_norm# stage4輸出特征矩陣的channels = 96*2**(4-1)self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.mlp_ratio = mlp_ratio# split image into non-overlapping patches#對應的是模型中Patch Partition和Linear Embeddingself.patch_embed = PatchEmbed(patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)self.pos_drop = nn.Dropout(p=drop_rate)#創建了dropout層#這個droprate是逐漸上升的,直到上升到預先設定的drop_path_rate# stochastic depthdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]#### stochastic depth decay rule#torch.linspace是指定初始和末尾的值,后面的是每個步數,一點一點往上增加# build layersself.layers = nn.ModuleList()####經過這個for循環之后會創建所有的layer層for i_layer in range(self.num_layers):# 注意這里構建的stage和論文圖中有些差異# 這里的stage不包含該stage的patch_merging層,包含的是下個stage的patch_merging層layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),#C、2C、4C、8Cdepth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,#downsample是針對第四個層設置的,因為第四個階段沒有PathMerging,因此需要判斷一下downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint)self.layers.append(layers)self.norm = norm_layer(self.num_features)self.avgpool = nn.AdaptiveAvgPool1d(1)# num_features = 4# num_classes = 1000self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self, x):# x: [B, L, C]x, H, W = self.patch_embed(x)x = self.pos_drop(x)for layer in self.layers:x, H, W = layer(x, H, W)x = self.norm(x) # [B, L, C]x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]###x = torch.flatten(x, 1)x = self.head(x)return xdef swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):# trained ImageNet-1K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 6, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes,**kwargs)return modeldef swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):# trained ImageNet-1K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=96,depths=(2, 2, 18, 2),num_heads=(3, 6, 12, 24),num_classes=num_classes,**kwargs)return modeldef swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):# trained ImageNet-1K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes,**kwargs)return modeldef swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):# trained ImageNet-1K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes,**kwargs)return modeldef swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):# trained ImageNet-22K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes,**kwargs)return modeldef swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):# trained ImageNet-22K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=128,depths=(2, 2, 18, 2),num_heads=(4, 8, 16, 32),num_classes=num_classes,**kwargs)return modeldef swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):# trained ImageNet-22K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=7,embed_dim=192,depths=(2, 2, 18, 2),num_heads=(6, 12, 24, 48),num_classes=num_classes,**kwargs)return modeldef swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):# trained ImageNet-22K# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pthmodel = SwinTransformer(in_chans=3,patch_size=4,window_size=12,embed_dim=192,depths=(2, 2, 18, 2),num_heads=(6, 12, 24, 48),num_classes=num_classes,**kwargs)return modelmy_dataset.py
# -*- coding: utf-8 -*- """ Created on Tue Dec 7 10:38:17 2021@author: admin """from PIL import Image import torch from torch.utils.data import Datasetclass MyDataSet(Dataset):"""自定義數據集"""def __init__(self, images_path: list, images_class: list, transform=None):self.images_path = images_pathself.images_class = images_classself.transform = transformdef __len__(self):return len(self.images_path)def __getitem__(self, item):img = Image.open(self.images_path[item])# RGB為彩色圖片,L為灰度圖片if img.mode != 'RGB':raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))label = self.images_class[item]if self.transform is not None:img = self.transform(img)return img, label@staticmethoddef collate_fn(batch):# 官方實現的default_collate可以參考# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.pyimages, labels = tuple(zip(*batch))images = torch.stack(images, dim=0)labels = torch.as_tensor(labels)return images, labelspredict.py
# -*- coding: utf-8 -*- """ Created on Tue Dec 7 10:40:32 2021@author: admin """import os import jsonimport torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt #導入對應的模型 from model import swin_tiny_patch4_window7_224 as create_modeldef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")img_size = 224data_transform = transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image,img_path這個地址是要預測的圖片的文件位置img_path = "../1.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)json_file = open(json_path, "r")class_indict = json.load(json_file)# create model# num_classes=5是分類的類別個數model = create_model(num_classes=5).to(device)# load model weightsmodel_weight_path = "./weights/model-9.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()train.py
# -*- coding: utf-8 -*- """ Created on Tue Dec 7 10:41:26 2021@author: admin """import os import argparseimport torch import torch.optim as optim from torch.utils.tensorboard import SummaryWriter from torchvision import transformsfrom my_dataset import MyDataSet from model import swin_tiny_patch4_window7_224 as create_model from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 實例化訓練數據集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 實例化驗證數據集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)["model"]# 刪除有關分類類別的權重for k in list(weights_dict.keys()):#這個地方是刪除分類類別相關的權重if "head" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 設置為True的話就會凍結除head外,其他全部權重if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=10)parser.add_argument('--batch-size', type=int, default=8)parser.add_argument('--lr', type=float, default=0.0001)# 數據集所在根目錄# http://download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default="./data/flower_photos")# 預訓練權重路徑,如果不想載入就設置為空字符parser.add_argument('--weights', type=str, default='./swin_tiny_patch4_window7_224.pth',help='initial weights path')# 是否凍結權重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)utils.py
# -*- coding: utf-8 -*- """ Created on Tue Dec 7 10:42:24 2021@author: admin """import os import sys import json import pickle import randomimport torch from tqdm import tqdmimport matplotlib.pyplot as pltdef read_split_data(root: str, val_rate: float = 0.2):random.seed(0) # 保證隨機結果可復現assert os.path.exists(root), "dataset root: {} does not exist.".format(root)# 遍歷文件夾,一個文件夾對應一個類別flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]# 排序,保證順序一致flower_class.sort()# 生成類別名稱以及對應的數字索引class_indices = dict((k, v) for v, k in enumerate(flower_class))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)train_images_path = [] # 存儲訓練集的所有圖片路徑train_images_label = [] # 存儲訓練集圖片對應索引信息val_images_path = [] # 存儲驗證集的所有圖片路徑val_images_label = [] # 存儲驗證集圖片對應索引信息every_class_num = [] # 存儲每個類別的樣本總數supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后綴類型# 遍歷每個文件夾下的文件for cla in flower_class:cla_path = os.path.join(root, cla)# 遍歷獲取supported支持的所有文件路徑images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)if os.path.splitext(i)[-1] in supported]# 獲取該類別對應的索引image_class = class_indices[cla]# 記錄該類別的樣本數量every_class_num.append(len(images))# 按比例隨機采樣驗證樣本val_path = random.sample(images, k=int(len(images) * val_rate))for img_path in images:if img_path in val_path: # 如果該路徑在采樣的驗證集樣本中則存入驗證集val_images_path.append(img_path)val_images_label.append(image_class)else: # 否則存入訓練集train_images_path.append(img_path)train_images_label.append(image_class)print("{} images were found in the dataset.".format(sum(every_class_num)))print("{} images for training.".format(len(train_images_path)))print("{} images for validation.".format(len(val_images_path)))plot_image = Falseif plot_image:# 繪制每種類別個數柱狀圖plt.bar(range(len(flower_class)), every_class_num, align='center')# 將橫坐標0,1,2,3,4替換為相應的類別名稱plt.xticks(range(len(flower_class)), flower_class)# 在柱狀圖上添加數值標簽for i, v in enumerate(every_class_num):plt.text(x=i, y=v + 5, s=str(v), ha='center')# 設置x坐標plt.xlabel('image class')# 設置y坐標plt.ylabel('number of images')# 設置柱狀圖的標題plt.title('flower class distribution')plt.show()return train_images_path, train_images_label, val_images_path, val_images_labeldef plot_data_loader_image(data_loader):batch_size = data_loader.batch_sizeplot_num = min(batch_size, 4)json_path = './class_indices.json'assert os.path.exists(json_path), json_path + " does not exist."json_file = open(json_path, 'r')class_indices = json.load(json_file)for data in data_loader:images, labels = datafor i in range(plot_num):# [C, H, W] -> [H, W, C]img = images[i].numpy().transpose(1, 2, 0)# 反Normalize操作img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255label = labels[i].item()plt.subplot(1, plot_num, i+1)plt.xlabel(class_indices[str(label)])plt.xticks([]) # 去掉x軸的刻度plt.yticks([]) # 去掉y軸的刻度plt.imshow(img.astype('uint8'))plt.show()def write_pickle(list_info: list, file_name: str):with open(file_name, 'wb') as f:pickle.dump(list_info, f)def read_pickle(file_name: str) -> list:with open(file_name, 'rb') as f:info_list = pickle.load(f)return info_listdef train_one_epoch(model, optimizer, data_loader, device, epoch):model.train()loss_function = torch.nn.CrossEntropyLoss()accu_loss = torch.zeros(1).to(device) # 累計損失accu_num = torch.zeros(1).to(device) # 累計預測正確的樣本數optimizer.zero_grad()sample_num = 0data_loader = tqdm(data_loader)for step, data in enumerate(data_loader):images, labels = datasample_num += images.shape[0]pred = model(images.to(device))pred_classes = torch.max(pred, dim=1)[1]accu_num += torch.eq(pred_classes, labels.to(device)).sum()loss = loss_function(pred, labels.to(device))loss.backward()accu_loss += loss.detach()data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,accu_loss.item() / (step + 1),accu_num.item() / sample_num)if not torch.isfinite(loss):print('WARNING: non-finite loss, ending training ', loss)sys.exit(1)optimizer.step()optimizer.zero_grad()return accu_loss.item() / (step + 1), accu_num.item() / sample_num@torch.no_grad() def evaluate(model, data_loader, device, epoch):loss_function = torch.nn.CrossEntropyLoss()model.eval()accu_num = torch.zeros(1).to(device) # 累計預測正確的樣本數accu_loss = torch.zeros(1).to(device) # 累計損失sample_num = 0data_loader = tqdm(data_loader)for step, data in enumerate(data_loader):images, labels = datasample_num += images.shape[0]pred = model(images.to(device))pred_classes = torch.max(pred, dim=1)[1]accu_num += torch.eq(pred_classes, labels.to(device)).sum()loss = loss_function(pred, labels.to(device))accu_loss += lossdata_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,accu_loss.item() / (step + 1),accu_num.item() / sample_num)return accu_loss.item() / (step + 1), accu_num.item() / sample_num環境設置
(swin) C:\Users\deep>pip list Package Version ----------------------- --------- absl-py 1.0.0 apex 0.1 cachetools 4.2.4 certifi 2021.10.8 charset-normalizer 2.0.9 colorama 0.4.4 cycler 0.11.0 fonttools 4.28.3 google-auth 2.3.3 google-auth-oauthlib 0.4.6 grpcio 1.42.0 idna 3.3 importlib-metadata 4.8.2 kiwisolver 1.3.2 Markdown 3.3.6 matplotlib 3.5.1 mkl-fft 1.0.12 mkl-random 1.1.1 numpy 1.21.4 oauthlib 3.1.1 olefile 0.46 opencv-python 4.4.0.46 packaging 21.3 Pillow 8.4.0 pip 21.2.4 protobuf 3.19.1 pyasn1 0.4.8 pyasn1-modules 0.2.8 pyparsing 3.0.6 python-dateutil 2.8.2 PyYAML 6.0 requests 2.26.0 requests-oauthlib 1.3.0 rsa 4.8 setuptools 59.5.0 six 1.16.0 tensorboard 2.7.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.0 termcolor 1.1.0 timm 0.3.2 torch 1.7.1 torchvision 0.8.2 tqdm 4.62.3 typing-extensions 3.10.0.2 urllib3 1.26.7 Werkzeug 2.0.2 wheel 0.37.0 wincertstore 0.2 yacs 0.1.8 zipp 3.6.0總結
以上是生活随笔為你收集整理的swin_transformer用于做图像分类(已跑通)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 算法和数据结构
- 下一篇: Numpy学习-数组的索引