Pytorch自定义数据集
簡述
Pytorch自定義數據集方法,應該是用pytorch做算法的最基本的東西。
往往網絡上給的demo都是基于torch自帶的MNIST的相關類。所以,為了解決使用其他的數據集,在查閱了torch關于MNIST數據集的源碼之后,很容易就可以推廣到了我們自己需要的代碼上。
具體操作如下:
準備工作
需要導入一些包。
from torch.utils.data import Dataset, DataLoader再自定義一個用于當訓練集合的類。
class TrainSet(Dataset):def __init__(self, X, Y):# 定義好 image 的路徑self.X, self.Y = X, Ydef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)數據預處理
之后,假設你的訓練集合為[X,Y],其中X是訓練數據,Y是對應的數據的標簽。
首先,需要知道的是,torch能處理的數據只能是torch.Tensor,所以有必要將其他數據轉換為torch.Tensor。
常見的有幾種數據:
- np.ndarray
- PIL.Image
- …
如果是圖片數據,其實也有多種情況,根據數據維度不同,有些是二維圖,有些是三維圖(通俗來講,就是黑白圖和彩圖)。
所以,我先按照數據類型的模式將一遍,再補充關于圖片的處理。
np.ndarray
np.ndarray是非常常見的格式,轉成Tensor也非常簡單。
torch.Tensor(array)這樣代碼的返回格式就是一個Tensor。
PIL.Image
import torchvision.transforms as transforms transforms.ToTensor()(image)這樣代碼的返回格式就是一個Tensor。
關于圖片
- 彩色的三維圖: 上面方法就已經完成了對應的數據處理的步驟
- 灰白或者是二值的二維圖:就需要將數據增加一個維度了(因為往往關于圖片,所用到的算法都是包括了卷積的步驟,所以要求增加一個維度)
具體操作如下: 明顯,torch.Tensor(X)這樣的步驟,其實是重復了上面的將np.ndarray轉成torch.Tensor的步驟。同理可以換成上面的關于PIL.Image的方法
X_tensor = torch.unsqueeze(torch.Tensor(X), 1) Y_tensor = torch.unsqueeze(torch.Tensor(Y), 1)導入數據
建立自己的數據集。
mydataset = TrainSet(X_tensor, Y_tensor)再把自己的數據集導入到數據加載器上:
- batch_size表示用將原數據拆分之后,每batch_size個數據作為一組數據被調用。shuffle表示數據是否被洗牌(即刷新順序,避免訓練的時候多次調用結果都遇到同一batch,從而避免誤差)
使用的方式也非常簡單:
for step, (x, y) in enumerate(train_loader):這里的x,y就是每個batch所處理的數據。
另外,附上一個我常用的讀取自定義圖片的dataset類
main函數部分是對數據集做測試。
import torch.utils.data as data import glob import os import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt import numpy as np import torchimport piexif import imghdrclass MyDataset(data.Dataset):def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False):if resize != -1:transform = transforms.Compose([transforms.Resize(resize),transforms.CenterCrop(resize),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])else:transform = transforms.Compose([transforms.ToTensor(),])img_format = '*.%s' % img_typeif remove_exif:for name in glob.glob(os.path.join(path, img_format)):try:piexif.remove(name) # 去除exifexcept Exception:continue# imghdr.what(img_path) 判斷是否為損壞圖片if Len == -1:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format)) if imghdr.what(name)]else:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]self.dataset = np.array(self.dataset)self.dataset = torch.Tensor(self.dataset)self.Train = Traindef __len__(self):return len(self.dataset)def __getitem__(self, idx):return self.dataset[idx]if __name__ == '__main__':path = r'D:\Software\DataSet\faces'dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg')print(len(dataset))plt.imshow(dataset[0].numpy().transpose(1, 2, 0) * 0.5 + 0.5)plt.show()print(dataset[0].max(), dataset[0].min())總結
以上是生活随笔為你收集整理的Pytorch自定义数据集的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 删除有序vector中的重复值c++
- 下一篇: 【plt显示Tensor转出来的arra