Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader
torchvision.datasets
由于MNIST數據集太簡單,簡單的網絡就可以達到99%以上的top one準確率,也就是說在這個數據集上表現較好的網絡,在別的任務上表現不一定好。因此zalando research的工作人員建立了fashion mnist數據集,該數據集由衣服、鞋子等服飾組成,包含70000張圖像,其中60000張訓練圖像加10000張測試圖像,圖像大小為28x28,單通道,共分10個類,如下圖,每3行表示一個類。
所以我們通過torchvison來處理FashionMNIST數據集:
這樣我們就完成了FashionMNIST數據的提取和轉換。
如果這個過程中報錯:ImportError: IProgress not found. Please update jupyter and ipywidgets.。一般是jupyter的版本有些低了,可能是你默認的環境,所以重裝以下就好了:
# 可以先用你的環境 conda activate xx # 卸載jupyter: pip install --upgrade jupyter訪問單獨某個訓練數據:
torchvision.dataloader
dataloader使我們能夠訪問數據并提供查詢功能。
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
通過train_loader方式得到的batch包含圖像的張量是4維的張量,形狀是[10, 1, 28, 28],這告訴我們有10個圖像,他們都有1個單獨的顏色通道,高度寬度都是28;對于包含標簽的張量,他的長度是10,每10個圖像為一批數據。
現在讓我們看看如何使用torchvison.utils.make_grid函數一次性的畫出整批圖像:
我們可以看到,我們已經使用torchvision.utils.make_grid函數創建了一個網絡,我們把圖像張量作為第一個參數,nrow=10這樣我們所有的圖像就會沿著一行顯示,nrow參數指定每一行的圖像數量,因為我們的batch_size=10,這就給我們了一排圖像,我們使用np.transpose(grid, (1,2,0)),這樣軸就滿足了圖像的功能需要的規格。
現在我們知道了dataset和dataloader之間如何交互的了。現在試試如何批量處理數據:
總結
以上是生活随笔為你收集整理的Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 夏普电视画面如何设置
- 下一篇: sklearn 学习曲线Learning