TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)
***************************************************
碼字不易,收藏之余,別忘了給我點(diǎn)個(gè)贊吧!
***************************************************
---------Start
官方代碼:https://github.com/Beckschen/TransUNet
目的:訓(xùn)練5個(gè)類別的汽車部件分割任務(wù)(測(cè)試在另一篇博客中)
CSDN數(shù)據(jù)集免費(fèi)下載
實(shí)現(xiàn)效果:
1. github下載代碼,并解壓。
項(xiàng)目里的文件可能跟你下載的不一樣,不急后面會(huì)講到!
2. 配置數(shù)據(jù)集(盡最大努力還原官方數(shù)據(jù)集的格式)。
通常自己手上的數(shù)據(jù)集分images和labels文件夾,分別存放著原始圖像和對(duì)應(yīng)的mask圖像,如下圖所示; mask圖像中的像素有0,1,2,3,4 分別代表背景,車身,輪子,車燈,窗戶,一共五個(gè)類別,所以這里顯示全黑色,肉眼看不出差別!通過閱讀官方讀取數(shù)據(jù)的代碼,我們需要將一張圖像和其對(duì)應(yīng)的標(biāo)簽合并轉(zhuǎn)化成一個(gè).npz文件.
官方數(shù)據(jù)集格式,data文件夾,Synapse文件夾,test_vol_h5文件夾,train_npz文件夾手動(dòng)創(chuàng)建!
轉(zhuǎn)化數(shù)據(jù)集的代碼如下,會(huì)將images中的圖像和labels中的標(biāo)簽生成一個(gè).npz文件。
def npz():#圖像路徑path = r'G:\dataset\car-segmentation\train\images\*.png'#項(xiàng)目中存放訓(xùn)練所用的npz文件路徑path2 = r'G:\dataset\Unet\TransUnet-ori\data\Synapse\train_npz\\'for i,img_path in enumerate(glob.glob(path)):#讀入圖像image = cv2.imread(img_path)image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)#讀入標(biāo)簽label_path = img_path.replace('images','labels')label = cv2.imread(label_path,flags=0)#保存npznp.savez(path2+str(i),image=image,label=label)print('------------',i)# 加載npz文件# data = np.load(r'G:\dataset\Unet\Swin-Unet-ori\data\Synapse\train_npz\0.npz', allow_pickle=True)# image, label = data['image'], data['label']print('ok')生成的文件在 data\Synapse\train_npz文件夾中,如下圖,也可以自己定義生成的路徑,然后把文件復(fù)制到data\Synapse\train_npz文件中。
data\Synapse\train_npz文件夾中存放的是訓(xùn)練集樣本,按照同樣的方式生成測(cè)試集樣本,存放在data\Synapse\test_vol_h5文件夾中。
我的訓(xùn)練集203個(gè)樣本,測(cè)試集3個(gè)樣本。npz文件生成完成之后,找到train.txt和test_vol.txt,手動(dòng)將文件里面的內(nèi)容清空,split_data.py這個(gè)文件直接無(wú)視。自己寫一個(gè)函數(shù)讀取train_npz中所有的文件名稱,然后將文件名稱寫入train.txt文件,一個(gè)名稱一行,如下圖所示。同理可完成test_vol.txt文件制作。
至此,數(shù)據(jù)集制作完畢!!!代碼會(huì)先去train.txt文件中讀取訓(xùn)練樣本的名稱,然后根據(jù)名稱再去train_npz文件夾下讀取npz文件。所以每一步都很重要,必須正確!
3. 下載預(yù)訓(xùn)練權(quán)重
官方下載地址
CSDN下載地址[推薦]
進(jìn)入網(wǎng)站后,點(diǎn)擊imagenet21k文件夾。
下載這個(gè)權(quán)重文件即可。
手動(dòng)創(chuàng)建如下多個(gè)文件夾,存放剛剛下載完畢的權(quán)重,注意名稱跟我的保持一致!
至此,預(yù)訓(xùn)練權(quán)重已下載完畢。
4. 修改讀取文件的方法
找到datasets/dataset_synapse.py文件中的Synapse_dataset類,修改__getitem__函數(shù)。
def __getitem__(self, idx):if self.split == "train":slice_name = self.sample_list[idx].strip('\n')data_path = self.data_dir+"/"+slice_name+'.npz'data = np.load(data_path)image, label = data['image'], data['label']else:slice_name = self.sample_list[idx].strip('\n')data_path = self.data_dir+"/"+slice_name+'.npz'data = np.load(data_path)image, label = data['image'], data['label']image = torch.from_numpy(image.astype(np.float32))image = image.permute(2,0,1)label = torch.from_numpy(label.astype(np.float32))sample = {'image': image, 'label': label}if self.transform:sample = self.transform(sample)sample['case_name'] = self.sample_list[idx].strip('\n')return sample找到datasets/dataset_synapse.py文件中的RandomGenerator類,修改__call__函數(shù)。
def __call__(self, sample):image, label = sample['image'], sample['label']if random.random() > 0.5:image, label = random_rot_flip(image, label)elif random.random() > 0.5:image, label = random_rotate(image, label)x, y,_ = image.shapeif x != self.output_size[0] or y != self.output_size[1]:image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3) # why not 3?label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)image = torch.from_numpy(image.astype(np.float32))image = image.permute(2,0,1)label = torch.from_numpy(label.astype(np.float32))sample = {'image': image, 'label': label.long()}return sample至此,數(shù)據(jù)讀取的部分已經(jīng)修改完畢!
5. 配置訓(xùn)練參數(shù)
認(rèn)真檢查各個(gè)參數(shù)是否正確,這里的路徑都是 ‘./’(當(dāng)前目錄下),不是"…/",訓(xùn)練時(shí),batch_size通常大于1,我這里設(shè)置有誤!類別數(shù)可根據(jù)你的任務(wù)定!
圖片大小設(shè)置,越大越耗顯存。
6. 修改trainer.py文件
設(shè)置trainer.py文件中的DataLoader函數(shù)中的num_workers=0
至此,所有代碼修改完畢!
總結(jié):以上修改內(nèi)容針對(duì)彩色圖像的分割任務(wù), 由于僅文字表述某些操作存在局限性,故只能簡(jiǎn)略應(yīng)答,有任何問題可下方留言評(píng)論。
總結(jié)
以上是生活随笔為你收集整理的TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JSR303注解字段校验
- 下一篇: 面向对象---抽象和封装