Swintransformer详细设计文档
1、文件說明
Model.py:構建模型
My_dataset.py:數據集處理
Predict.py:預測圖片分類類別
Train.py:訓練網絡
Utils.py:
2、項目結構和函數設計
Model.py 的類
class DropPath(nn.Module)def forward(self, x) class PatchEmbed(nn.Module)def forward(self, x) class PatchMerging(nn.Module):def forward(self, x, H, W) class Mlp(nn.Module):def forward(self, x): class WindowAttention(nn.Module):def forward(self, x, mask: Optional[torch.Tensor] = None): class SwinTransformerBlock(nn.Module):def forward(self, x, attn_mask): class BasicLayer(nn.Module):def create_mask(self, x, H, W):def forward(self, x, H, W): class SwinTransformer(nn.Module):def _init_weights(self, m):def forward(self, x)Model.py 的函數
def drop_path_f(x, drop_prob: float = 0., training: bool = False) def window_partition(x, window_size: int) def window_reverse(windows, window_size: int, H: int, W: int) def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs): def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs): def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs): def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs): def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs): def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):My_dataset.py只有類
class MyDataSet(Dataset): ---def __len__(self): ---def __getitem__(self, item):@staticmethod ---def collate_fn(batch):Predict.py只有函數
def main(): if __name__ == '__main__':main()Train.py只有函數
def main(args): if __name__ == '__main__':。。。main(opt)Utils.py只有函數
def read_split_data(root: str, val_rate: float = 0.2): def plot_data_loader_image(data_loader): def write_pickle(list_info: list, file_name: str): def read_pickle(file_name: str) -> list: def train_one_epoch(model, optimizer, data_loader, device, epoch): @torch.no_grad() def evaluate(model, data_loader, device, epoch):Swin-Transformer 論文代碼介紹
1 開發環境
? Python 3.6
? torch 1.7.1
? GPU
2 功能設計
實驗數據集的說明:
數據來源
http://download.tensorflow.org/example_images/flower_photos.tgz
5類花的圖片做分類:
3670 images were found in the dataset.
2939 images for training.
731 images for validation.
Daisy:菊花
Dandelion:蒲公英
Roses:玫瑰
Sunflowers:向日葵
Tulips:郁金香
3 、文件說明
Model.py:構建模型
My_dataset.py:數據集處理
Predict.py:預測圖片分類類別
Train.py:訓練網絡
Utils.py:功能類函數
Model.py 的類
DropPath:設置各模塊內的dropout率
PatchEmbed:對圖片像素進行劃分patch
PatchMerging:對圖進行petch的拼接和線性映射
Mlp:SwinTransformerBlock后面一段的使用的
WindowAttention:window內部計算attention
SwinTransformerBlock:構建單個SwinTransformerBlock模型,該模型中含有W-MSA和SW-MSA兩個模塊
SwinTransformer:構建整個分類模型,這個類調用其他類,共同組成整個模型,從Patchpartion到LinearEmbedding(即類PatchEmbed),到四個SwinTransformerBlock,以及在SwinTransformerBlock中使用是否使用PatchMerging,經過四個階段的SwinTransformerBlock之后輸出展平的向量。
Model.py 的函數
window_partition:對特征圖進行劃分,劃分成一個一個沒有重疊的window
window_reverse:將window還原成特征圖
定義各種模型,用于實例化模型
swin_tiny_patch4_window7_224
swin_small_patch4_window7_224
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_base_patch4_window7_224_in22k
swin_base_patch4_window12_384_in22k
swin_large_patch4_window7_224_in22k
swin_large_patch4_window12_384_in22k
My_dataset.py只有類
MyDataSet(Dataset):構建獲取數據集中元素和大小的方法
@staticmethod
collate_fn(batch):用于單獨調用使用,將一個批次的圖片轉為向量并拼在一起
Predict.py只有函數
main(): 創建預測圖片類別的函數,展示預測的圖片以及被預測圖片屬于每個類別的概率
if name == ‘main’:
main()
開始預測
Train.py只有函數
main(args)
獲取訓練集和驗證集,對圖片進行處理,調整兩個數據集中圖片的大小,實例化模型,訓練模型,保存模型。
自定義參數,解析參數,調用并執行main(args),訓練分類模型
Utils.py只有函數
read_split_data:讀取圖片和圖片的類別,劃分訓練集和驗證集
train_one_epoch:
定義損失函數:torch.nn.CrossEntropyLoss()
進行一個epoch的訓練,返回損失和精確率
Evaluate
4 流程
運行train.py訓練模型,訓練了個epoch,最高精確率可到96.6%
5 效果演示
運行predict.py對單獨一張圖片進行預測類別
總結
以上是生活随笔為你收集整理的Swintransformer详细设计文档的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python-类的学习
- 下一篇: Pycharm-列出代码结构