深度学习修炼(二)——数据集的加载
文章目錄
- 致謝
- 2 數(shù)據(jù)集的加載
- 2.1 框架數(shù)據(jù)集的加載
- 2.2 自定義數(shù)據(jù)集
- 2.3 準(zhǔn)備數(shù)據(jù)以進(jìn)行數(shù)據(jù)加載器訓(xùn)練
致謝
Pytorch自帶數(shù)據(jù)集介紹_godblesstao的博客-CSDN博客_pytorch自帶數(shù)據(jù)集
2 數(shù)據(jù)集的加載
與sklearn中的datasets自帶數(shù)據(jù)集類似,pytorch框架也為我們提供了數(shù)據(jù)集以便一系列的模型測(cè)試。其數(shù)據(jù)集作為一個(gè)類繼承自父類torch.utils.data.Dataset。
2.1 框架數(shù)據(jù)集的加載
讓我們看看torch為我們提供了什么數(shù)據(jù)集。數(shù)據(jù)集種類如下所示:
-
手寫字符識(shí)別:EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot
-
實(shí)物分類:Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet
-
人臉識(shí)別:CelebA
-
場(chǎng)景分類:LSUN、Places365
-
用于object detection:SVHN、VOCDetection、COCODetection
-
用于semantic/instance segmentation:
-
語(yǔ)義分割:Cityscapes、VOCSegmentation
-
語(yǔ)義邊界:SBD
-
用于image captioning:Flickr、COCOCaption
-
用于video classification:HMDB51、Kinetics
-
用于3D reconstruction:PhotoTour
-
用于shadow detectors:SBU
以FashionMNIST數(shù)據(jù)集為例,我們看一下如何加載數(shù)據(jù)集。
torch.datasets.FashionMNIST(root = “data”,train = True,download = True,transform = ToTensor())
- root是存儲(chǔ)訓(xùn)練/測(cè)試數(shù)據(jù)的路徑
- train指定訓(xùn)練或測(cè)試數(shù)據(jù)集,當(dāng)布爾值為True則為訓(xùn)練集,當(dāng)布爾值為False則為測(cè)試集
- download=True從互聯(lián)網(wǎng)下載數(shù)據(jù)(如果無(wú)法在本地獲得)
- transform指定特征轉(zhuǎn)換方式,target_transform指定標(biāo)簽轉(zhuǎn)換方式
數(shù)據(jù)集加載完實(shí)際上是以類的形式存在的,其不同于sklearn中返回的Bunch。
如果我們想要看看數(shù)據(jù)集中有啥要怎么做呢?首先,這個(gè)數(shù)據(jù)集是圖像分類數(shù)據(jù)集,說(shuō)明里面含有的都是圖像,為此,我們可以使用subplots存放這些圖片。對(duì)于這些數(shù)據(jù)集,我們可以像列表一樣手動(dòng)索引。如train_data[index]。
import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as pltdef load_data():"""加載數(shù)據(jù)集"""# 1 訓(xùn)練數(shù)據(jù)集的加載train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 測(cè)試數(shù)據(jù)集的加載test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""數(shù)據(jù)集可視化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 從訓(xùn)練集中隨機(jī)抽出九張圖(九個(gè)樣本)for i in range(1, cols * rows + 1):# 設(shè)置索引,索引取值為0到訓(xùn)練集的長(zhǎng)度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出對(duì)應(yīng)樣本的圖片和標(biāo)簽img, label = train_data[sample_idx]# 依次畫于事先指定的九宮格圖上figure.add_subplot(rows, cols, i)# 設(shè)置對(duì)應(yīng)圖片的標(biāo)題plt.title(label_map[label])# 關(guān)掉坐標(biāo)軸plt.axis("off")# 展示圖片plt.imshow(img.squeeze(), cmap="gray")# 釋放畫布plt.show()train_data, test_data = load_data() show_data(train_data)out:
上面用到了一個(gè)API,即torch.randint()
torch.randint(low=0, high, size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
- 用于取隨機(jī)整數(shù),返回值為張量
- low:int類型,表明要從分布中提取的最低整數(shù)
- high:int類型,表明要從分布中提取的最高整數(shù)1
- size:元組類型,表明輸出張量的形狀
- dtype:返回值張量的數(shù)據(jù)類型
- device:返回張量所需的設(shè)備
- requires_grad:布爾類型,表明是否要對(duì)返回的張量自動(dòng)求導(dǎo)。
如:
torch.randint(3, 5, (3,)) tensor([4, 3, 4])意味生成一個(gè)一維的3元素向量,其中向量中的元素取值從3-5取。
2.2 自定義數(shù)據(jù)集
如果你不想使用框架自帶的數(shù)據(jù)集,那么你可以自己定義一個(gè)數(shù)據(jù)集類。自定義Dataset類必須實(shí)現(xiàn)三個(gè)函數(shù):__ init __ 、 __ len __ 、__ getitem __。其中圖像部分存儲(chǔ)于一個(gè)文件夾中,標(biāo)簽單獨(dú)存儲(chǔ)在CSV文件中。
在接下來(lái)的代碼中,讓我們看看如何創(chuàng)建一個(gè)自定義數(shù)據(jù)集。
import os import pandas as pd from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label對(duì)于__ init __ 函數(shù)來(lái)說(shuō),包含加載圖像、注釋文件和兩個(gè)轉(zhuǎn)換的目錄,在這里我們不做過(guò)多講解,后面會(huì)詳細(xì)介紹。
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform對(duì)于__ len __ 函數(shù),其功能是返回?cái)?shù)據(jù)集中的樣本數(shù)。
def __len__(self):return len(self.img_labels)對(duì)于 __ getitem __,其功能是給定索引便能返回對(duì)應(yīng)樣本。
def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label在自定義這一部分不用過(guò)多的去了解,用著用著就會(huì)了,就算不會(huì)代碼也是通用,需要用的時(shí)候看一下復(fù)制一下,別搞得自己這么焦慮。
2.3 準(zhǔn)備數(shù)據(jù)以進(jìn)行數(shù)據(jù)加載器訓(xùn)練
在pytorch中,數(shù)據(jù)加載的核心實(shí)際上是torch.utils.data.DataLoader類,它支持對(duì)torch數(shù)據(jù)集的python可迭代,換而言之,DataLoader相當(dāng)于你拿一個(gè)水盆,而dataset相當(dāng)于泉水。DataLoader可以對(duì)小批量數(shù)據(jù)集進(jìn)行處理,處理內(nèi)容包括:
- 地圖樣式和可迭代樣式的數(shù)據(jù)集
- 自定義數(shù)據(jù)集加載順序
- 多進(jìn)程加載數(shù)據(jù)
- 自動(dòng)內(nèi)存固定
其中地圖樣式數(shù)據(jù)集是指自定義數(shù)據(jù)集,而可迭代樣式數(shù)據(jù)集指的是自帶數(shù)據(jù)集。其他詳情對(duì)于初學(xué)者來(lái)說(shuō)很不友好,這里不做過(guò)多解釋,你可以理解為這就是個(gè)科普知識(shí)。
我們來(lái)看一下這個(gè)API吧。
torch.utils.data.DataLoader(數(shù)據(jù)集, batch_size=1, shuffle=False)
- 用于加載樣本并且進(jìn)行批處理
- 數(shù)據(jù)集:要加載的數(shù)據(jù)集
- batch_size:整數(shù)類型,表明每批要加載的樣本數(shù),默認(rèn)為1
- shuffle:布爾類型,表明是否要洗牌
我們利用上面的API來(lái)加載我們上面的Fashion_MNIST吧。
def load_batch_data():"""數(shù)據(jù)集批處理加載器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloader既然已經(jīng)將樣本導(dǎo)入加載器,那么我們?nèi)绾螐募虞d器中讀取數(shù)據(jù)呢?我們可以根據(jù)需要循環(huán)訪問(wèn)數(shù)據(jù)集。
import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt from torch.utils.data import DataLoaderdef load_data():"""加載數(shù)據(jù)集"""# 1 訓(xùn)練數(shù)據(jù)集的加載train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 測(cè)試數(shù)據(jù)集的加載test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""數(shù)據(jù)集可視化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 從訓(xùn)練集中隨機(jī)抽出九張圖(九個(gè)樣本)for i in range(1, cols * rows + 1):# 設(shè)置索引,索引取值為0到訓(xùn)練集的長(zhǎng)度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出對(duì)應(yīng)樣本的圖片和標(biāo)簽img, label = train_data[sample_idx]# 依次畫于事先指定的九宮格圖上figure.add_subplot(rows, cols, i)# 設(shè)置對(duì)應(yīng)圖片的標(biāo)題plt.title(label_map[label])# 關(guān)掉坐標(biāo)軸plt.axis("off")# 展示圖片plt.imshow(img.squeeze(), cmap="gray")# 釋放畫布plt.show()def load_batch_data():"""數(shù)據(jù)集批處理加載器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloaderdef show_batch_data():"""循環(huán)訪問(wèn)數(shù)據(jù)加載器"""train_dataloader, test_dataloader = load_batch_data()train_feature, train_labels = next(iter(train_dataloader))print(f"特征大小:{train_feature.size()}")print(f"標(biāo)簽大小:{train_labels.size()}")img = train_feature[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"label:{label}")train_data, test_data = load_data() # show_data(train_data) show_batch_data() 創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來(lái)咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)總結(jié)
以上是生活随笔為你收集整理的深度学习修炼(二)——数据集的加载的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。