日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 编程资源 > 编程问答 >内容正文

编程问答

pytorch Dataset, DataLoader产生自定义的训练数据

發布時間:2024/4/15 编程问答 38 豆豆
生活随笔 收集整理的這篇文章主要介紹了 pytorch Dataset, DataLoader产生自定义的训练数据 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

pytorch Dataset, DataLoader產生自定義的訓練數據


目錄

pytorch Dataset, DataLoader產生自定義的訓練數據

1. torch.utils.data.Dataset

2. torch.utils.data.DataLoader

3. 使用Dataset, DataLoader產生自定義訓練數據

3.1 自定義Dataset

3.2 DataLoader產生批訓練數據

3.3 附件:image_processing.py

3.4 完整的代碼


1. torch.utils.data.Dataset

??datasets這是一個pytorch定義的dataset的源碼集合。下面是一個自定義Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()兩個方法是必須重寫的。__getitem__()返回訓練數據,如圖片和label,而__len__()返回數據長度。

class CustomDataset(data.Dataset):#需要繼承data.Datasetdef __init__(self):# TODO# 1. Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).#這里需要注意的是,第一步:read one data,是一個datapassdef __len__(self):# You should change 0 to the total size of your dataset.return 0

2. torch.utils.data.DataLoader

DataLoader(object)可用參數:

  • dataset(Dataset): 傳入的數據集
  • batch_size(int, optional): 每個batch有多少個樣本
  • shuffle(bool, optional): 在每個epoch開始的時候,對數據進行重新排序
  • sampler(Sampler, optional): 自定義從數據集中取樣本的策略,如果指定這個參數,那么shuffle必須為False
  • batch_sampler(Sampler, optional): 與sampler類似,但是一次只返回一個batch的indices(索引),需要注意的是,一旦指定了這個參數,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  • num_workers (int, optional): 這個參數決定了有幾個進程來處理data loading。0意味著所有的數據都會被load進主進程。(默認為0)
  • collate_fn (callable, optional): 將一個list的sample組成一個mini-batch的函數
  • pin_memory (bool, optional): 如果設置為True,那么data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內存(CUDA pinned memory)中.
  • drop_last (bool, optional):如果設置為True:這個是對最后的未完成的batch來說的,比如你的batch_size設置為64,而一個epoch只有100個樣本,那么訓練的時候后面的36個就被扔掉了。?如果為False(默認),那么會繼續正常執行,只是最后的batch_size會小一點。
  • timeout(numeric, optional):如果是正數,表明等待從worker進程中收集一個batch等待的時間,若超出設定的時間還沒有收集到,那就不收集這個內容了。這個numeric應總是大于等于0。默認為0
  • worker_init_fn (callable, optional): 每個worker初始化函數 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)?

  • 3. 使用Dataset, DataLoader產生自定義訓練數據

    假設TXT文件保存了數據的圖片和label,格式如下:第一列是圖片的名字,第二列是label

    0.jpg 0 1.jpg 1 2.jpg 2 3.jpg 3 4.jpg 4 5.jpg 5 6.jpg 6 7.jpg 7 8.jpg 8 9.jpg 9

    也可以是多標簽的數據,如:

    0.jpg 0 10 1.jpg 1 11 2.jpg 2 12 3.jpg 3 13 4.jpg 4 14 5.jpg 5 15 6.jpg 6 16 7.jpg 7 17 8.jpg 8 18 9.jpg 9 19

    圖庫十張原始圖片放在./dataset/images目錄下,然后我們就可以自定義一個Dataset解析這些數據并讀取圖片,再使用DataLoader類產生batch的訓練數據


    3.1 自定義Dataset

    首先先自定義一個TorchDataset類,用于讀取圖片數據,產生標簽:

    注意初始化函數:

    import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader import numpy as np from utils import image_processing import osclass TorchDataset(Dataset):def __init__(self, filename, image_dir, resize_height=256, resize_width=256, repeat=1):''':param filename: 數據文件TXT:格式:imge_name.jpg label1_id labe2_id:param image_dir: 圖片路徑:image_dir+imge_name.jpg構成圖片的完整路徑:param resize_height 為None時,不進行縮放:param resize_width 為None時,不進行縮放,PS:當參數resize_height或resize_width其中一個為None時,可實現等比例縮放:param repeat: 所有樣本數據重復次數,默認循環一次,當repeat為None時,表示無限循環<sys.maxsize'''self.image_label_list = self.read_file(filename)self.image_dir = image_dirself.len = len(self.image_label_list)self.repeat = repeatself.resize_height = resize_heightself.resize_width = resize_width# 相關預處理的初始化'''class torchvision.transforms.ToTensor'''# 把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數據# 轉換成shape=(C,H,W)的像素數據,并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。self.toTensor = transforms.ToTensor()'''class torchvision.transforms.Normalize(mean, std)此轉換類作用于torch. * Tensor,給定均值(R, G, B) 和標準差(R, G, B),用公式channel = (channel - mean) / std進行規范化。'''# self.normalize=transforms.Normalize()def __getitem__(self, i):index = i % self.len# print("i={},index={}".format(i, index))image_name, label = self.image_label_list[index]image_path = os.path.join(self.image_dir, image_name)img = self.load_data(image_path, self.resize_height, self.resize_width, normalization=False)img = self.data_preproccess(img)label=np.array(label)return img, labeldef __len__(self):if self.repeat == None:data_len = 10000000else:data_len = len(self.image_label_list) * self.repeatreturn data_lendef read_file(self, filename):image_label_list = []with open(filename, 'r') as f:lines = f.readlines()for line in lines:# rstrip:用來去除結尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格)content = line.rstrip().split(' ')name = content[0]labels = []for value in content[1:]:labels.append(int(value))image_label_list.append((name, labels))return image_label_listdef load_data(self, path, resize_height, resize_width, normalization):'''加載數據:param path::param resize_height::param resize_width::param normalization: 是否歸一化:return:'''image = image_processing.read_image(path, resize_height, resize_width, normalization)return imagedef data_preproccess(self, data):'''數據預處理:param data::return:'''data = self.toTensor(data)return data

    3.2 DataLoader產生批訓練數據

    if __name__=='__main__':train_filename="../dataset/train.txt"# test_filename="../dataset/test.txt"image_dir='../dataset/images'epoch_num=2 #總樣本循環次數batch_size=7 #訓練時的一組數據的大小train_data_nums=10max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #總迭代次數train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=1)# test_data = TorchDataset(filename=test_filename, image_dir=image_dir,repeat=1)train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)# test_loader = DataLoader(dataset=test_data, batch_size=batch_size,shuffle=False)# [1]使用epoch方法迭代,TorchDataset的參數repeat=1for epoch in range(epoch_num):for batch_image, batch_label in train_loader:image=batch_image[0,:]image=image.numpy()#image=np.array(image)image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]image_processing.cv_show_image("image",image)print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))# batch_x, batch_y = Variable(batch_x), Variable(batch_y)

    上面的迭代代碼是通過兩個for實現,其中參數epoch_num表示總樣本循環次數,比如epoch_num=2,那就是所有樣本循環迭代2次。但這會出現一個問題,當樣本總數train_data_nums與batch_size不能整取時,最后一個batch會少于規定batch_size的大小,比如這里樣本總數train_data_nums=10,batch_size=7,第一次迭代會產生7個樣本,第二次迭代會因為樣本不足,只能產生3個樣本。

    我們希望,每次迭代都會產生相同大小的batch數據,因此可以如下迭代:注意本人在構造TorchDataset類時,就已經考慮循環迭代的方法,因此,你現在只需修改repeat為None時,就表示無限循環了,調用方法如下:

    '''下面兩種方式,TorchDataset設置repeat=None可以實現無限循環,退出循環由max_iterate設定'''train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=None)train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)# [2]第2種迭代方法for step, (batch_image, batch_label) in enumerate(train_loader):image=batch_image[0,:]image=image.numpy()#image=np.array(image)image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]image_processing.cv_show_image("image",image)print("step:{},batch_image.shape:{},batch_label:{}".format(step,batch_image.shape,batch_label))# batch_x, batch_y = Variable(batch_x), Variable(batch_y)if step>=max_iterate:break# [3]第3種迭代方法# for step in range(max_iterate):# batch_image, batch_label=train_loader.__iter__().__next__()# image=batch_image[0,:]# image=image.numpy()#image=np.array(image)# image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]# image_processing.cv_show_image("image",image)# print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))# # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

    3.3 附件:image_processing.py

    上面代碼,用到image_processing,這是本人封裝好的圖像處理包,包含讀取圖片,畫圖等基本方法:

    # -*-coding: utf-8 -*- """@Project: IntelligentManufacture@File : image_processing.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2019-02-14 15:34:50 """import os import glob import cv2 import numpy as np import matplotlib.pyplot as pltdef show_image(title, image):'''調用matplotlib顯示RGB圖片:param title: 圖像標題:param image: 圖像的數據:return:'''# plt.figure("show_image")# print(image.dtype)plt.imshow(image)plt.axis('on') # 關掉坐標軸為 offplt.title(title) # 圖像題目plt.show()def cv_show_image(title, image):'''調用OpenCV顯示RGB圖片:param title: 圖像標題:param image: 輸入RGB圖像:return:'''channels=image.shape[-1]if channels==3:image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # 將BGR轉為RGBcv2.imshow(title,image)cv2.waitKey(0)def read_image(filename, resize_height=None, resize_width=None, normalization=False):'''讀取圖片數據,默認返回的是uint8,[0,255]:param filename::param resize_height::param resize_width::param normalization:是否歸一化到[0.,1.0]:return: 返回的RGB圖片數據'''bgr_image = cv2.imread(filename)# bgr_image = cv2.imread(filename,cv2.IMREAD_IGNORE_ORIENTATION|cv2.IMREAD_COLOR)if bgr_image is None:print("Warning:不存在:{}", filename)return Noneif len(bgr_image.shape) == 2: # 若是灰度圖則轉為三通道print("Warning:gray image", filename)bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 將BGR轉為RGB# show_image(filename,rgb_image)# rgb_image=Image.open(filename)rgb_image = resize_image(rgb_image,resize_height,resize_width)rgb_image = np.asanyarray(rgb_image)if normalization:# 不能寫成:rgb_image=rgb_image/255rgb_image = rgb_image / 255.0# show_image("src resize image",image)return rgb_imagedef fast_read_image_roi(filename, orig_rect, ImreadModes=cv2.IMREAD_COLOR, normalization=False):'''快速讀取圖片的方法:param filename: 圖片路徑:param orig_rect:原始圖片的感興趣區域rect:param ImreadModes: IMREAD_UNCHANGEDIMREAD_GRAYSCALEIMREAD_COLORIMREAD_ANYDEPTHIMREAD_ANYCOLORIMREAD_LOAD_GDALIMREAD_REDUCED_GRAYSCALE_2IMREAD_REDUCED_COLOR_2IMREAD_REDUCED_GRAYSCALE_4IMREAD_REDUCED_COLOR_4IMREAD_REDUCED_GRAYSCALE_8IMREAD_REDUCED_COLOR_8IMREAD_IGNORE_ORIENTATION:param normalization: 是否歸一化:return: 返回感興趣區域ROI'''# 當采用IMREAD_REDUCED模式時,對應rect也需要縮放scale=1if ImreadModes == cv2.IMREAD_REDUCED_COLOR_2 or ImreadModes == cv2.IMREAD_REDUCED_COLOR_2:scale=1/2elif ImreadModes == cv2.IMREAD_REDUCED_GRAYSCALE_4 or ImreadModes == cv2.IMREAD_REDUCED_COLOR_4:scale=1/4elif ImreadModes == cv2.IMREAD_REDUCED_GRAYSCALE_8 or ImreadModes == cv2.IMREAD_REDUCED_COLOR_8:scale=1/8rect = np.array(orig_rect)*scalerect = rect.astype(int).tolist()bgr_image = cv2.imread(filename,flags=ImreadModes)if bgr_image is None:print("Warning:不存在:{}", filename)return Noneif len(bgr_image.shape) == 3: #rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 將BGR轉為RGBelse:rgb_image=bgr_image #若是灰度圖rgb_image = np.asanyarray(rgb_image)if normalization:# 不能寫成:rgb_image=rgb_image/255rgb_image = rgb_image / 255.0roi_image=get_rect_image(rgb_image , rect)# show_image_rect("src resize image",rgb_image,rect)# cv_show_image("reROI",roi_image)return roi_imagedef resize_image(image,resize_height, resize_width):''':param image::param resize_height::param resize_width::return:'''image_shape=np.shape(image)height=image_shape[0]width=image_shape[1]if (resize_height is None) and (resize_width is None):#錯誤寫法:resize_height and resize_width is Nonereturn imageif resize_height is None:resize_height=int(height*resize_width/width)elif resize_width is None:resize_width=int(width*resize_height/height)image = cv2.resize(image, dsize=(resize_width, resize_height))return image def scale_image(image,scale):''':param image::param scale: (scale_w,scale_h):return:'''image = cv2.resize(image,dsize=None, fx=scale[0],fy=scale[1])return imagedef get_rect_image(image,rect):''':param image::param rect: [x,y,w,h]:return:'''x, y, w, h=rectcut_img = image[y:(y+ h),x:(x+w)]return cut_img def scale_rect(orig_rect,orig_shape,dest_shape):'''對圖像進行縮放時,對應的rectangle也要進行縮放:param orig_rect: 原始圖像的rect=[x,y,w,h]:param orig_shape: 原始圖像的維度shape=[h,w]:param dest_shape: 縮放后圖像的維度shape=[h,w]:return: 經過縮放后的rectangle'''new_x=int(orig_rect[0]*dest_shape[1]/orig_shape[1])new_y=int(orig_rect[1]*dest_shape[0]/orig_shape[0])new_w=int(orig_rect[2]*dest_shape[1]/orig_shape[1])new_h=int(orig_rect[3]*dest_shape[0]/orig_shape[0])dest_rect=[new_x,new_y,new_w,new_h]return dest_rectdef show_image_rect(win_name,image,rect):''':param win_name::param image::param rect::return:'''x, y, w, h=rectpoint1=(x,y)point2=(x+w,y+h)cv2.rectangle(image, point1, point2, (0, 0, 255), thickness=2)cv_show_image(win_name, image)def rgb_to_gray(image):image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)return imagedef save_image(image_path, rgb_image,toUINT8=True):if toUINT8:rgb_image = np.asanyarray(rgb_image * 255, dtype=np.uint8)if len(rgb_image.shape) == 2: # 若是灰度圖則轉為三通道bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_GRAY2BGR)else:bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)cv2.imwrite(image_path, bgr_image)def combime_save_image(orig_image, dest_image, out_dir,name,prefix):'''命名標準:out_dir/name_prefix.jpg:param orig_image::param dest_image::param image_path::param out_dir::param prefix::return:'''dest_path = os.path.join(out_dir, name + "_"+prefix+".jpg")save_image(dest_path, dest_image)dest_image = np.hstack((orig_image, dest_image))save_image(os.path.join(out_dir, "{}_src_{}.jpg".format(name,prefix)), dest_image)

    3.4 完整的代碼

    # -*-coding: utf-8 -*- """@Project: pytorch-learning-tutorials@File : dataset.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2019-03-07 18:45:06 """ import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader import numpy as np from utils import image_processing import osclass TorchDataset(Dataset):def __init__(self, filename, image_dir, resize_height=256, resize_width=256, repeat=1):''':param filename: 數據文件TXT:格式:imge_name.jpg label1_id labe2_id:param image_dir: 圖片路徑:image_dir+imge_name.jpg構成圖片的完整路徑:param resize_height 為None時,不進行縮放:param resize_width 為None時,不進行縮放,PS:當參數resize_height或resize_width其中一個為None時,可實現等比例縮放:param repeat: 所有樣本數據重復次數,默認循環一次,當repeat為None時,表示無限循環<sys.maxsize'''self.image_label_list = self.read_file(filename)self.image_dir = image_dirself.len = len(self.image_label_list)self.repeat = repeatself.resize_height = resize_heightself.resize_width = resize_width# 相關預處理的初始化'''class torchvision.transforms.ToTensor'''# 把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數據# 轉換成shape=(C,H,W)的像素數據,并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。self.toTensor = transforms.ToTensor()'''class torchvision.transforms.Normalize(mean, std)此轉換類作用于torch. * Tensor,給定均值(R, G, B) 和標準差(R, G, B),用公式channel = (channel - mean) / std進行規范化。'''# self.normalize=transforms.Normalize()def __getitem__(self, i):index = i % self.len# print("i={},index={}".format(i, index))image_name, label = self.image_label_list[index]image_path = os.path.join(self.image_dir, image_name)img = self.load_data(image_path, self.resize_height, self.resize_width, normalization=False)img = self.data_preproccess(img)label=np.array(label)return img, labeldef __len__(self):if self.repeat == None:data_len = 10000000else:data_len = len(self.image_label_list) * self.repeatreturn data_lendef read_file(self, filename):image_label_list = []with open(filename, 'r') as f:lines = f.readlines()for line in lines:# rstrip:用來去除結尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格)content = line.rstrip().split(' ')name = content[0]labels = []for value in content[1:]:labels.append(int(value))image_label_list.append((name, labels))return image_label_listdef load_data(self, path, resize_height, resize_width, normalization):'''加載數據:param path::param resize_height::param resize_width::param normalization: 是否歸一化:return:'''image = image_processing.read_image(path, resize_height, resize_width, normalization)return imagedef data_preproccess(self, data):'''數據預處理:param data::return:'''data = self.toTensor(data)return dataif __name__=='__main__':train_filename="../dataset/train.txt"# test_filename="../dataset/test.txt"image_dir='../dataset/images'epoch_num=2 #總樣本循環次數batch_size=7 #訓練時的一組數據的大小train_data_nums=10max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #總迭代次數train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=1)# test_data = TorchDataset(filename=test_filename, image_dir=image_dir,repeat=1)train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)# test_loader = DataLoader(dataset=test_data, batch_size=batch_size,shuffle=False)# [1]使用epoch方法迭代,TorchDataset的參數repeat=1for epoch in range(epoch_num):for batch_image, batch_label in train_loader:image=batch_image[0,:]image=image.numpy()#image=np.array(image)image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]image_processing.cv_show_image("image",image)print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))# batch_x, batch_y = Variable(batch_x), Variable(batch_y)'''下面兩種方式,TorchDataset設置repeat=None可以實現無限循環,退出循環由max_iterate設定'''train_data = TorchDataset(filename=train_filename, image_dir=image_dir,repeat=None)train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False)# [2]第2種迭代方法for step, (batch_image, batch_label) in enumerate(train_loader):image=batch_image[0,:]image=image.numpy()#image=np.array(image)image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]image_processing.cv_show_image("image",image)print("step:{},batch_image.shape:{},batch_label:{}".format(step,batch_image.shape,batch_label))# batch_x, batch_y = Variable(batch_x), Variable(batch_y)if step>=max_iterate:break# [3]第3種迭代方法# for step in range(max_iterate):# batch_image, batch_label=train_loader.__iter__().__next__()# image=batch_image[0,:]# image=image.numpy()#image=np.array(image)# image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c]# image_processing.cv_show_image("image",image)# print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label))# # batch_x, batch_y = Variable(batch_x), Variable(batch_y)

    ?

    總結

    以上是生活随笔為你收集整理的pytorch Dataset, DataLoader产生自定义的训练数据的全部內容,希望文章能夠幫你解決所遇到的問題。

    如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。