日韩av黄I国产麻豆传媒I国产91av视频在线观看I日韩一区二区三区在线看I美女国产在线I麻豆视频国产在线观看I成人黄色短片

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 >

Pytorch自定义数据集

發(fā)布時間:2025/4/16 57 豆豆
生活随笔 收集整理的這篇文章主要介紹了 Pytorch自定义数据集 小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.

簡述

Pytorch自定義數(shù)據(jù)集方法,應該是用pytorch做算法的最基本的東西。
往往網(wǎng)絡(luò)上給的demo都是基于torch自帶的MNIST的相關(guān)類。所以,為了解決使用其他的數(shù)據(jù)集,在查閱了torch關(guān)于MNIST數(shù)據(jù)集的源碼之后,很容易就可以推廣到了我們自己需要的代碼上。

具體操作如下:

準備工作

需要導入一些包。

from torch.utils.data import Dataset, DataLoader

再自定義一個用于當訓練集合的類。

class TrainSet(Dataset):def __init__(self, X, Y):# 定義好 image 的路徑self.X, self.Y = X, Ydef __getitem__(self, index):return self.X[index], self.Y[index]def __len__(self):return len(self.X)

數(shù)據(jù)預處理

之后,假設(shè)你的訓練集合為[X,Y],其中X是訓練數(shù)據(jù),Y是對應的數(shù)據(jù)的標簽。

首先,需要知道的是,torch能處理的數(shù)據(jù)只能是torch.Tensor,所以有必要將其他數(shù)據(jù)轉(zhuǎn)換為torch.Tensor。

常見的有幾種數(shù)據(jù):

  • np.ndarray
  • PIL.Image

如果是圖片數(shù)據(jù),其實也有多種情況,根據(jù)數(shù)據(jù)維度不同,有些是二維圖,有些是三維圖(通俗來講,就是黑白圖和彩圖)。

所以,我先按照數(shù)據(jù)類型的模式將一遍,再補充關(guān)于圖片的處理。

np.ndarray

np.ndarray是非常常見的格式,轉(zhuǎn)成Tensor也非常簡單。

torch.Tensor(array)

這樣代碼的返回格式就是一個Tensor。

PIL.Image

import torchvision.transforms as transforms transforms.ToTensor()(image)

這樣代碼的返回格式就是一個Tensor。

關(guān)于圖片

  • 彩色的三維圖: 上面方法就已經(jīng)完成了對應的數(shù)據(jù)處理的步驟
  • 灰白或者是二值的二維圖:就需要將數(shù)據(jù)增加一個維度了(因為往往關(guān)于圖片,所用到的算法都是包括了卷積的步驟,所以要求增加一個維度)

具體操作如下: 明顯,torch.Tensor(X)這樣的步驟,其實是重復了上面的將np.ndarray轉(zhuǎn)成torch.Tensor的步驟。同理可以換成上面的關(guān)于PIL.Image的方法

X_tensor = torch.unsqueeze(torch.Tensor(X), 1) Y_tensor = torch.unsqueeze(torch.Tensor(Y), 1)

導入數(shù)據(jù)

建立自己的數(shù)據(jù)集。

mydataset = TrainSet(X_tensor, Y_tensor)

再把自己的數(shù)據(jù)集導入到數(shù)據(jù)加載器上:

  • batch_size表示用將原數(shù)據(jù)拆分之后,每batch_size個數(shù)據(jù)作為一組數(shù)據(jù)被調(diào)用。shuffle表示數(shù)據(jù)是否被洗牌(即刷新順序,避免訓練的時候多次調(diào)用結(jié)果都遇到同一batch,從而避免誤差)
train_loader = DataLoader(mydataset, batch_size=10, shuffle=True)

使用的方式也非常簡單:

for step, (x, y) in enumerate(train_loader):

這里的x,y就是每個batch所處理的數(shù)據(jù)。

另外,附上一個我常用的讀取自定義圖片的dataset類

main函數(shù)部分是對數(shù)據(jù)集做測試。

import torch.utils.data as data import glob import os import torchvision.transforms as transforms from PIL import Image import matplotlib.pyplot as plt import numpy as np import torchimport piexif import imghdrclass MyDataset(data.Dataset):def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False):if resize != -1:transform = transforms.Compose([transforms.Resize(resize),transforms.CenterCrop(resize),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])else:transform = transforms.Compose([transforms.ToTensor(),])img_format = '*.%s' % img_typeif remove_exif:for name in glob.glob(os.path.join(path, img_format)):try:piexif.remove(name) # 去除exifexcept Exception:continue# imghdr.what(img_path) 判斷是否為損壞圖片if Len == -1:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format)) if imghdr.what(name)]else:self.dataset = [np.array(transform(Image.open(name).convert("RGB"))) for name inglob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]self.dataset = np.array(self.dataset)self.dataset = torch.Tensor(self.dataset)self.Train = Traindef __len__(self):return len(self.dataset)def __getitem__(self, idx):return self.dataset[idx]if __name__ == '__main__':path = r'D:\Software\DataSet\faces'dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg')print(len(dataset))plt.imshow(dataset[0].numpy().transpose(1, 2, 0) * 0.5 + 0.5)plt.show()print(dataset[0].max(), dataset[0].min())

總結(jié)

以上是生活随笔為你收集整理的Pytorch自定义数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網(wǎng)站內(nèi)容還不錯,歡迎將生活随笔推薦給好友。