莫烦python博客_莫烦Python 4
莫煩Python 4
新建模板小書匠
RNN Classifier 循環神經網絡
問題描述
使用RNN對MNIST里面的圖片進行分類
關鍵
SimpleRNN()參數
batch_input_shape
使用狀態RNN的注意事項
可以將RNN設置為‘stateful’,意味著由每個batch計算出的狀態都會被重用于初始化下一個batch的初始狀態。狀態RNN假設連續的兩個batch之中,相同下標的元素有一一映射關系。
要啟用狀態RNN,請在實例化層對象時指定參數stateful=True,并在Sequential模型使用固定大小的batch:通過在模型的第一層傳入batch_size=(…)和input_shape來實現。在函數式模型中,對所有的輸入都要指定相同的batch_size。
如果要將循環層的狀態重置,請調用.reset_states(),對模型調用將重置模型中所有狀態RNN的狀態。對單個層調用則只重置該層的狀態。
(samples,timesteps,input_dim)
代碼
'''
RNN Classifier 循環神經網絡
'''
import numpy as np
np.random.seed(1337)
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN, Activation, Dense
from keras.optimizers import Adam
time_step = 28
input_size = 28
batch_size = 50
output_size = 10
cell_size = 50
LR = 0.001
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28, 28) / 255. # normalize
X_test = X_test.reshape(-1, 28, 28) / 255. # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)
model = Sequential()
model.add(
SimpleRNN(
batch_input_shape=(None, time_step, input_size),
units=cell_size
)
)
model.add(
Dense(output_size)
)
model.add(Activation('softmax'))
adam = Adam(LR)
model.compile(
optimizer=adam,
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary()
model.fit(X_train, y_train, batch_size=batch_size, epochs=2, verbose=2, validation_data=(X_test, y_test))
結果
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_2 (SimpleRNN) (None, 50) 3950
_________________________________________________________________
dense_2 (Dense) (None, 10) 510
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 4,460
Trainable params: 4,460
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/2
- 12s - loss: 0.6643 - accuracy: 0.7966 - val_loss: 0.4501 - val_accuracy: 0.8550
Epoch 2/2
- 9s - loss: 0.3220 - accuracy: 0.9087 - val_loss: 0.2445 - val_accuracy: 0.9359
總結
以上是生活随笔為你收集整理的莫烦python博客_莫烦Python 4的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: .NET Framework 1.1安装
- 下一篇: Python植物大战僵尸源代码及素材