TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)
***************************************************
碼字不易,收藏之余,別忘了給我點個贊吧!
***************************************************
---------Start
官方代碼:https://github.com/Beckschen/TransUNet
目的:訓練5個類別的汽車部件分割任務(測試在另一篇博客中)
CSDN數據集免費下載
實現效果:
1. github下載代碼,并解壓。
項目里的文件可能跟你下載的不一樣,不急后面會講到!
2. 配置數據集(盡最大努力還原官方數據集的格式)。
通常自己手上的數據集分images和labels文件夾,分別存放著原始圖像和對應的mask圖像,如下圖所示; mask圖像中的像素有0,1,2,3,4 分別代表背景,車身,輪子,車燈,窗戶,一共五個類別,所以這里顯示全黑色,肉眼看不出差別!通過閱讀官方讀取數據的代碼,我們需要將一張圖像和其對應的標簽合并轉化成一個.npz文件.
官方數據集格式,data文件夾,Synapse文件夾,test_vol_h5文件夾,train_npz文件夾手動創建!
轉化數據集的代碼如下,會將images中的圖像和labels中的標簽生成一個.npz文件。
def npz():#圖像路徑path = r'G:\dataset\car-segmentation\train\images\*.png'#項目中存放訓練所用的npz文件路徑path2 = r'G:\dataset\Unet\TransUnet-ori\data\Synapse\train_npz\\'for i,img_path in enumerate(glob.glob(path)):#讀入圖像image = cv2.imread(img_path)image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)#讀入標簽label_path = img_path.replace('images','labels')label = cv2.imread(label_path,flags=0)#保存npznp.savez(path2+str(i),image=image,label=label)print('------------',i)# 加載npz文件# data = np.load(r'G:\dataset\Unet\Swin-Unet-ori\data\Synapse\train_npz\0.npz', allow_pickle=True)# image, label = data['image'], data['label']print('ok')生成的文件在 data\Synapse\train_npz文件夾中,如下圖,也可以自己定義生成的路徑,然后把文件復制到data\Synapse\train_npz文件中。
data\Synapse\train_npz文件夾中存放的是訓練集樣本,按照同樣的方式生成測試集樣本,存放在data\Synapse\test_vol_h5文件夾中。
我的訓練集203個樣本,測試集3個樣本。npz文件生成完成之后,找到train.txt和test_vol.txt,手動將文件里面的內容清空,split_data.py這個文件直接無視。自己寫一個函數讀取train_npz中所有的文件名稱,然后將文件名稱寫入train.txt文件,一個名稱一行,如下圖所示。同理可完成test_vol.txt文件制作。
至此,數據集制作完畢!!!代碼會先去train.txt文件中讀取訓練樣本的名稱,然后根據名稱再去train_npz文件夾下讀取npz文件。所以每一步都很重要,必須正確!
3. 下載預訓練權重
官方下載地址
CSDN下載地址[推薦]
進入網站后,點擊imagenet21k文件夾。
下載這個權重文件即可。
手動創建如下多個文件夾,存放剛剛下載完畢的權重,注意名稱跟我的保持一致!
至此,預訓練權重已下載完畢。
4. 修改讀取文件的方法
找到datasets/dataset_synapse.py文件中的Synapse_dataset類,修改__getitem__函數。
def __getitem__(self, idx):if self.split == "train":slice_name = self.sample_list[idx].strip('\n')data_path = self.data_dir+"/"+slice_name+'.npz'data = np.load(data_path)image, label = data['image'], data['label']else:slice_name = self.sample_list[idx].strip('\n')data_path = self.data_dir+"/"+slice_name+'.npz'data = np.load(data_path)image, label = data['image'], data['label']image = torch.from_numpy(image.astype(np.float32))image = image.permute(2,0,1)label = torch.from_numpy(label.astype(np.float32))sample = {'image': image, 'label': label}if self.transform:sample = self.transform(sample)sample['case_name'] = self.sample_list[idx].strip('\n')return sample找到datasets/dataset_synapse.py文件中的RandomGenerator類,修改__call__函數。
def __call__(self, sample):image, label = sample['image'], sample['label']if random.random() > 0.5:image, label = random_rot_flip(image, label)elif random.random() > 0.5:image, label = random_rotate(image, label)x, y,_ = image.shapeif x != self.output_size[0] or y != self.output_size[1]:image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3) # why not 3?label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)image = torch.from_numpy(image.astype(np.float32))image = image.permute(2,0,1)label = torch.from_numpy(label.astype(np.float32))sample = {'image': image, 'label': label.long()}return sample至此,數據讀取的部分已經修改完畢!
5. 配置訓練參數
認真檢查各個參數是否正確,這里的路徑都是 ‘./’(當前目錄下),不是"…/",訓練時,batch_size通常大于1,我這里設置有誤!類別數可根據你的任務定!
圖片大小設置,越大越耗顯存。
6. 修改trainer.py文件
設置trainer.py文件中的DataLoader函數中的num_workers=0
至此,所有代碼修改完畢!
總結:以上修改內容針對彩色圖像的分割任務, 由于僅文字表述某些操作存在局限性,故只能簡略應答,有任何問題可下方留言評論。
總結
以上是生活随笔為你收集整理的TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JSR303注解字段校验
- 下一篇: 面向对象---抽象和封装