使用MATLAB的trainNetwork设计一个简单的LSTM神经网络
文章目錄
- 前言
- 一、數據集
- 二、網絡結構
- 三、測試程序
前言
借助MATLAB的deepNetworkDesigner搭一個簡單的LSTM,數據集使用mnist手寫數字識別數據集。
一、數據集
mnist數據集包括60000組訓練數據和對應的標簽,10000組測試數據和對應標簽。每個數據都是一個28x28的矩陣,可以將其看做28x28像素的灰度圖像(黑底白字)。而LSTM的輸入應當是一個序列,我們可以把矩陣的每一行當做一幀,把圖像分為28幀輸入到LSTM。
數據集可以在我上傳的資源里找到。
數據的格式是這樣的:
XTrain,即訓練圖像,是一個60000x1的cell,cell的每一個元素是一個28x28的矩陣。矩陣的每一列為一幀。直接將矩陣以圖片顯示是這樣的:
這不是某希臘字母,而是手寫數字3。我們希望按行輸入,而MATLAB按列讀取,因此我做了個轉置。再轉置一下就能看到正常的圖像:
標簽的格式為:
可以直接通過categorical函數實現數值到categorical的轉換,比如:
輸入訓練數據的方式不唯一,我用的只是其中一種,詳情見MathWorks官網:trainNetwork
二、網絡結構
使用一層128個隱藏節點的LSTM,一層全連接,輸出使用softmax。網絡的輸入是一個序列,輸出是標簽,在MATLAB中,此網絡可以這樣描述:
layers = [ ...sequenceInputLayer(inputSize) %sequence輸入lstmLayer(numHiddenUnits,'OutputMode','last') %lstmfullyConnectedLayer(numClasses) %全連接softmaxLayer %softmaxclassificationLayer]; %label輸出三、測試程序
完整的測試程序如下:
clear clc %加載數據 load('.\mnist_data_mat\XTrain.mat') load('.\mnist_data_mat\YTrain.mat') load('.\mnist_data_mat\XTest.mat') load('.\mnist_data_mat\YTest.mat')%設置參數 inputSize = 28; %28個輸入節點 numHiddenUnits = 128; %128個隱藏節點 numClasses = 10; %10種分類結果layers = [ ...sequenceInputLayer(inputSize) %sequence輸入lstmLayer(numHiddenUnits,'OutputMode','last') %lstmfullyConnectedLayer(numClasses) %全連接softmaxLayer %softmaxclassificationLayer]; %label輸出options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',5, ...'MiniBatchSize',60, ...'GradientThreshold',1, ...'Verbose',false, ...'Plots','training-progress');net=trainNetwork(XTrain,YTrain,layers, options); %訓練Y_pred = classify(net, XTest); %測試 accy = sum(Y_pred == YTest) / length(YTest); %計算準確度準確度為97.73%
options里的參數可以修改一下,我用同樣結構的網絡不同的參數做出了98.74%的準確度,仍有提升空間。這里為了節省訓練時間犧牲了一些精度。
訓練好的網絡也上傳到了資源里。
總結
以上是生活随笔為你收集整理的使用MATLAB的trainNetwork设计一个简单的LSTM神经网络的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: android webview 长按菜单
- 下一篇: 软件工程基础作业 可行性与需求分析