pytorch rnn 实现手写字体识别
生活随笔
收集整理的這篇文章主要介紹了
pytorch rnn 实现手写字体识别
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
pytorch rnn 實現手寫字體識別
- 構建 RNN 代碼
- 加載數據
- 使用RNN 訓練 和測試數據
構建 RNN 代碼
import torch import torch.nn as nn from torch.autograd import Variableimport torch.utils.data as Dataimport torchvisionimport matplotlib.pyplot as plttorch.manual_seed(1)#batch size BATCH_SIZE=50 #學習率 LR= 0.001 DOWNLOAD=False #是否訓練 TRAIN =Falseclass RNN(nn.Module):def __init__(self):super(RNN,self).__init__()'''input_size:輸入特征的數目hidden_size:隱層的特征數目num_layers:這個是模型集成的LSTM的個數 記住這里是模型中有多少個LSTM摞起來 一般默認就1個#batch_first: 輸入數據的size為[batch_size, time_step, input_size]還是[time_step, batch_size, input_size]'''self.rnn= nn.LSTM(input_size=28,hidden_size=64,num_layers=3,batch_first=True #batch_first: 輸入數據的size為[batch_size, time_step, input_size]還是[time_step, batch_size, input_size])self.out = nn.Linear(64,10)self.optimizer = torch.optim.Adam(self.parameters(),lr=LR)self.lossFunc= nn.CrossEntropyLoss()def forward(self,x):#x [ batch,28,28]r_out ,(h_n,h_c)= self.rnn(x,None)#r_out [50,28,64] h_n=[1,50,64] h_c =[1,50,64]#r_out 表示 每一次輸入 28 個像素 輸入了 50* 28 次#h_n 表示 每 28*28 為一次 記錄 隱藏層 為 64 所以為 50,64 每28*28為一個記錄 參數print(r_out.size(), h_n.size(),h_c.size())r_out = self.out(r_out[:,-1,:])return r_outdef lossFunction(self,predict ,batchY):loss = self.lossFunc(predict,batchY)self.optimizer.zero_grad()loss.backward()print("loss==",loss.data)self.optimizer.step()加載數據
tranData = torchvision.datasets.MNIST(root="d:/mnist/",train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD)testData = torchvision.datasets.MNIST(root="d:/mnist/",train=False )trainLoader = Data.DataLoader(dataset=tranData,batch_size=BATCH_SIZE,shuffle=True)# 為了節約時間, 我們測試時只測試前2000個 # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_x = Variable(torch.unsqueeze(testData.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255. test_y = testData.test_labels[:2000]使用RNN 訓練 和測試數據
#構造RNN myRNN= RNN()#訓練數據 if(TRAIN):for epoch in range(3):for step ,(x,y) in enumerate(trainLoader):trainX = Variable(x.view(-1,28,28))print("trainX==",trainX.size())tranY = Variable(y)predict= myRNN(trainX)print("predict==",predict)myRNN.lossFunction(predict,tranY)torch.save(myRNN.state_dict(), "d:/mnist/rnn.pkl") else:myRNN.load_state_dict(torch.load("d:/mnist/rnn.pkl"))#測試數據 testOut = myRNN(test_x[:20].view(-1,28,28))print("testOut==",testOut.size()) #預測值 testPredict = torch.max(testOut,1)[1]print("testPredict==", testPredict.size()) print(testPredict,test_y[:20])總結
以上是生活随笔為你收集整理的pytorch rnn 实现手写字体识别的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 解决通用串行总线控制器里全是叹号问题
- 下一篇: 带 SPI 接口的独立 CAN 控制器,