控制论python_[干货]深入浅出LSTM及其Python代码实现
人工神經網絡在近年來大放異彩,在圖像識別、語音識別、自然語言處理與大數據分析領域取得了巨大的成功,而長短期記憶網絡LSTM作為一種特殊的神經網絡模型,它又有哪些特點呢?作為初學者,如何由淺入深地理解LSTM并將其應用到實際工作中呢?本文將由淺入深介紹循環神經網絡RNN和長短期記憶網絡LSTM的基本原理,并基于Pytorch實現一個簡單應用例子,提供完整代碼。
1. 神經網絡簡介
1.1 神經網絡起源
人工神經網絡(Aritificial Neural Networks, ANN)是一種仿生的網絡結構,起源于對人類大腦的研究。人工神經網絡(Aritificial Neural Networks)也常被簡稱為神經網絡(Neural Networks, NN),基本思想是通過大量簡單的神經元之間的相互連接來構造復雜的網絡結構,信號(數據)可以在這些神經元之間傳遞,通過激活不同的神經元和對傳遞的信號進行加權來使得信號被放大或衰減,經過多次的傳遞來改變信號的強度和表現形式。
神經網絡最早起源于20世紀40年代,神經科學家和控制論專家Warren McCulloch和邏輯學家Walter Pitts基于數學和閾值邏輯算法創造了最早的神經網絡計算模型。由于當時的計算資源有限,無法構建層數太多的神經網絡(3層以內),因此神經網絡的應用范圍很局限。隨著計算機技術的發展,神經網絡層數的增加帶來的計算負擔已經可以被現代計算機解決,各位前輩大牛對于神經網絡的理解也進一步加深。歷史上神經網絡的發展大致經歷了三次高潮:20世紀40年代的控制論、20世紀80年代到90年代中期的聯結主義和2006年以來的深度學習。深度學習的出現直接引爆了一部分應用市場,這里有太多的案例可以講,想了解的讀者可參考下面的鏈接:
1.2 傳統神經網絡的缺陷
本文假設讀者已經了解神經網絡的基本原理了,如果有讀者是初次接觸神經網絡的知識,這里分享一篇個人覺得非常適合初學者的文章:
1.2.1 傳統神經網絡的原理回顧
傳統的神經網絡結構可以用下面這張圖表示:
NN.jpg
其中:
輸入層:可以包含多個神經元,可以接收多維的信號輸入(特征信息);
輸出層:可以包含多個神經元,可以輸出多維信號;
隱含層:可以包含多個神經網絡層,每一層包含多個神經元。
每層的神經元與上一層神經元和下一層神經元連接(類似生物神經元的突觸),這些連接通路用于信號傳遞。每個神經元接收來自上一層的信號輸入,使用一定的加和規則將所有的信號輸入匯聚到一起,并使用激活函數將輸入信號激活為輸出信號,再將信號傳遞到下一層。
神經網絡為什么要使用激活函數?不同的激活函數有什么不同的作用?讀者可參考:
所以,影響神經網絡表現能力的主要因素有神經網絡的層數、神經元的個數、神經元之間的連接方式以及神經元所采用的激活函數。神經元之間以不同的連接方式(全連接、部分連接)組合,可以構成不同神經網絡,對于不同的信號處理效果也不一樣。但是,目前依舊沒有一種通用的方法可以根據信號輸入的特征來決定神經網絡的結構,這也是神經網絡模型被稱為黑箱的原因之一,帶來的問題也就是模型的參數不容易調整,也不清楚其中到底發生了什么。因此,在不斷的探索當中,前輩大牛們總結得到了許多經典的神經網絡結構:MLP、BP、FFNN、CNN、RNN等。詳細的介紹見以下鏈接:
神經網絡優點很明顯,給我們提供了構建模型的便利,你大可不用顧及模型本身是如何作用的,只需要按照規則構建網絡,然后使用訓練數據集不斷調整參數,在許多問題上都能得到一個比較“能接受”的結果,然而我們對其中發生了什么是未可知的。在深度學習領域,許多問題都可以通過構建深層的神經網絡模型來解決。這里,我們不對神經網絡的優點做過多闡述。
1.2.2 傳統神經網絡結構的缺陷
從傳統的神經網絡結構我們可以看出,信號流從輸入層到輸出層依次流過,同一層級的神經元之間,信號是不會相互傳遞的。這樣就會導致一個問題,輸出信號只與輸入信號有關,而與輸入信號的先后順序無關。并且神經元本身也不具有存儲信息的能力,整個網絡也就沒有“記憶”能力,當輸入信號是一個跟時間相關的信號時,如果我們想要通過這段信號的“上下文”信息來理解一段時間序列的意思,傳統的神經網絡結構就顯得無力了。與我們人類的理解過程類似,我們聽到一句話時往往需要通過這句話中詞語出現的順序以及我們之前所學的關于這些詞語的意思來理解整段話的意思,而不是簡單的通過其中的幾個詞語來理解。
例如,在自然語言處理領域,我們要讓神經網絡理解這樣一句話:“地球上最高的山是珠穆朗瑪峰”,按照傳統的神經網絡結構,它可能會將這句話拆分為幾個單獨的詞(地球、上、最高的、山、是、珠穆朗瑪峰),分別輸入到模型之中,而不管這幾個詞之間的順序。然而,直觀上我們可以看到,這幾個詞出現的順序是與最終這句話要表達的意思是密切相關的,但傳統的神經網絡結構無法處理這種情況。
因此,我們需要構建具有“記憶”能力的神經網絡模型,用來處理需要理解上下文意思的信號,也就是時間序列數據。循環神經網絡(RNN)就是用來處理這類信號的,RNN之所以能夠有效的處理時間序列數據,主要是基于它比較特殊的運行原理。下面將介紹RNN的構建過程和基本運行原理,然后引入長短期記憶網絡(LSTM)。
2. 循環神經網絡RNN
2.1 RNN的構造過程
RNN是一種特殊的神經網路結構,其本身是包含循環的網絡,允許信息在神經元之間傳遞,如下圖所示:
RNN-rolled.png
圖示是一個RNN結構示意圖,圖中的
表示神經網絡模型,
表示模型的輸入信號,
表示模型的輸出信號,如果沒有
的輸出信號傳遞到
的那個箭頭, 這個網絡模型與普通的神經網絡結構無異。那么這個箭頭做了什么事情呢?它允許
將信息傳遞給
,神經網絡將自己的輸出作為輸入了!這怎么理解啊?作者第一次看到這個圖的時候也是有點懵,讀者可以思考一分鐘。
關鍵在于輸入信號是一個時間序列,跟時間
有關。也就是說,在
時刻,輸入信號
作為神經網絡
的輸入,
的輸出分流為兩部分,一部分輸出給
,一部分作為一個隱藏的信號流被輸入到
中,在下一次時刻輸入信號
時,這部分隱藏的信號流也作為輸入信號輸入到了
中。此時神經網絡
就同時接收了
時刻和
時刻的信號輸入了,此時的輸出信號又將被傳遞到下一時刻的
中。如果我們把上面那個圖根據時間
展開來看,就是:
RNN-unrolled.png
看到了嗎?
時刻的信息輸出給
時刻的模型
了,
時刻的信息輸出給
時刻的模型
了,
。這樣,相當于RNN在時間序列上把自己復制了很多遍,每個模型都對應一個時刻的輸入,并且當前時刻的輸出還作為下一時刻的模型的輸入信號。
這樣鏈式的結構揭示了RNN本質上是與序列相關的,是對于時間序列數據最自然的神經網絡架構。并且理論上,RNN可以保留以前任意時刻的信息。RNN在語音識別、自然語言處理、圖片描述、視頻圖像處理等領域已經取得了一定的成果,而且還將更加大放異彩。在實際使用的時候,用得最多的一種RNN結構是LSTM,為什么是LSTM呢?我們從普通RNN的局限性說起。
2.2 RNN的局限性
RNN利用了神經網絡的“內部循環”來保留時間序列的上下文信息,可以使用過去的信號數據來推測對當前信號的理解,這是非常重要的進步,并且理論上RNN可以保留過去任意時刻的信息。但實際使用RNN時往往遇到問題,請看下面這個例子。
假如我們構造了一個語言模型,可以通過當前這一句話的意思來預測下一個詞語。現在有這樣一句話:“我是一個中國人,出生在普通家庭,我最常說漢語,也喜歡寫漢字。我喜歡媽媽做的菜”。我們的語言模型在預測“我最常說漢語”的“漢語”這個詞時,它要預測“我最長說”這后面可能跟的是一個語言,可能是英語,也可能是漢語,那么它需要用到第一句話的“我是中國人”這段話的意思來推測我最常說漢語,而不是英語、法語等。而在預測“我喜歡媽媽做的菜”的最后的詞“菜”時并不需要“我是中國人”這個信息以及其他的信息,它跟我是不是一個中國人沒有必然的關系。
這個例子告訴我們,想要精確地處理時間序列,有時候我們只需要用到最近的時刻的信息。例如預測“我喜歡媽媽做的菜”最后這個詞“菜”,此時信息傳遞是這樣的:
RNN-shorttermdepdencies.png
“菜”這個詞與“我”、“喜歡”、“媽媽”、“做”、“的”這幾個詞關聯性比較大,距離也比較近,所以可以直接利用這幾個詞進行最后那個詞語的推測。
而有時候我們又需要用到很早以前時刻的信息,例如預測“我最常說漢語”最后的這個詞“漢語”。此時信息傳遞是這樣的:
RNN-longtermdependencies.png
此時,我們要預測“漢語”這個詞,僅僅依靠“我”、“最”、“常”、“說”這幾個詞還不能得出我說的是漢語,必須要追溯到更早的句子“我是一個中國人”,由“中國人”這個詞語來推測我最常說的是漢語。因此,這種情況下,我們想要推測“漢語”這個詞的時候就比前面那個預測“菜”這個詞所用到的信息就處于更早的時刻。
而RNN雖然在理論上可以保留所有歷史時刻的信息,但在實際使用時,信息的傳遞往往會因為時間間隔太長而逐漸衰減,傳遞一段時刻以后其信息的作用效果就大大降低了。因此,普通RNN對于信息的長期依賴問題沒有很好的處理辦法。
為了克服這個問題,Hochreiter等人在1997年改進了RNN,提出了一種特殊的RNN模型——LSTM網絡,可以學習長期依賴信息,在后面的20多年被改良和得到了廣泛的應用,并且取得了極大的成功。
3. 長短時間記憶網絡(LSTM)
3.1 LSTM與RNN的關系
長短期記憶(Long Short Term Memory,LSTM)網絡是一種特殊的RNN模型,其特殊的結構設計使得它可以避免長期依賴問題,記住很早時刻的信息是LSTM的默認行為,而不需要專門為此付出很大代價。
普通的RNN模型中,其重復神經網絡模塊的鏈式模型如下圖所示,這個重復的模塊只有一個非常簡單的結構,一個單一的神經網絡層(例如tanh層),這樣就會導致信息的處理能力比較低。
LSTM3-SimpleRNN.png
而LSTM在此基礎上將這個結構改進了,不再是單一的神經網絡層,而是4個,并且以一種特殊的方式進行交互。
LSTM3-chain.png
粗看起來,這個結構有點復雜,不過不用擔心,接下來我們會慢慢解釋。在解釋這個神經網絡層時我們先來認識一些基本的模塊表示方法。圖中的模塊分為以下幾種:
LSTM2-notation.png
黃色方塊:表示一個神經網絡層(Neural Network Layer);
粉色圓圈:表示按位操作或逐點操作(pointwise operation),例如向量加和、向量乘積等;
單箭頭:表示信號傳遞(向量傳遞);
合流箭頭:表示兩個信號的連接(向量拼接);
分流箭頭:表示信號被復制后傳遞到2個不同的地方。
下面我們將分別介紹這些模塊如何在LSTM中作用。
3.2 LSTM的基本思想
LSTM的關鍵是細胞狀態(直譯:cell state),表示為
,用來保存當前LSTM的狀態信息并傳遞到下一時刻的LSTM中,也就是RNN中那根“自循環”的箭頭。當前的LSTM接收來自上一個時刻的細胞狀態
,并與當前LSTM接收的信號輸入
共同作用產生當前LSTM的細胞狀態
,具體的作用方式下面將詳細介紹。
LSTM3-C-line.png
在LSTM中,采用專門設計的“門”來引入或者去除細胞狀態
中的信息。門是一種讓信息選擇性通過的方法。有的門跟信號處理中的濾波器有點類似,允許信號部分通過或者通過時被門加工了;有的門也跟數字電路中的邏輯門類似,允許信號通過或者不通過。這里所采用的門包含一個
神經網絡層和一個按位的乘法操作,如下圖所示:
LSTM3-gate.png
其中黃色方塊表示
神經網絡層,粉色圓圈表示按位乘法操作。
神經網絡層可以將輸入信號轉換為
到
之間的數值,用來描述有多少量的輸入信號可以通過。
表示“不允許任何量通過”,
表示“允許所有量通過”。
神經網絡層起到類似下圖的
函數所示的作用:
sigmod_function.jpg
其中,橫軸表示輸入信號,縱軸表示經過
以后的輸出信號。
LSTM主要包括三個不同的門結構:遺忘門、記憶門和輸出門。這三個門用來控制LSTM的信息保留和傳遞,最終反映到細胞狀態
和輸出信號
。如下圖所示:
LSTM_gates.png
圖中標示了LSTM中各個門的構成情況和相互之間的關系,其中:
遺忘門由一個
神經網絡層和一個按位乘操作構成;
記憶門由輸入門(input gate)與tanh神經網絡層和一個按位乘操作構成;
輸出門(output gate)與
函數(注意:這里不是
神經網絡層)以及按位乘操作共同作用將細胞狀態和輸入信號傳遞到輸出端。
3.3 遺忘門
顧名思義,遺忘門的作用就是用來“忘記”信息的。在LSTM的使用過程中,有一些信息不是必要的,因此遺忘門的作用就是用來選擇這些信息并“忘記”它們。遺忘門決定了細胞狀態
中的哪些信息將被遺忘。那么遺忘門的工作原理是什么呢?看下面這張圖。
LSTM3-focus-f.png
左邊高亮的結構就是遺忘門了,包含一個
神經網絡層(黃色方框,神經網絡參數為
),接收
時刻的輸入信號
和
時刻LSTM的上一個輸出信號
,這兩個信號進行拼接以后共同輸入到
神經網絡層中,然后輸出信號
,
是一個
到
之間的數值,并與
相乘來決定
中的哪些信息將被保留,哪些信息將被舍棄。可能看到這里有的初學者還是不知道具體是什么意思,我們用一個簡單的例子來說明。
假設
,
,
, 那么遺忘門的輸入信號就是
和
的組合,即
, 然后通過
神經網絡層輸出每一個元素都處于
到
之間的向量
,注意,此時
是一個與
維數相同的向量,此處為3維。如果看到這里還沒有看懂的讀者,可能會有這樣的疑問:輸入信號明明是6維的向量,為什么
就變成了3維呢?這里可能是將
神經網絡層當成了
激活函數了,兩者不是一個東西,初學者在這里很容易混淆。下文所提及的
神經網絡層和
神經網絡層而是類似的道理,他們并不是簡單的
激活函數和
激活函數,在學習時要注意區分。
3.4 記憶門
記憶門的作用與遺忘門相反,它將決定新輸入的信息
和
中哪些信息將被保留。
LSTM3-focus-i.png
如圖所示,記憶門包含2個部分。第一個是包含
神經網絡層(輸入門,神經網絡網絡參數為
)和一個
神經網絡層(神經網絡參數為
)。
神經網絡層的作用很明顯,跟遺忘門一樣,它接收
和
作為輸入,然后輸出一個
到
之間的數值
來決定哪些信息需要被更新;
Tanh神經網絡層的作用是將輸入的
和
整合,然后通過一個
神經網絡層來創建一個新的狀態候選向量
,
的值范圍在
到
之間。
記憶門的輸出由上述兩個神經網絡層的輸出決定,
與
相乘來選擇哪些信息將被新加入到
時刻的細胞狀態
中。
3.5 更新細胞狀態
有了遺忘門和記憶門,我們就可以更新細胞狀態
了。
LSTM3-focus-C.png
這里將遺忘門的輸出
與上一時刻的細胞狀態
相乘來選擇遺忘和保留一些信息,將記憶門的輸出與從遺忘門選擇后的信息加和得到新的細胞狀態
。這就表示
時刻的細胞狀態
已經包含了此時需要丟棄的
時刻傳遞的信息和
時刻從輸入信號獲取的需要新加入的信息
。
將繼續傳遞到
時刻的LSTM網絡中,作為新的細胞狀態傳遞下去。
3.6 輸出門
前面已經講了LSTM如何來更新細胞狀態
, 那么在
時刻我們輸入信號
以后,對應的輸出信號該如何計算呢?
LSTM3-focus-o.png
如上面左圖所示,輸出門就是將
時刻傳遞過來并經過了前面遺忘門與記憶門選擇后的細胞狀態
, 與
時刻的輸出信號
和
時刻的輸入信號
整合到一起作為當前時刻的輸出信號。整合的過程如上圖所示,
和
經過一個
神經網絡層(神經網絡參數為
)輸出一個
到
之間的數值
。
經過一個
函數(注意:這里不是
神經網絡層)到一個在
到
之間的數值,并與
相乘得到輸出信號
,同時
也作為下一個時刻的輸入信號傳遞到下一階段。
其中,
函數是激活函數的一種,函數圖像為:
tanh.png
至此,基本的LSTM網絡模型就介紹完了。如果對LSTM模型還沒有理解到的,可以看一下這個視頻,作者是一個外國小哥,英文講解的,有動圖,方便理解。
3.7 LSTM的一些變體
前面已經介紹了基本的LSTM網絡模型,而實際應用時,我們常常會采用LSTM的一些變體,雖然差異不大,這里不再做詳細介紹,有興趣的讀者可以自行了解。
3.7.1 在門上增加窺視孔
LSTM3-var-peepholes.png
這是2000年Gers和Schemidhuber教授提出的一種LSTM變體。圖中,在傳統的LSTM結構基礎上,每個門(遺忘門、記憶門和輸出門)增加了一個“窺視孔”(Peephole),有的學者在使用時也選擇只對部分門加入窺視孔。
3.7.2 整合遺忘門和輸入門
LSTM3-var-tied.png
與傳統的LSTM不同的是,這個變體不需要分開來確定要被遺忘和記住的信息,采用一個結構搞定。在遺忘門的輸出信號值(
到
之間)上,用
減去該數值來作為記憶門的狀態選擇,表示只更新需要被遺忘的那些信息的狀態。
3.7.3 GRU
改進比較大的一個LSTM變體叫Gated Recurrent Unit (GRU),目前應用較多。結構圖如下
LSTM3-var-GRU.png
GRU主要包含2個門:重置門和更新門。GRU混合了細胞狀態
和隱藏狀態
為一個新的狀態,使用
來表示。 該模型比傳統的標準LSTM模型簡單。
4. 基于Pytorch的LSTM代碼實現
Pytorch是Python的一個機器學習包,與Tensorflow類似,Pytorch非常適合用來構建神經網絡模型,并且已經提供了一些常用的神經網絡模型包,用戶可以直接調用。下面我們就用一個簡單的小例子來說明如何使用Pytorch來構建LSTM模型。
我們使用正弦函數和余弦函數來構造時間序列,而正余弦函數之間是成導數關系,所以我們可以構造模型來學習正弦函數與余弦函數之間的映射關系,通過輸入正弦函數的值來預測對應的余弦函數的值。
正弦函數和余弦函數對應關系圖如下圖所示:
demo_sine_cosine.png
可以看到,每一個函數曲線上,每一個正弦函數的值都對應一個余弦函數值。但其實如果只關心正弦函數的值本身而不考慮當前值所在的時間,那么正弦函數值和余弦函數值不是一一對應關系。例如,當
和
時,
,但在這兩個不同的時刻,
的值卻不一樣,也就是說如果不考慮時間,同一個正弦函數值可能對應了不同的幾個余弦函數值。對于傳統的神經網絡來說,它僅僅基于當前的輸入來預測輸出,對于這種同一個輸入可能對應多個輸出的情況不再適用。
我們取正弦函數的值作為LSTM的輸入,來預測余弦函數的值。基于Pytorch來構建LSTM模型,采用1個輸入神經元,1個輸出神經元,16個隱藏神經元作為LSTM網絡的構成參數,平均絕對誤差(LMSE)作為損失誤差,使用Adam優化算法來訓練LSTM神經網絡。基于Anaconda和Python3.6的完整代碼如下:
# -*- coding:UTF-8 -*-
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
# Define LSTM Neural Networks
class LstmRNN(nn.Module):
"""
Parameters:
- input_size: feature size
- hidden_size: number of hidden units
- output_size: number of output
- num_layers: layers of LSTM to stack
"""
def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers) # utilize the LSTM model in torch.nn
self.forwardCalculation = nn.Linear(hidden_size, output_size)
def forward(self, _x):
x, _ = self.lstm(_x) # _x is input, size (seq_len, batch, input_size)
s, b, h = x.shape # x is output, size (seq_len, batch, hidden_size)
x = x.view(s*b, h)
x = self.forwardCalculation(x)
x = x.view(s, b, -1)
return x
if __name__ == '__main__':
# create database
data_len = 200
t = np.linspace(0, 12*np.pi, data_len)
sin_t = np.sin(t)
cos_t = np.cos(t)
dataset = np.zeros((data_len, 2))
dataset[:,0] = sin_t
dataset[:,1] = cos_t
dataset = dataset.astype('float32')
# plot part of the original dataset
plt.figure()
plt.plot(t[0:60], dataset[0:60,0], label='sin(t)')
plt.plot(t[0:60], dataset[0:60,1], label = 'cos(t)')
plt.plot([2.5, 2.5], [-1.3, 0.55], 'r--', label='t = 2.5') # t = 2.5
plt.plot([6.8, 6.8], [-1.3, 0.85], 'm--', label='t = 6.8') # t = 6.8
plt.xlabel('t')
plt.ylim(-1.2, 1.2)
plt.ylabel('sin(t) and cos(t)')
plt.legend(loc='upper right')
# choose dataset for training and testing
train_data_ratio = 0.5 # Choose 80% of the data for testing
train_data_len = int(data_len*train_data_ratio)
train_x = dataset[:train_data_len, 0]
train_y = dataset[:train_data_len, 1]
INPUT_FEATURES_NUM = 1
OUTPUT_FEATURES_NUM = 1
t_for_training = t[:train_data_len]
# test_x = train_x
# test_y = train_y
test_x = dataset[train_data_len:, 0]
test_y = dataset[train_data_len:, 1]
t_for_testing = t[train_data_len:]
# ----------------- train -------------------
train_x_tensor = train_x.reshape(-1, 5, INPUT_FEATURES_NUM) # set batch size to 5
train_y_tensor = train_y.reshape(-1, 5, OUTPUT_FEATURES_NUM) # set batch size to 5
# transfer data to pytorch tensor
train_x_tensor = torch.from_numpy(train_x_tensor)
train_y_tensor = torch.from_numpy(train_y_tensor)
# test_x_tensor = torch.from_numpy(test_x)
lstm_model = LstmRNN(INPUT_FEATURES_NUM, 16, output_size=OUTPUT_FEATURES_NUM, num_layers=1) # 16 hidden units
print('LSTM model:', lstm_model)
print('model.parameters:', lstm_model.parameters)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=1e-2)
max_epochs = 10000
for epoch in range(max_epochs):
output = lstm_model(train_x_tensor)
loss = loss_function(output, train_y_tensor)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if loss.item() < 1e-4:
print('Epoch [{}/{}], Loss: {:.5f}'.format(epoch+1, max_epochs, loss.item()))
print("The loss value is reached")
break
elif (epoch+1) % 100 == 0:
print('Epoch: [{}/{}], Loss:{:.5f}'.format(epoch+1, max_epochs, loss.item()))
# prediction on training dataset
predictive_y_for_training = lstm_model(train_x_tensor)
predictive_y_for_training = predictive_y_for_training.view(-1, OUTPUT_FEATURES_NUM).data.numpy()
# torch.save(lstm_model.state_dict(), 'model_params.pkl') # save model parameters to files
# ----------------- test -------------------
# lstm_model.load_state_dict(torch.load('model_params.pkl')) # load model parameters from files
lstm_model = lstm_model.eval() # switch to testing model
# prediction on test dataset
test_x_tensor = test_x.reshape(-1, 5, INPUT_FEATURES_NUM) # set batch size to 5, the same value with the training set
test_x_tensor = torch.from_numpy(test_x_tensor)
predictive_y_for_testing = lstm_model(test_x_tensor)
predictive_y_for_testing = predictive_y_for_testing.view(-1, OUTPUT_FEATURES_NUM).data.numpy()
# ----------------- plot -------------------
plt.figure()
plt.plot(t_for_training, train_x, 'g', label='sin_trn')
plt.plot(t_for_training, train_y, 'b', label='ref_cos_trn')
plt.plot(t_for_training, predictive_y_for_training, 'y--', label='pre_cos_trn')
plt.plot(t_for_testing, test_x, 'c', label='sin_tst')
plt.plot(t_for_testing, test_y, 'k', label='ref_cos_tst')
plt.plot(t_for_testing, predictive_y_for_testing, 'm--', label='pre_cos_tst')
plt.plot([t[train_data_len], t[train_data_len]], [-1.2, 4.0], 'r--', label='separation line') # separation line
plt.xlabel('t')
plt.ylabel('sin(t) and cos(t)')
plt.xlim(t[0], t[-1])
plt.ylim(-1.2, 4)
plt.legend(loc='upper right')
plt.text(14, 2, "train", size = 15, alpha = 1.0)
plt.text(20, 2, "test", size = 15, alpha = 1.0)
plt.show()
訓練的過程如下:
LSTM_training.png
該模型在訓練集和測試集上的結果如下:
demo_LSTM.png
圖中,紅色虛線的左邊表示該模型在訓練數據集上的表現,右邊表示該模型在測試數據集上的表現。可以看到,使用LSTM構建訓練模型,我們可以僅僅使用正弦函數在
時刻的值作為輸入來準確預測
時刻的余弦函數值,不用額外添加當前的時間信息、速度信息等。
5. 參考鏈接
總結
以上是生活随笔為你收集整理的控制论python_[干货]深入浅出LSTM及其Python代码实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 微信小程序没登录跳到登录页怎么做_微信小
- 下一篇: r型聚类典型指标_六种GAN评估指标的综