PyTorch基础(四)-----数据加载和预处理
前言
之前已經(jīng)簡(jiǎn)單講述了PyTorch的Tensor、Autograd、torch.nn和torch.optim包,通過(guò)這些我們已經(jīng)可以簡(jiǎn)單的搭建一個(gè)網(wǎng)絡(luò)模型,但這是不夠的,我們還需要大量的數(shù)據(jù),眾所周知,數(shù)據(jù)是深度學(xué)習(xí)的靈魂,深度學(xué)習(xí)的模型是由數(shù)據(jù)“喂”出來(lái)的,這篇我們來(lái)講述一下數(shù)據(jù)的加載和預(yù)處理。
- 首先,我們要引入torch包
一、數(shù)據(jù)的加載
PyTorch通過(guò)torch.utils.data對(duì)一般常用的數(shù)據(jù)加載進(jìn)行了封裝,可以很容易地實(shí)現(xiàn)多線程數(shù)據(jù)預(yù)讀和批量加載。
1.1 Dataset
Dataset是一個(gè)抽象類(lèi),為了能夠方便的讀取,需要將要使用的數(shù)據(jù)包裝為Dataset類(lèi)。自定義的Dataset類(lèi)需要繼承它并且實(shí)現(xiàn)2個(gè)成員方法:
- 1.__getitem__():該方法定義用索引(0-len(self))獲取一條數(shù)據(jù)或一個(gè)樣本
- 2.__len__():該方法返回?cái)?shù)據(jù)集的總長(zhǎng)度
下面我們使用Kaggle上的一個(gè)競(jìng)賽bluebook for bulldozers自定義一個(gè)數(shù)據(jù)集,為了方便介紹,我們使用里面的數(shù)據(jù)字典來(lái)做說(shuō)明
- 首先,我們需要引用相關(guān)的包
- 自定義一個(gè)數(shù)據(jù)集
- 至此,我們的數(shù)據(jù)集已經(jīng)定義完成了,我們可以實(shí)例化一個(gè)對(duì)象來(lái)訪問(wèn)
- 我們可以直接使用如下命令查看數(shù)據(jù)集數(shù)據(jù)
- 使用索引可以直接訪問(wèn)對(duì)應(yīng)的數(shù)據(jù)
自定義的數(shù)據(jù)集已經(jīng)創(chuàng)建好了,下面我們使用官方提供的數(shù)據(jù)載入器,讀取數(shù)據(jù)
1.2 DataLoader
DataLoader為我們提供了對(duì)Dataset的讀取操作,常用參數(shù)有:batch_size(每個(gè)batch的大小)、shuffle(是否進(jìn)行shuffle操作)、num_workers(加載數(shù)據(jù)時(shí)使用幾個(gè)子進(jìn)程)。下面做一個(gè)簡(jiǎn)單的演示:
dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)DataLoader返回的是一個(gè)可迭代對(duì)象,我們可以使用迭代器分次獲取數(shù)據(jù)
idata=iter(dl) print(next(idata))常見(jiàn)的用法是使用for循環(huán)對(duì)其進(jìn)行遍歷
for i, data in enumerate(dl):print(i,data)# 為了節(jié)約空間,這里只循環(huán)一遍break至此,我們已經(jīng)可以通過(guò)dataset定義數(shù)據(jù)集,并使用DataLorder載入和遍歷數(shù)據(jù)集。
二、torchvision包
torchvision 是PyTorch中專(zhuān)門(mén)用來(lái)處理圖像的庫(kù),PyTorch官網(wǎng)的安裝教程中最后的pip install torchvision 就是安裝這個(gè)包。
torchvision已經(jīng)預(yù)先實(shí)現(xiàn)了常用圖像數(shù)據(jù)集,包括前面使用過(guò)的CIFAR-10,ImageNet、COCO、MNIST、LSUN等數(shù)據(jù)集,可通過(guò)torchvision.datasets方便的調(diào)用。
- 這里總結(jié)一下torchvision已經(jīng)預(yù)裝的數(shù)據(jù)集:
| MNIST |
| COCO |
| CIFAR-10 |
| ImageNet |
| Captions |
| Detection |
| LSUN |
| ImageFolder |
| Imagenet-12 |
| STL10 |
| SVHN |
| PhotoTour |
PyTorch中自帶的數(shù)據(jù)集由2個(gè)上層api提供,分別是torchvision和torchtext
- torchvision提供了對(duì)圖像數(shù)據(jù)處理的相關(guān)數(shù)據(jù)和api
- 數(shù)據(jù)位置:torchvision.datasets;例如:torchvision.datasets.MNIST
- torchtext提供了對(duì)文本數(shù)據(jù)處理的相關(guān)數(shù)據(jù)和api
- 數(shù)據(jù)位置:torchtext.datasets;例如:torchtext.datasets.IMDB
下面我們做一個(gè)簡(jiǎn)單的演示
- 首先,我們要引入torchvision包
2.1 torchvision.models
torchvision不僅提供了常用的圖像數(shù)據(jù)集,而且還提供了一些訓(xùn)練好的網(wǎng)絡(luò)模型,可以加載之后直接使用,或者繼續(xù)進(jìn)行遷移學(xué)習(xí)。torchvision.models模塊的子模塊中包含以下模型:
| AlexNet |
| VGG |
| ResNet |
| SqueezeNet |
| DenseNet |
我們直接可以使用訓(xùn)練好的模型,當(dāng)然這個(gè)與datasets相同,都是需要從服務(wù)器下載的。
- 首先,我們需要導(dǎo)入torchvision.models
- 直接使用
2.2 torchvision.tranforms
transforms 模塊提供了一般的圖像轉(zhuǎn)換操作類(lèi),用作數(shù)據(jù)處理和數(shù)據(jù)增強(qiáng)
- 首先,我們需要引入torchvision.tranforms,然后做一個(gè)簡(jiǎn)單的演示
肯定有人會(huì)問(wèn):(0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) 這幾個(gè)數(shù)字是什么意思?
官方的這個(gè)帖子有詳細(xì)的說(shuō)明: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21 這些都是根據(jù)ImageNet訓(xùn)練的歸一化參數(shù),可以直接使用,我們認(rèn)為這個(gè)是固定值就可以。
到這里,我們已經(jīng)完成了PyTorch的基本內(nèi)容介紹。
參考文獻(xiàn)
https://github.com/zergtant/pytorch-handbook/blob/master/chapter2
總結(jié)
以上是生活随笔為你收集整理的PyTorch基础(四)-----数据加载和预处理的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 手把手教你做用户画像
- 下一篇: BRD、MRD 和 PRD 之间的区别与