PyTorch数据加载处理
PyTorch數(shù)據(jù)加載處理
PyTorch提供了許多工具來簡(jiǎn)化和希望數(shù)據(jù)加載,使代碼更具可讀性。
1.下載安裝包
? scikit-image:用于圖像的IO和變換
? pandas:用于更容易地進(jìn)行csv解析
from future import print_function, division
import os
import torch
import pandas as pd #用于更容易地進(jìn)行csv解析
from skimage import io, transform #用于圖像的IO和變換
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
忽略警告
import warnings
warnings.filterwarnings(“ignore”)
plt.ion() # interactive mode
2.下載數(shù)據(jù)集
從此處下載數(shù)據(jù)集, 數(shù)據(jù)存于“data / faces /”的目錄中。這個(gè)數(shù)據(jù)集實(shí)際上是imagenet數(shù)據(jù)集標(biāo)注為face的圖片當(dāng)中在 dlib 面部檢測(cè) (dlib’s pose estimation) 表現(xiàn)良好的圖片。要處理的是一個(gè)面部姿態(tài)的數(shù)據(jù)集。也就是按如下方式標(biāo)注的人臉:
2.1 數(shù)據(jù)集注釋
數(shù)據(jù)集是按如下規(guī)則打包成的csv文件:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, … ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, … 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, … ,128,312
3.讀取數(shù)據(jù)集
將csv中的標(biāo)注點(diǎn)數(shù)據(jù)讀入(N,2)數(shù)組中,其中N是特征點(diǎn)的數(shù)量。讀取數(shù)據(jù)代碼如下:
landmarks_frame = pd.read_csv(‘data/faces/face_landmarks.csv’)
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype(‘float’).reshape(-1, 2)
print(‘Image name: {}’.format(img_name))
print(‘Landmarks shape: {}’.format(landmarks.shape))
print(‘First 4 Landmarks: {}’.format(landmarks[:4]))
3.1 數(shù)據(jù)結(jié)果
輸出:
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
[33. 76.]
[34. 86.]
[34. 97.]]
4 編寫函數(shù)
寫一個(gè)簡(jiǎn)單的函數(shù),來展示一張圖片和對(duì)應(yīng)的標(biāo)注點(diǎn)作為例子。
def show_landmarks(image, landmarks):
“”“顯示帶有地標(biāo)的圖片”""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker=’.’, c=‘r’)
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join(‘data/faces/’, img_name)),
landmarks)
plt.show()
函數(shù)展示結(jié)果如下圖所示:
5.數(shù)據(jù)集類
torch.utils.data.Dataset是表示數(shù)據(jù)集的抽象類,因此自定義數(shù)據(jù)集應(yīng)繼承Dataset并覆蓋以下方法 * len 實(shí)現(xiàn) len(dataset) 返還數(shù)據(jù)集的尺寸。 * __getitem__用來獲取一些索引數(shù)據(jù),例如 dataset[i] 中的(i)。
5.1 建立數(shù)據(jù)集類
為面部數(shù)據(jù)集創(chuàng)建一個(gè)數(shù)據(jù)集類。將在 __init__中讀取csv的文件內(nèi)容,在 __getitem__中讀取圖片。這么做是為了節(jié)省內(nèi)存空間。只有在需要用到圖片的時(shí)候才讀取,而不是一開始就把圖片全部存進(jìn)內(nèi)存里。
數(shù)據(jù)樣本將按這樣一個(gè)字典{‘image’: image, ‘landmarks’: landmarks}組織。 數(shù)據(jù)集類將添加一個(gè)可選參數(shù)transform ,以方便對(duì)樣本進(jìn)行預(yù)處理。下面會(huì)看到,什么時(shí)候需要用到transform參數(shù)。 __init__方法如下圖所示:
class FaceLandmarksDataset(Dataset):
“”“面部標(biāo)記數(shù)據(jù)集.”""
def __init__(self, csv_file, root_dir, transform=None):"""csv_file(string):帶注釋的csv文件的路徑。root_dir(string):包含所有圖像的目錄。transform(callable, optional):一個(gè)樣本上的可用的可選變換"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)landmarks = self.landmarks_frame.iloc[idx, 1:]landmarks = np.array([landmarks])landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample
6.數(shù)據(jù)可視化
實(shí)例化這個(gè)類并遍歷數(shù)據(jù)樣本。將會(huì)打印出前四個(gè)例子的尺寸,并展示標(biāo)注的特征點(diǎn)。 代碼如下圖所示:
face_dataset = FaceLandmarksDataset(csv_file=‘data/faces/face_landmarks.csv’,
root_dir=‘data/faces/’)
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]
print(i, sample['image'].shape, sample['landmarks'].shape)ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)if i == 3:plt.show()break
數(shù)據(jù)結(jié)果:
6.1 圖形展示結(jié)果
6.2 控制臺(tái)輸出結(jié)果:
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)
7.數(shù)據(jù)變換
通過上面的例子會(huì)發(fā)現(xiàn)圖片并不是同樣的尺寸。絕大多數(shù)神經(jīng)網(wǎng)絡(luò)都假定圖片的尺寸相同。因此需要做一些預(yù)處理。創(chuàng)建三個(gè)轉(zhuǎn)換: * Rescale:縮放圖片 * RandomCrop:對(duì)圖片進(jìn)行隨機(jī)裁剪。這是一種數(shù)據(jù)增強(qiáng)操作 * ToTensor:把numpy格式圖片轉(zhuǎn)為torch格式圖片 (需要交換坐標(biāo)軸)。
把們寫成可調(diào)用的類的形式,而不是簡(jiǎn)單的函數(shù),這樣就不需要每次調(diào)用時(shí)傳遞一遍參數(shù)。只需要實(shí)現(xiàn)__call__方法,必 要的時(shí)候?qū)崿F(xiàn) __init__方法。可以這樣調(diào)用這些轉(zhuǎn)換:
tsfm = Transform(params)
transformed_sample = tsfm(sample)
觀察下面這些轉(zhuǎn)換是如何應(yīng)用在圖像和標(biāo)簽上的。
class Rescale(object):
“”"將樣本中的圖像重新縮放到給定大小。.
Args:output_size(tuple或int):所需的輸出大小。 如果是元組,則輸出為與output_size匹配。 如果是int,則匹配較小的圖像邊緣到output_size保持縱橫比相同。
"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# h and w are swapped for landmarks because for images,# x and y axes are axis 1 and 0 respectivelylandmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
“”"隨機(jī)裁剪樣本中的圖像.
Args:output_size(tuple或int):所需的輸出大小。 如果是int,方形裁剪是。
"""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
“”“將樣本中的ndarrays轉(zhuǎn)換為Tensors.”""
def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# 交換顏色軸因?yàn)? numpy包的圖片是: H * W * C# torch包的圖片是: C * H * Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}
8.組合轉(zhuǎn)換
接下來把這些轉(zhuǎn)換應(yīng)用到一個(gè)例子上。
要把圖像的短邊調(diào)整為256,然后隨機(jī)裁剪(randomcrop)為224大小的正方形。也就是說,打算組合一個(gè)Rescale和 RandomCrop的變換。 可以調(diào)用一個(gè)簡(jiǎn)單的類 torchvision.transforms.Compose來實(shí)現(xiàn)這一操作。具體實(shí)現(xiàn)如下圖:
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
在樣本上應(yīng)用上述的每個(gè)變換。
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
transformed_sample = tsfrm(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()
? 輸出效果:
9.迭代數(shù)據(jù)集
讓把這些整合起來以創(chuàng)建一個(gè)帶組合轉(zhuǎn)換的數(shù)據(jù)集。每次這個(gè)數(shù)據(jù)集被采樣時(shí): * 及時(shí)地從文件中讀取圖片 * 對(duì)讀取的圖片應(yīng)用轉(zhuǎn)換 * 。由于其中一步操作是隨機(jī)的 (randomcrop) , 數(shù)據(jù)被增強(qiáng)了。
可以像之前那樣,使用for i in range循環(huán),來對(duì)所有創(chuàng)建的數(shù)據(jù)集執(zhí)行同樣的操作。
transformed_dataset = FaceLandmarksDataset(csv_file=‘data/faces/face_landmarks.csv’,
root_dir=‘data/faces/’,
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:break
? 輸出結(jié)果:
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])
但是,對(duì)所有數(shù)據(jù)集簡(jiǎn)單的使用for循環(huán)犧牲了許多功能,尤其是: * 批量處理數(shù)據(jù) * 打亂數(shù)據(jù) * 使用多線程multiprocessingworker 并行加載數(shù)據(jù)。
torch.utils.data.DataLoader是一個(gè)提供上述所有這些功能的迭代器。下面使用的參數(shù)必須是清楚的。一個(gè)值得關(guān)注的參數(shù)是collate_fn, 可以通過來決定如何對(duì)數(shù)據(jù)進(jìn)行批處理。但是絕大多數(shù)情況下默認(rèn)值就能運(yùn)行良好。
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
輔助功能:顯示批次
def show_landmarks_batch(sample_batched):
“”“Show image with landmarks for a batch of samples.”""
images_batch, landmarks_batch =
sample_batched[‘image’], sample_batched[‘landmarks’]
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,landmarks_batch[i, :, 1].numpy() + grid_border_size,s=10, marker='.', c='r')plt.title('Batch from dataloader')
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched[‘image’].size(),
sample_batched[‘landmarks’].size())
# 觀察第4批次并停止。
if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break
? 輸出
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
10.后記:torchvision
本文學(xué)習(xí)了如何構(gòu)造和使用數(shù)據(jù)集類(datasets),轉(zhuǎn)換(transforms)和數(shù)據(jù)加載器(dataloader)。torchvision包提供了常用的數(shù)據(jù)集類(datasets)和轉(zhuǎn)換(transforms)。可能不需要自己構(gòu)造這些類。torchvision中還有一個(gè)更常用的數(shù)據(jù)集類ImageFolder。 假定了數(shù)據(jù)集是以如下方式構(gòu)造的:
root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png
其中’ants’,bees’等是分類標(biāo)簽。在PIL.Image中,也可以使用類似的轉(zhuǎn)換(transforms)例如RandomHorizontalFlip,Scale。利用這些可以按如下的方式創(chuàng)建一個(gè)數(shù)據(jù)加載器(dataloader) :
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root=‘hymenoptera_data/train’,
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
總結(jié)
以上是生活随笔為你收集整理的PyTorch数据加载处理的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch 数据并行处理
- 下一篇: VGG16迁移学习实现