Pytorch自定义Dataset和DataLoader去除不存在和空的数据
Pytorch自定義Dataset和DataLoader去除不存在和空的數據
【源碼GitHub地址】:https://github.com/PanJinquan/pytorch-learning-tutorials/tree/master/image_classification/utils
覺得可以,麻煩給個”Star“
目錄
Pytorch自定義Dataset和DataLoader去除不存在和空的數據
1. 問題描述
2. 一般的解決方法
3. 另一種解決方法:自定義返回數據的規則:collate_fn()校隊函數
3.1 Pytorch數據處理函數:Dataset和?DataLoader
3.2 自定義collate_fn()函數:
1. 問題描述
? ? 之前寫了一篇關于《pytorch Dataset, DataLoader產生自定義的訓練數據》的博客,但存在一個問題,我們不能在Dataset做一些數據清理,如果我們傳遞給Dataset數據,本身存在問題,那么迭代過程肯定出錯的。
? ? 比如我把很多圖片路徑都傳遞給Dataset,如果圖片路徑都是正確的,且圖片都存在也沒有損壞,那顯然運行是沒有問題的;但倘若傳遞給Dataset的圖片路徑有些圖片是不存在,這時你通過Dataset讀取圖片數據,然后再迭代返回,就會出現類似如下的錯誤:
? File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in <listcomp>
? ? return [default_collate(samples) for samples in transposed]
? File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate
? ? raise TypeError((error_msg_fmt.format(type(batch[0]))))
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'NoneType'>
2. 一般的解決方法
一般的解決方法也很簡單粗暴,就是在傳遞數據給Dataset前,就做數據清理,把不存在的圖片,損壞的數據都提前清理掉。是的,這個是最簡單粗暴的
3. 另一種解決方法:自定義返回數據的規則:collate_fn()校對函數
我們希望不管傳遞什么處理給Dataset,Dataset都進行處理,如果不存在或者異常,就返回None,而在DataLoader時,對于不存為None的數據,都去除掉。這樣就保證在迭代過程中,DataLoader獲得batch數據都是正確的。比如讀取batch_size=5的圖片數據,如果其中有1個(或者多個)圖片是不存在,那么返回的batch應該把不存在的數據過濾掉,即返回5-1=4大小的batch的數據。
是的,我要實現的就是這個功能:返回的batch數據會自定清理掉不合法的數據。
3.1 Pytorch數據處理函數:Dataset和?DataLoader
Pytorch有兩個數據處理函數:Dataset和?DataLoader
from torch.utils.data import Dataset, DataLoader其中Dataset用于定義數據的讀取和預處理操作,而DataLoader用于加載并產生批訓練數據。
torch.utils.data.DataLoader參數說明:
DataLoader(object)可用參數:
我們要用到的是collate_fn()回調函數
3.2 自定義collate_fn()函數:
? ? torch.utils.data.DataLoader的collate_fn()用于設置batch數據拼接方式,默認是default_collate函數,但當batch中含有None等數據時,默認的default_collate校隊方法會出現錯誤。因此,我們需要自定義collate_fn()函數:
? ? 方法也很簡單:只需在原來的default_collate函數中添加下面幾句代碼:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了。
# 這里添加:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了if isinstance(batch, list):batch = [(image, image_id) for (image, image_id) in batch if image is not None]if batch==[]:return (None,None)dataset_collate.py:
# -*-coding: utf-8 -*- """@Project: pytorch-learning-tutorials@File : dataset_collate.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2019-06-07 17:09:13 """r""""Contains definitions of the methods used by the _DataLoaderIter workers to collate samples fetched from dataset into Tensor(s).These **needs** to be in global scope since Py2 doesn't support serializing static methods. """import torch import re from torch._six import container_abcs, string_classes, int_classes_use_shared_memory = False r"""Whether to use shared memory in default_collate"""np_str_obj_array_pattern = re.compile(r'[SaUO]')error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"numpy_type_map = {'float64': torch.DoubleTensor,'float32': torch.FloatTensor,'float16': torch.HalfTensor,'int64': torch.LongTensor,'int32': torch.IntTensor,'int16': torch.ShortTensor,'int8': torch.CharTensor,'uint8': torch.ByteTensor, }def collate_fn(batch):'''collate_fn (callable, optional): merges a list of samples to form a mini-batch.該函數參考touch的default_collate函數,也是DataLoader的默認的校對方法,當batch中含有None等數據時,默認的default_collate校隊方法會出現錯誤一種的解決方法是:判斷batch中image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了:param batch::return:'''r"""Puts each data field into a tensor with outer dimension batch size"""# 這里添加:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了if isinstance(batch, list):batch = [(image, image_id) for (image, image_id) in batch if image is not None]if batch==[]:return (None,None)elem_type = type(batch[0])if isinstance(batch[0], torch.Tensor):out = Noneif _use_shared_memory:# If we're in a background process, concatenate directly into a# shared memory tensor to avoid an extra copynumel = sum([x.numel() for x in batch])storage = batch[0].storage()._new_shared(numel)out = batch[0].new(storage)return torch.stack(batch, 0, out=out)elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \and elem_type.__name__ != 'string_':elem = batch[0]if elem_type.__name__ == 'ndarray':# array of string classes and objectif np_str_obj_array_pattern.search(elem.dtype.str) is not None:raise TypeError(error_msg_fmt.format(elem.dtype))return collate_fn([torch.from_numpy(b) for b in batch])if elem.shape == (): # scalarspy_type = float if elem.dtype.name.startswith('float') else intreturn numpy_type_map[elem.dtype.name](list(map(py_type, batch)))elif isinstance(batch[0], float):return torch.tensor(batch, dtype=torch.float64)elif isinstance(batch[0], int_classes):return torch.tensor(batch)elif isinstance(batch[0], string_classes):return batchelif isinstance(batch[0], container_abcs.Mapping):return {key: collate_fn([d[key] for d in batch]) for key in batch[0]}elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuplereturn type(batch[0])(*(collate_fn(samples) for samples in zip(*batch)))elif isinstance(batch[0], container_abcs.Sequence):transposed = zip(*batch)#okreturn [collate_fn(samples) for samples in transposed]raise TypeError((error_msg_fmt.format(type(batch[0]))))測試方法:
# -*-coding: utf-8 -*- """@Project: pytorch-learning-tutorials@File : dataset.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2019-03-07 18:45:06 """ import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader import numpy as np from utils import dataset_collate import os import cv2 from PIL import Image def read_image(path,mode='RGB'):''':param path::param mode: RGB or L:return:'''return Image.open(path).convert(mode)class TorchDataset(Dataset):def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None):''':param filename: 數據文件TXT:格式:imge_name.jpg label1_id labe2_id:param image_dir: 圖片路徑:image_dir+imge_name.jpg構成圖片的完整路徑:param resize_height 為None時,不進行縮放:param resize_width 為None時,不進行縮放,PS:當參數resize_height或resize_width其中一個為None時,可實現等比例縮放:param repeat: 所有樣本數據重復次數,默認循環一次,當repeat為None時,表示無限循環<sys.maxsize:param transform:預處理'''self.image_dir = image_dirself.image_id_list=image_id_listself.len = len(image_id_list)self.repeat = repeatself.resize_height = resize_heightself.resize_width = resize_widthself.transform= transformdef __getitem__(self, i):index = i % self.len# print("i={},index={}".format(i, index))image_id = self.image_id_list[index]image_path = os.path.join(self.image_dir, image_id)img = self.load_data(image_path)if img is None:return None,image_idimg = self.data_preproccess(img)return img,image_iddef __len__(self):if self.repeat == None:data_len = 10000000else:data_len = len(self.image_id_list) * self.repeatreturn data_lendef load_data(self, path):'''加載數據:param path::param resize_height::param resize_width::param normalization: 是否歸一化:return:'''try:image = read_image(path)except Exception as e:image=Noneprint(e)# image = image_processing.read_image(path)#用opencv讀取圖像return imagedef data_preproccess(self, data):'''數據預處理:param data::return:'''if self.transform is not None:data = self.transform(data)return dataif __name__=='__main__':resize_height = 224resize_width = 224image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]image_dir="../dataset/test_images/images"# 相關預處理的初始化'''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數據# 轉換成shape=(C,H,W)的像素數據,并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。'''train_transform = transforms.Compose([transforms.Resize(size=(resize_height, resize_width)),# transforms.RandomHorizontalFlip(),#隨機翻轉圖像transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 隨機裁剪transforms.ToTensor(), # 吧shape=(H,W,C)->換成shape=(C,H,W),并且歸一化到[0.0, 1.0]的torch.FloatTensor類型# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#給定均值(R,G,B) 方差(R,G,B),將會把Tensor正則化])epoch_num=2 #總樣本循環次數batch_size=5 #訓練時的一組數據的大小train_data_nums=10max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #總迭代次數train_data = TorchDataset(image_id_list=image_id_list,image_dir=image_dir,resize_height=resize_height,resize_width=resize_width,repeat=1,transform=train_transform)# 使用默認的default_collate會報錯# train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)# 使用自定義的collate_fntrain_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn)# [1]使用epoch方法迭代,TorchDataset的參數repeat=1for epoch in range(epoch_num):for step,(batch_image, batch_label) in enumerate(train_loader):if batch_image is None and batch_label is None:print("batch_image:{},batch_label:{}".format(batch_image, batch_label))continueimage=batch_image[0,:]image=image.numpy()#image=np.array(image)image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]cv2.imshow("image",image)cv2.waitKey(2000)print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))# batch_x, batch_y = Variable(batch_x), Variable(batch_y)?輸出結果說明:
batch_size=5,輸入圖片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"]?,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情況下返回的數據應該是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被過濾掉了,所以第一個batch的維度變為torch.Size([3, 3, 224, 224])
[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'
[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'
batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')
batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')
[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'
[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'
batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')
batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')
?
總結
以上是生活随笔為你收集整理的Pytorch自定义Dataset和DataLoader去除不存在和空的数据的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: from torch._C import
- 下一篇: Pytorch模型迁移和迁移学习,导入部