keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测
生活随笔
收集整理的這篇文章主要介紹了
keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
一.概述
傳統循環網絡RNN可以通過記憶體實現短期記憶進行連續數據的預測,但是,當連續數據的序列邊長時,會使展開時間步過長,在反向傳播更新參數的過程中,梯度要按時間步連續相乘,會導致梯度消失或者梯度爆炸。
LSTM是RNN的變體,通過門結構,有效的解決了梯度爆炸或者梯度消失問題。
LSTM在RNN的基礎上引入了三個門結構和記錄長期記憶的細胞態以及歸納出新知識的候選態。
二.LSTM結構
1.短期記憶
短期記憶即為RNN中的記憶體,在LSTM中,它的通過輸出門
和經過tanh函數的長期記憶的哈達瑪積得到:2.細胞態(長期記憶)
長期記憶記錄了當前時刻的歷史信息:
其中,
為上一時刻的長期記憶, 為遺忘門, 為輸入門, 為候選狀態,表示在本時間段歸納出的新知識:3.輸入門、遺忘門、輸出門
它們三個都是當前時刻的輸入特征
和上個時刻的短期記憶 的函數。遺忘門通過sigmod函數,將上一層隱藏狀態
和本層輸入 映射到[0,1],表示上一層的內部狀態 需要遺忘多少信息,公式為下:輸入門
控制當前候選狀態 有多少信息需要保存。輸出門
控制當前時刻的內部狀態 有多少信息傳遞給隱藏信息 。三.LSTM過程
1.先利用上一時刻的隱藏狀態
和當前輸入計算出三個門和候選狀態:2.結合遺忘門
和輸入門 更新長期記憶:3.結合輸出門和內部狀態更新隱藏狀態:
4.反向傳播,利用梯度下降等優化方法更新參數矩陣和偏置。
四.keras+LSTM實現股票預測
導入依賴包
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import pandas as pd from tensorflow.keras.layers import Dense,Dropout,LSTM from sklearn.preprocessing import MinMaxScaler from sklearn.metrics import mean_absolute_error,mean_squared_error讀取數據
maotai = pd.read_csv('./SH600519.csv') training_set = maotai.iloc[0:2126,2:3].values test_set = maotai.iloc[2126:,2:3].values print(training_set.shape,test_set.shape) 輸出: (2126, 1) (300, 1)歸一化
sc = MinMaxScaler(feature_range=(0,1)) training_set = sc.fit_transform(training_set) test_set = sc.fit_transform(test_set)劃分訓練數據和測試數據
x_train,y_train,x_test,y_test=[],[],[],[] for i in range(60,len(training_set)):x_train.append(training_set[i-60:i,0])y_train.append(training_set[i,0]) np.random.seed(7) np.random.shuffle(x_train) np.random.seed(7) np.random.shuffle(y_train) tf.random.set_seed(7) x_train,y_train = np.array(x_train),np.array(y_train) x_train = np.reshape(x_train, (x_train.shape[0], 60, 1)) for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0]) x_test, y_test = np.array(x_test), np.array(y_test) x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))搭建網絡
model = tf.keras.Sequential([LSTM(80,return_sequences=True),Dropout(0.2),LSTM(100),Dropout(0.2),Dense(1) ])配置網絡
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')開始訓練
history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1)訓練過程
Epoch 1/50 33/33 [==============================] - 4s 114ms/step - loss: 0.0135 - val_loss: 0.0110 Epoch 2/50 33/33 [==============================] - 3s 95ms/step - loss: 0.0013 - val_loss: 0.0049 Epoch 3/50 33/33 [==============================] - 3s 99ms/step - loss: 0.0011 - val_loss: 0.0051 Epoch 4/50 33/33 [==============================] - 3s 98ms/step - loss: 0.0013 - val_loss: 0.0057 Epoch 5/50 33/33 [==============================] - 3s 95ms/step - loss: 0.0011 - val_loss: 0.0047 Epoch 6/50 33/33 [==============================] - 3s 99ms/step - loss: 0.0011 - val_loss: 0.0046 Epoch 7/50 33/33 [==============================] - 3s 92ms/step - loss: 0.0011 - val_loss: 0.0046 Epoch 8/50 33/33 [==============================] - 3s 86ms/step - loss: 0.0010 - val_loss: 0.0049 Epoch 9/50 33/33 [==============================] - 3s 84ms/step - loss: 0.0010 - val_loss: 0.0051 Epoch 10/50 33/33 [==============================] - 3s 86ms/step - loss: 0.0010 - val_loss: 0.0051 Epoch 11/50 33/33 [==============================] - 3s 86ms/step - loss: 9.7592e-04 - val_loss: 0.0044 Epoch 12/50 33/33 [==============================] - 3s 87ms/step - loss: 9.6163e-04 - val_loss: 0.0043 Epoch 13/50 33/33 [==============================] - 3s 88ms/step - loss: 0.0011 - val_loss: 0.0041 Epoch 14/50 33/33 [==============================] - 3s 89ms/step - loss: 9.1143e-04 - val_loss: 0.0042 Epoch 15/50 33/33 [==============================] - 3s 89ms/step - loss: 0.0011 - val_loss: 0.0046 Epoch 16/50 33/33 [==============================] - 3s 89ms/step - loss: 8.8493e-04 - val_loss: 0.0040 Epoch 17/50 33/33 [==============================] - 3s 90ms/step - loss: 9.2448e-04 - val_loss: 0.0042 Epoch 18/50 33/33 [==============================] - 3s 91ms/step - loss: 8.7795e-04 - val_loss: 0.0038 Epoch 19/50 33/33 [==============================] - 3s 91ms/step - loss: 7.1217e-04 - val_loss: 0.0045 Epoch 20/50 33/33 [==============================] - 3s 91ms/step - loss: 0.0012 - val_loss: 0.0038 Epoch 21/50 33/33 [==============================] - 3s 93ms/step - loss: 8.5274e-04 - val_loss: 0.0037 Epoch 22/50 33/33 [==============================] - 3s 92ms/step - loss: 9.9773e-04 - val_loss: 0.0052 Epoch 23/50 33/33 [==============================] - 3s 93ms/step - loss: 9.0810e-04 - val_loss: 0.0046 Epoch 24/50 33/33 [==============================] - 3s 93ms/step - loss: 8.4353e-04 - val_loss: 0.0041 Epoch 25/50 33/33 [==============================] - 3s 95ms/step - loss: 8.7846e-04 - val_loss: 0.0037 Epoch 26/50 33/33 [==============================] - 3s 94ms/step - loss: 7.2408e-04 - val_loss: 0.0035 Epoch 27/50 33/33 [==============================] - 3s 95ms/step - loss: 7.8355e-04 - val_loss: 0.0059 Epoch 28/50 33/33 [==============================] - 3s 96ms/step - loss: 8.1942e-04 - val_loss: 0.0035 Epoch 29/50 33/33 [==============================] - 3s 96ms/step - loss: 7.7674e-04 - val_loss: 0.0033 Epoch 30/50 33/33 [==============================] - 3s 95ms/step - loss: 7.3867e-04 - val_loss: 0.0037 Epoch 31/50 33/33 [==============================] - 3s 97ms/step - loss: 7.2609e-04 - val_loss: 0.0033 Epoch 32/50 33/33 [==============================] - 3s 96ms/step - loss: 6.9374e-04 - val_loss: 0.0033 Epoch 33/50 33/33 [==============================] - 3s 96ms/step - loss: 6.3776e-04 - val_loss: 0.0050 Epoch 34/50 33/33 [==============================] - 3s 97ms/step - loss: 7.6443e-04 - val_loss: 0.0036 Epoch 35/50 33/33 [==============================] - 3s 98ms/step - loss: 7.9301e-04 - val_loss: 0.0032 Epoch 36/50 33/33 [==============================] - 3s 97ms/step - loss: 7.7646e-04 - val_loss: 0.0036 Epoch 37/50 33/33 [==============================] - 3s 99ms/step - loss: 8.3467e-04 - val_loss: 0.0033 Epoch 38/50 33/33 [==============================] - 3s 99ms/step - loss: 7.6392e-04 - val_loss: 0.0032 Epoch 39/50 33/33 [==============================] - 3s 99ms/step - loss: 6.3954e-04 - val_loss: 0.0047 Epoch 40/50 33/33 [==============================] - 3s 99ms/step - loss: 7.3498e-04 - val_loss: 0.0034 Epoch 41/50 33/33 [==============================] - 3s 99ms/step - loss: 5.8371e-04 - val_loss: 0.0031 Epoch 42/50 33/33 [==============================] - 3s 99ms/step - loss: 5.7156e-04 - val_loss: 0.0034 Epoch 43/50 33/33 [==============================] - 3s 100ms/step - loss: 6.2417e-04 - val_loss: 0.0030 Epoch 44/50 33/33 [==============================] - 3s 101ms/step - loss: 6.8761e-04 - val_loss: 0.0035 Epoch 45/50 33/33 [==============================] - 4s 108ms/step - loss: 6.7483e-04 - val_loss: 0.0031 Epoch 46/50 33/33 [==============================] - 4s 113ms/step - loss: 6.2236e-04 - val_loss: 0.0031 Epoch 47/50 33/33 [==============================] - 4s 115ms/step - loss: 6.4746e-04 - val_loss: 0.0034 Epoch 48/50 33/33 [==============================] - 4s 112ms/step - loss: 7.4622e-04 - val_loss: 0.0029 Epoch 49/50 33/33 [==============================] - 3s 101ms/step - loss: 6.8864e-04 - val_loss: 0.0028 Epoch 50/50 33/33 [==============================] - 3s 101ms/step - loss: 5.6762e-04 - val_loss: 0.0028loss曲線
loss = history.history['loss'] val_loss = history.history['val_loss'] plt.plot(loss,label='Training Loss') plt.plot(val_loss,label='Validation Loss') plt.legend() plt.title('Loss') plt.show()預測結果與真實值比較
predict_price = model.predict(x_test) predict_price = sc.inverse_transform(predict_price) real_price = sc.inverse_transform(test_set[60:]) plt.plot(real_price, color='red', label='MaoTai Stock Price') plt.plot(predict_price, color='blue', label='Predicted MaoTai Stock Price') plt.title('MaoTai Stock Price Prediction') plt.xlabel('Time') plt.ylabel('MaoTai Stock Price') plt.legend() plt.show()查看評價指標(均方誤差和均方根差)
mse=mean_squared_error(predict_price,real_price) mae = mean_absolute_error(predict_price,real_price) print('mean_squared_error',mse) print('mean_absolute_error',mae) 輸出: mean_squared_error 922.6493975725148 mean_absolute_error 23.789508666992194總結
以上是生活随笔為你收集整理的keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: java中coverage怎么取消_别人
- 下一篇: 提高电脑反应速度_宁美千元价电脑,一体机