PyTorch框架:(4)如何去构建数据
接PyTorch框架:(3)
1、最基本的方法
(1)使用模塊
模塊1:TensorDataset、模塊2:DataLoader
自己去構(gòu)造數(shù)據(jù)集,然后一個batch一個batch的取數(shù)據(jù),自己去寫構(gòu)造數(shù)據(jù)太麻煩,可以自動讓其把數(shù)據(jù)源給我們構(gòu)建好,這兩個模塊就是來幫我們完成這個事的。
第一步把x_train和y_train傳進(jìn)去,使用TensorDataset自動的幫我們組件dataset即(train_ds);
DataLoader是得搭配一下,先把數(shù)據(jù)轉(zhuǎn)化為TensorDataset所支持的格式,然后采用DataLoader讀進(jìn)來,DataLoader的意思就是你把數(shù)據(jù)交給我,然后你告訴我一個batch_size有多少,然后你要取數(shù)據(jù)的時候我就幫你一個batch一個batch的取數(shù)據(jù),這樣方便一些。shuffle=True表示要不要重新洗牌;
?(2)定義一個get_data方法,需要傳進(jìn)來當(dāng)前的數(shù)據(jù)集,后邊做了一個return,就是按照一個Batch取數(shù)據(jù)就完事了;
?(3)訓(xùn)練函數(shù)
?自己定義一個訓(xùn)練方法,def fit方法,實際的去執(zhí)行訓(xùn)練的操作。傳進(jìn)來的參數(shù):
steps:一共迭代多少次。
model:就是定義的model,就是自己寫個類,把model傳進(jìn)來。
loss_func:使用的f.中的損失。
opt:優(yōu)化器是什么。
train_dl:實際數(shù)據(jù)傳進(jìn)來。
valid_dl:實際數(shù)據(jù)傳進(jìn)來。
Batch Normalization和Dropout在訓(xùn)練的時候一般都會加這兩項,讓模型過擬合的更低;在測試的時候一般就不加這兩個東西了。所以為了有這兩個區(qū)分,如果此時是訓(xùn)練,那么在訓(xùn)練的時候加上model.train();下邊不是訓(xùn)練就是走一次前向傳播,看一下對于當(dāng)前模型來說他的一個效果,他的損失等于多少,把損失拿過來,我也不需要進(jìn)行參數(shù)更新,不需要計算梯度,也不需要訓(xùn)練的過程,所以這一塊我再額外的指定一下,這塊不需要加Batch Normalization和Dropout,他不是一個訓(xùn)練的過程,所以在前邊加上model.eval()。
所以見到這兩個就是表示:model.train()強調(diào)的是你的訓(xùn)練過程,把該加的加進(jìn)去;model.eval()強調(diào)的是測試過程,只需要得到結(jié)果,不需要把沒用的都加進(jìn)去。
?loss_batch做的事情:如果你傳進(jìn)來一個優(yōu)化器,優(yōu)化器求梯度,求完梯度更新,更新完之后置0,然后返回結(jié)果。這里不光計算一個loss值還要去計算他實際的梯度值是多少,要進(jìn)行參數(shù)的更新。
上述相當(dāng)于把每個模塊都準(zhǔn)備好了,實際訓(xùn)練模型的時候不用把每個函數(shù)都也在一個sell當(dāng)中,下面三行就搞定了:
?第一步:拿到數(shù)據(jù)getdata。
第二步:拿到模型和優(yōu)化器。(模型就是自己的類Mnist_NN)
第三步:執(zhí)行fit函數(shù)。(fit函數(shù)的第三個參數(shù)表示損失函數(shù)是如何計算的,在損失函數(shù)計算當(dāng)中還加入了梯度的更新,第四個使用什么樣的優(yōu)化器去更新我當(dāng)前的結(jié)果)
2、復(fù)雜的方法
暫定
總結(jié)
以上是生活随笔為你收集整理的PyTorch框架:(4)如何去构建数据的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch框架:(3)使用PyTor
- 下一篇: PyTorch框架:(5)使用PyTor