【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型
「@Author:Runsen」
當閱讀一篇課文時,我們可以根據前面的單詞來理解每個單詞的,而不是從零開始理解每個單詞。這可以稱為記憶。卷積神經網絡模型(CNN)不能實現這種記憶,因此引入了遞歸神經網絡模型(RNN)來解決這一問題。RNN是帶有循環的網絡,允許信息持久存在。
RNN的應用有:
情緒分析(多對一,順序輸入)
機器翻譯(多對多,順序輸入和順序輸出)
語音識別(多對多) 它被廣泛地用于處理序列數據的預測和自然語言處理。針對Vanilla-RNN存在短時記憶(梯度消失問題),引入LSTM和GRU來解決這一問題。特別是LSTM被廣泛應用于深度學習模型中。
本博客介紹了如何通過PyTorch實現RNN和LSTM,并將其應用于比特幣價格預測。
import?torch import?torch.nn?as?nn import?torch.optim?as?optim from?torch.autograd?import?Variable import?torch.utils.data?as?Data from?torch.utils.data?import?DataLoaderimport?torchvision import?torchvision.datasets?as?datasets import?torchvision.transforms?as?transforms import?torchvision.utils?as?vutilsimport?numpy?as?np import?pandas?as?pd import?matplotlib.pyplot?as?plt import?os device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu') print(device)?#cuda比特幣歷史數據集
將通過比較RNN lstm 的性能,來處理時間序列數據,而不是語言數據。使用的數據來自Kaggle比特幣歷史數據. 比特幣是一種加密區塊鏈貨幣。
data?=?pd.read_csv("EtherPriceHistory(USD).csv") data.tail()然后我將繪制數據,以查看比特幣價格的趨勢。
plt.figure(figsize?=?(12,?8)) plt.plot(data["Date(UTC)"],?data["Value"]) plt.xlabel("Date") plt.ylabel("Price") plt.title("Ethereum?Price?History") plt.show() #?Hyper?parameters threshold?=?116 window?=?30input_size?=?1 hidden_size?=?50 num_layers?=?3 output_size?=?1learning_rate?=?0.001 batch_size?=?16train_data?=?data['Value'][:len(data)?-?threshold] test_data?=?data['Value'][len(data)?-?threshold:]下面的函數是生成一個滑動窗口,create_sequences掃描所有的訓練數據。
def?create_sequences(input_data,?window):length?=?len(input_data)x?=?input_data[0:window].valuesy?=?input_data[1:window+1].valuesfor?i?in?range(1,?length?-?window):x?=?np.vstack((x,?input_data[i:i+window].values))y?=?np.vstack((y,?input_data[i+1:window+1+i].values))sequence?=?torch.from_numpy(x).type(torch.FloatTensor)label?=?torch.from_numpy(y).type(torch.FloatTensor)sequence?=?Data.TensorDataset(sequence,?label)return?sequence train_data?=?create_sequences(train_data,?window) train_loader?=?Data.DataLoader(train_data,?batch_size?=?batch_size,?shuffle?=?False,?drop_last?=?True)建立RNN神經網絡模型
class?RNN(nn.Module):def?__init__(self,?input_size,?hidden_size,?num_layers,?output_size):super(RNN,?self).__init__()self.input_size?=?input_sizeself.hidden_size?=?hidden_sizeself.num_layers?=?num_layersself.output_size?=?output_sizeself.hidden?=?torch.zeros(num_layers,?1,?hidden_size)self.rnn?=?nn.RNN(input_size,?hidden_size,?num_layers,?????????????#?number?of?recurrent?layersbatch_first?=?True,????#?Default:?False#?If?True,?layer?does?not?use?bias?weightsnonlinearity?=?'relu',??#?'tanh'?or?'relu'#dropout?=?0.5)self.fc?=?nn.Linear(hidden_size,?output_size)def?forward(self,?x):#?input?shape?of?(batch,?seq_len,?input_size)#?output?shape?of?(batch,?seq_len,?hidden_size)out,?hidden?=?self.rnn(x,?self.hidden)self.hidden?=?hidden#?output?shape?of?(batch_,?seq_len,?output_size)out?=?self.fc(out)return?outdef?init_hidden(self,?batch_size):#?hidden?shape?of?(num_layers,?batch,?hidden_size)self.hidden?=?torch.zeros(self.num_layers,?batch_size,?self.hidden_size) rnn?=?RNN(input_size,?hidden_size,?num_layers,?output_size).to(device) rnnMSELoss表示均方損失,Adam表示學習率為0.001的Adam優化器。與CNN模型的訓練不同,添加了nn.utils.clip_grad_norm_來防止梯度爆炸問題。
def?train(model,?num_epochs):criterion?=?nn.MSELoss()optimizer?=?optim.Adam(model.parameters(),?lr?=?learning_rate)for?epoch?in?range(num_epochs):for?i,?(sequences,?labels)?in?enumerate(train_loader):model.init_hidden(batch_size)sequences?=?sequences.view(-1,?window,?1)labels?=?labels.view(-1,?window,?1)pred?=?model(sequences)cost?=?criterion(pred[-1],?labels[-1])optimizer.zero_grad()cost.backward()#防止梯度爆炸問題nn.utils.clip_grad_norm_(model.parameters(),?5)optimizer.step()print("Epoch?[%d/%d]?Loss?%.4f"%(epoch+1,?num_epochs,?cost.item()))print("Training?Finished!")train(rnn,?10) def?evaluation(model):model.eval()model.init_hidden(1)val_day?=?30dates?=?data['Date(UTC)'][1049+window:1049+window+val_day]pred_X?=?[]for?i?in?range(val_day):X?=?torch.from_numpy(test_data[i:window+i].values).type(torch.FloatTensor)X?=?X.view(1,?window,?1).to(device)pred?=?model(X)pred?=?pred.reshape(-1)pred?=?pred.cpu().data.numpy()pred_X.append(pred[-1])y?=?test_data[window:window+val_day].valuesplt.figure(figsize?=?(12,?8))plt.plot(dates,?y,?'o-',?alpha?=?0.7,?label?=?'Real')plt.plot(dates,?pred,?'*-',?alpha?=?0.7,?label?=?'Predict')plt.xticks(rotation?=?45)plt.xlabel("Date")plt.ylabel("Ethereum?Price?(USD)")plt.legend()plt.title("Comparison?between?Prediction?and?Real?Ethereum?BitCoin?Price")plt.show()預測價格大致遵循價格變動趨勢,但價格絕對值與實際價格相差不大。因此,考慮到價格的巨大變化,但實際它的預測并不壞。可以通過修改模型參數和超參數來改進。
#?Save?the?model?checkpoint save_path?=?'./model/'if?not?os.path.exists(save_path):os.makedirs(save_path)torch.save(rnn.state_dict(),?'rnn.ckpt') 往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯黃海廣老師《機器學習課程》課件合集 本站qq群851320808,加入微信群請掃碼:總結
以上是生活随笔為你收集整理的【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Windows平台RTMP推送|轻量级R
- 下一篇: PP视频如何设置默认缓存个数