Pytorch基础(三)—— DataSet的应用
一、概念
Pytorch的標(biāo)準(zhǔn)數(shù)據(jù)集包括很多種類型,如CIFAR,COCO,KITTI,MNIST等,我們可以在官網(wǎng)查看。當(dāng)然我們也可以做數(shù)據(jù)集,但需要自己標(biāo)注。
二、如何調(diào)用數(shù)據(jù)集
一、調(diào)用torchvision
在程序中調(diào)用torchvision.datasets,下面用程序示例如何下載CIFAR10數(shù)據(jù)集。
import torchvisiontrain_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True) test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)也可以復(fù)制路徑通過其他方式下載,然后將下載文件放入py文件路徑下,可運(yùn)行程序可自動解壓。
如果想顯示數(shù)據(jù)集的圖片,可以直接調(diào)用imshow方法。
如果想通過tensorboard顯示圖片,需要先將圖片格式轉(zhuǎn)化為tensor,然后調(diào)用SummaryWriter類。
二、調(diào)用dataset類
dataset類屬于抽象類,需要通過創(chuàng)建子類來繼承,從而創(chuàng)建數(shù)據(jù)集。
from torch.utils.data import Dataset from PIL import Image import osclass MyData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_path[idx]img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_path)_ init _可以初始化子類的基礎(chǔ)參數(shù),可以自定義,相當(dāng)于構(gòu)造函數(shù)。
_ getitem _根據(jù)索引返回數(shù)據(jù)和標(biāo)簽。
_ len _返回數(shù)據(jù)大小
三、加載數(shù)據(jù)集
一般用DataLoader類來加載數(shù)據(jù)集。常見的參數(shù)包括:batch_size, shuffle num_workers。
這些參數(shù)的意義如下:
batch_size:指批大小,在訓(xùn)練時每次在訓(xùn)練集中取batchsize個樣本。
epoch:指使用所有訓(xùn)練集的樣本訓(xùn)練一次。
shuffle :指將訓(xùn)練集進(jìn)行打亂的操作,一般生成數(shù)據(jù)集的時候要shuffle一下圖片順序,防止過擬合。
num_workers:設(shè)定DataLoader要使用多少個子進(jìn)程進(jìn)行加載。
drop_last:指訓(xùn)練集經(jīng)過批處理后剩余的部分?jǐn)?shù)據(jù)的處理模式。ture代表丟棄,false代表繼續(xù)執(zhí)行,只是batch_size會相對變小。
簡單例子:
import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWritertest_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())test_load = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=False)總結(jié)
以上是生活随笔為你收集整理的Pytorch基础(三)—— DataSet的应用的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 程序员教程第一章第二节
- 下一篇: 语义化版本控制规范(SemVer)