Pytorch自定义数据集
簡述
Pytorch自定義數(shù)據(jù)集方法,應該是用pytorch做算法的最基本的東西。
往往網(wǎng)絡(luò)上給的demo都是基于torch自帶的MNIST的相關(guān)類。所以,為了解決使用其他的數(shù)據(jù)集,在查閱了torch關(guān)于MNIST數(shù)據(jù)集的源碼之后,很容易就可以推廣到了我們自己需要的代碼上。
具體操作如下:
準備工作
需要導入一些包。
from torch.utils.data import Dataset, DataLoader再自定義一個用于當訓練集合的類。
class TrainSet(Dataset):def __init__(self, X, Y):# 定義好 image 的路徑self.X, self.Y = X, Ydef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)數(shù)據(jù)預處理
之后,假設(shè)你的訓練集合為[X,Y],其中X是訓練數(shù)據(jù),Y是對應的數(shù)據(jù)的標簽。
首先,需要知道的是,torch能處理的數(shù)據(jù)只能是torch.Tensor,所以有必要將其他數(shù)據(jù)轉(zhuǎn)換為torch.Tensor。
常見的有幾種數(shù)據(jù):
- np.ndarray
- PIL.Image
- …
如果是圖片數(shù)據(jù),其實也有多種情況,根據(jù)數(shù)據(jù)維度不同,有些是二維圖,有些是三維圖(通俗來講,就是黑白圖和彩圖)。
所以,我先按照數(shù)據(jù)類型的模式將一遍,再補充關(guān)于圖片的處理。
np.ndarray
np.ndarray是非常常見的格式,轉(zhuǎn)成Tensor也非常簡單。
torch.Tensor(array)這樣代碼的返回格式就是一個Tensor。
PIL.Image
import torchvision.transforms as transforms transforms.ToTensor()(image)這樣代碼的返回格式就是一個Tensor。
關(guān)于圖片
- 彩色的三維圖: 上面方法就已經(jīng)完成了對應的數(shù)據(jù)處理的步驟
- 灰白或者是二值的二維圖:就需要將數(shù)據(jù)增加一個維度了(因為往往關(guān)于圖片,所用到的算法都是包括了卷積的步驟,所以要求增加一個維度)
具體操作如下: 明顯,torch.Tensor(X)這樣的步驟,其實是重復了上面的將np.ndarray轉(zhuǎn)成torch.Tensor的步驟。同理可以換成上面的關(guān)于PIL.Image的方法
X_tensor = torch.unsqueeze(torch.Tensor(X), 1) Y_tensor = torch.unsqueeze(torch.Tensor(Y), 1)導入數(shù)據(jù)
建立自己的數(shù)據(jù)集。
mydataset = TrainSet(X_tensor, Y_tensor)再把自己的數(shù)據(jù)集導入到數(shù)據(jù)加載器上:
- batch_size表示用將原數(shù)據(jù)拆分之后,每batch_size個數(shù)據(jù)作為一組數(shù)據(jù)被調(diào)用。shuffle表示數(shù)據(jù)是否被洗牌(即刷新順序,避免訓練的時候多次調(diào)用結(jié)果都遇到同一batch,從而避免誤差)
使用的方式也非常簡單:
for step, (x, y) in enumerate(train_loader):這里的x,y就是每個batch所處理的數(shù)據(jù)。
另外,附上一個我常用的讀取自定義圖片的dataset類
main函數(shù)部分是對數(shù)據(jù)集做測試。
import torch.utils.data as data import glob import os import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt import numpy as np import torchimport piexif import imghdrclass MyDataset(data.Dataset):def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False):if resize != -1:transform = transforms.Compose([transforms.Resize(resize),transforms.CenterCrop(resize),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])else:transform = transforms.Compose([transforms.ToTensor(),])img_format = '*.%s' % img_typeif remove_exif:for name in glob.glob(os.path.join(path, img_format)):try:piexif.remove(name) # 去除exifexcept Exception:continue# imghdr.what(img_path) 判斷是否為損壞圖片if Len == -1:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format)) if imghdr.what(name)]else:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]self.dataset = np.array(self.dataset)self.dataset = torch.Tensor(self.dataset)self.Train = Traindef __len__(self):return len(self.dataset)def __getitem__(self, idx):return self.dataset[idx]if __name__ == '__main__':path = r'D:\Software\DataSet\faces'dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg')print(len(dataset))plt.imshow(dataset[0].numpy().transpose(1, 2, 0) * 0.5 + 0.5)plt.show()print(dataset[0].max(), dataset[0].min())總結(jié)
以上是生活随笔為你收集整理的Pytorch自定义数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 删除有序vector中的重复值c++
- 下一篇: 【plt显示Tensor转出来的arra