當(dāng)前位置:
首頁 >
RNN代码解释pytorch
發(fā)布時間:2025/4/16
48
豆豆
生活随笔
收集整理的這篇文章主要介紹了
RNN代码解释pytorch
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
簡述
還是跟之前的CNN一樣,都是學(xué)于莫煩Python的。
解釋
- 關(guān)于數(shù)據(jù)導(dǎo)入部分的代碼含義,其實(shí)跟之前的CNN幾乎完全一致。
- 而且還需要部分的源代碼–MNIST(在之前的地方有超鏈接)
- 這些都可以在下面的CNN的鏈接中看到
- 卷積神經(jīng)網(wǎng)絡(luò)CNN入門【pytorch學(xué)習(xí)】
模型含義
這里使用RNN,這是跟之前的CNN唯一的不同的地方,其他的都是完全一致的。
class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=28,hidden_size=64,num_layers=1,batch_first=True)self.out = nn.Linear(64, 10) # fully connected layer, output 10 classesdef forward(self, x):r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 會用全0的 state# r_out = [BATCH_SIZE, input_size, hidden_size]# r_out[:, -1, :] = [BATCH_SIZE, hidden_size] '-1',表示選取最后一個時間點(diǎn)的 r_out 輸出out = self.out(r_out[:, -1, :])# out = [BATCH_SIZE, 10]return outrnn = RNN()LSTM參數(shù)解釋
- 輸入?yún)?shù),其實(shí)是表示有多少序列。這里的最小單位,考慮的其實(shí)不是整個圖片的完整全部序列。而是每一行為最小單位的。
- 所以說經(jīng)過LSTM之后,輸出的結(jié)果就是r_out = [BATCH_SIZE, input_size, hidden_size]。 第一個input_size其實(shí)是恰好這個圖片大小是(input_size, input_size)的
out中輸入的有-1
- 會發(fā)現(xiàn)這里有一個數(shù)字-1,其實(shí)就是表示要選最后的一列作為最后的結(jié)果。其實(shí)就是說只看最后的一行。
完整代碼
import osimport torch import torch.nn as nn import torch.utils.data as Data import torchvision# Hyper Parameters EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 50 LR = 0.001 # learning rate DOWNLOAD_MNIST = False# Mnist digits dataset if not (os.path.exists('./mnist/')) or not os.listdir('./mnist/'):# not mnist dir or mnist is empyt dirDOWNLOAD_MNIST = Truetrain_data = torchvision.datasets.MNIST(root='./mnist/',train=True, # this is training datatransform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]download=DOWNLOAD_MNIST, )# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# pick 2000 samples to speed up testing test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000] / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) test_y = test_data.test_labels[:2000]class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=28,hidden_size=64,num_layers=1,batch_first=True)self.out = nn.Linear(64, 10) # fully connected layer, output 10 classesdef forward(self, x):r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 會用全0的 state# r_out = [BATCH_SIZE, input_size, hidden_size]# r_out[:, -1, :] = [BATCH_SIZE, hidden_size] '-1',表示選取最后一個時間點(diǎn)的 r_out 輸出out = self.out(r_out[:, -1, :])# out = [BATCH_SIZE, 10]return outrnn = RNN()optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.CrossEntropyLoss() # the target label is not one-hottedfor epoch in range(EPOCH):for step, (x, b_y) in enumerate(train_loader): # gives batch datab_x = x.view(-1, 28, 28) # reshape x to (batch, time_step, input_size)output = rnn(b_x) # rnn outputloss = loss_func(output, b_y) # cross entropy lossoptimizer.zero_grad() # clear gradients for this training steploss.backward() # backpropagation, compute gradientsoptimizer.step()test_output = rnn(test_x[:10].view(-1, 28, 28)) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()print(pred_y, 'prediction number') print(test_y[:10], 'real number')- 結(jié)果:
總結(jié)
以上是生活随笔為你收集整理的RNN代码解释pytorch的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 创建表名作为参数的mysq存储过程【pr
- 下一篇: 【解决办法】pandas画出时序数据(股