Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader
torchvision.datasets
由于MNIST數(shù)據(jù)集太簡單,簡單的網(wǎng)絡(luò)就可以達到99%以上的top one準確率,也就是說在這個數(shù)據(jù)集上表現(xiàn)較好的網(wǎng)絡(luò),在別的任務(wù)上表現(xiàn)不一定好。因此zalando research的工作人員建立了fashion mnist數(shù)據(jù)集,該數(shù)據(jù)集由衣服、鞋子等服飾組成,包含70000張圖像,其中60000張訓(xùn)練圖像加10000張測試圖像,圖像大小為28x28,單通道,共分10個類,如下圖,每3行表示一個類。
所以我們通過torchvison來處理FashionMNIST數(shù)據(jù)集:
這樣我們就完成了FashionMNIST數(shù)據(jù)的提取和轉(zhuǎn)換。
如果這個過程中報錯:ImportError: IProgress not found. Please update jupyter and ipywidgets.。一般是jupyter的版本有些低了,可能是你默認的環(huán)境,所以重裝以下就好了:
# 可以先用你的環(huán)境 conda activate xx # 卸載jupyter: pip install --upgrade jupyter訪問單獨某個訓(xùn)練數(shù)據(jù):
torchvision.dataloader
dataloader使我們能夠訪問數(shù)據(jù)并提供查詢功能。
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
通過train_loader方式得到的batch包含圖像的張量是4維的張量,形狀是[10, 1, 28, 28],這告訴我們有10個圖像,他們都有1個單獨的顏色通道,高度寬度都是28;對于包含標(biāo)簽的張量,他的長度是10,每10個圖像為一批數(shù)據(jù)。
現(xiàn)在讓我們看看如何使用torchvison.utils.make_grid函數(shù)一次性的畫出整批圖像:
我們可以看到,我們已經(jīng)使用torchvision.utils.make_grid函數(shù)創(chuàng)建了一個網(wǎng)絡(luò),我們把圖像張量作為第一個參數(shù),nrow=10這樣我們所有的圖像就會沿著一行顯示,nrow參數(shù)指定每一行的圖像數(shù)量,因為我們的batch_size=10,這就給我們了一排圖像,我們使用np.transpose(grid, (1,2,0)),這樣軸就滿足了圖像的功能需要的規(guī)格。
現(xiàn)在我們知道了dataset和dataloader之間如何交互的了。現(xiàn)在試試如何批量處理數(shù)據(jù):
總結(jié)
以上是生活随笔為你收集整理的Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 夏普电视画面如何设置
- 下一篇: sklearn 学习曲线Learning