构建单层单向RNN网络对MNIST数据集分类
生活随笔
收集整理的這篇文章主要介紹了
构建单层单向RNN网络对MNIST数据集分类
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
一、導入數據集
1 import tensorflow as tf 2 import numpy as np 3 #清除默認圖形堆棧并重置全局默認圖形,tf.reset_default_graph函數只適用于當前線程. 4 #當一個tf.Session或者tf.InteractiveSession激活時調用這個函數會導致未定義的行為. 5 #調用此函數后使用任何以前創建的tf.Operation或tf.Tensor對象將導致未定義的行為. 6 tf.reset_default_graph() 7 from tensorflow.examples.tutorials.mnist import input_data 8 9 #mnist是一個輕量級的類,它以numpy數組的形式存儲著訓練,校驗,測試數據集 one_hot表示輸出二值化后的10維 10 mnist = input_data.read_data_sets('MNIST-data',one_hot=True) 11 print('Training data shape:',mnist.train.images.shape) #Training data shape: (55000, 784) 12 print('Test data shape:',mnist.test.images.shape) #Test data shape: (10000, 784) 13 print('Validation data shape:',mnist.validation.images.shape) #Validation data shape: (5000, 784) 14 print('Training label shape:',mnist.train.labels.shape) #Training label shape: (55000, 10)二、定義參數
1 n_input = 28 #LSTM單元輸入節點的個數 2 n_steps = 28 #序列長度 3 n_hidden = 128 #LSTM單元輸出節點個數(即隱藏層個數) 4 n_classes = 10 #類別 5 #定義占位符 6 #batch_size:表示一次的批次樣本數量batch_size n_steps:表示時間序列總數 n_input:表示一個時序具體的數據長度 即一共28個時序,一個時序送入28個數據進入LSTM網絡 7 input_x = tf.placeholder(dtype=tf.float32,shape=[None,n_steps,n_input]) 8 input_y = tf.placeholder(dtype=tf.float32,shape=[None,n_classes])三、①、構建單層靜態LSTM網絡
1 def single_layer_static_lstm(input_x,n_steps,n_hidden): 2 ''' 3 返回靜態單層LSTM單元的輸出,以及cell狀態 4 args: 5 input_x:輸入張量 形狀為[batch_size,n_steps,n_input] 6 n_steps:時序總數 7 n_hidden:LSTM單元輸出的節點個數 即隱藏層節點數 8 ''' 9 ''' 10 #把輸入input_x按列拆分,并返回一個有n_steps個張量組成的list 如batch_sizex28x28的輸入拆成[(batch_size,28),((batch_size,28))....] 11 #如果是調用的是靜態rnn函數,需要這一步處理 即相當于把序列作為第一維度 12 ''' 13 input_x1 = tf.unstack(input_x,num=n_steps,axis=1)#① 14 '''可以看做隱藏層''' 15 lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=n_hidden,forget_bias=1.0)#② 16 '''靜態rnn函數傳入的是一個張量list 每一個元素都是一個(batch_size,n_input)大小的張量 ''' 17 hiddens,states = tf.contrib.rnn.static_rnn(cell=lstm_cell, inputs=input_x1, dtype=tf.float32)#③ 18 19 return hiddens,states? ②、構建單層靜態GRU網絡
1 def single_layer_static_gru(input_x,n_steps,n_hidden): 2 ''' 3 返回靜態單層GRU單元的輸出,以及cell狀態 4 5 args: 6 input_x:輸入張量 形狀為[batch_size,n_steps,n_input] 7 n_steps:時序總數 8 n_hidden:gru單元輸出的節點個數 即隱藏層節點數 9 ''' 10 ''' 11 #把輸入input_x按列拆分,并返回一個有n_steps個張量組成的list 如batch_sizex28x28的輸入拆成[(batch_size,28),((batch_size,28))....] 12 #如果是調用的是靜態rnn函數,需要這一步處理 即相當于把序列作為第一維度 13 ''' 14 input_x1 = tf.unstack(input_x,num=n_steps,axis=1) 15 16 '''可以看做隱藏層''' 17 gru_cell = tf.contrib.rnn.GRUCell(num_units=n_hidden) 18 '''靜態rnn函數傳入的是一個張量list 每一個元素都是一個(batch_size,n_input)大小的張量 ''' 19 hiddens,states = tf.contrib.rnn.static_rnn(cell=gru_cell,inputs=input_x1,dtype=tf.float32) 20 21 return hiddens,states③、構建單層動態LSTM網絡
1 def single_layer_dynamic_lstm(input_x,n_steps,n_hidden): 2 ''' 3 返回動態單層LSTM單元的輸出,以及cell狀態 4 5 args: 6 input_x:輸入張量 形狀為[batch_size,n_steps,n_input] 7 n_steps:時序總數 8 n_hidden:LSTM單元輸出的節點個數 即隱藏層節點數 9 ''' 10 #可以看做隱藏層 11 lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=n_hidden,forget_bias=1.0) 12 #動態rnn函數傳入的是一個三維張量,[batch_size,n_steps,n_input] 輸出也是這種形狀 13 hiddens,states = tf.nn.dynamic_rnn(cell=lstm_cell,inputs=input_x,dtype=tf.float32) 14 15 #注意這里輸出需要轉置 轉換為時序優先的 16 hiddens = tf.transpose(hiddens,[1,0,2]) 17 return hiddens,states? ④、構建單層動態GRU網絡
1 def single_layer_dynamic_gru(input_x,n_steps,n_hidden): 2 ''' 3 返回動態單層GRU單元的輸出,以及cell狀態 4 5 args: 6 input_x:輸入張量 形狀為[batch_size,n_steps,n_input] 7 n_steps:時序總數 8 n_hidden:gru單元輸出的節點個數 即隱藏層節點數 9 ''' 10 11 #可以看做隱藏層 12 gru_cell = tf.contrib.rnn.GRUCell(num_units=n_hidden) 13 #動態rnn函數傳入的是一個三維張量,[batch_size,n_steps,n_input] 輸出也是這種形狀 14 hiddens,states = tf.nn.dynamic_rnn(cell=gru_cell,inputs=input_x,dtype=tf.float32) 15 16 17 #注意這里輸出需要轉置 轉換為時序優先的 18 hiddens = tf.transpose(hiddens,[1,0,2]) 19 return hiddens,states四、初始化
1 #調用單層靜態LSTM網絡 2 hiddens,states = single_layer_static_lstm(input_x,n_steps,n_hidden) 3 #取LSTM最后一個時序的輸出,然后經過全連接網絡得到輸出值 4 output = tf.contrib.layers.fully_connected(inputs=hiddens[-1],num_outputs=n_classes,activation_fn = tf.nn.softmax) 5 #設置對數似然損失函數 6 #代價函數 J =-(Σy.logaL)/n 表示逐元素乘 7 cost = tf.reduce_mean(-tf.reduce_sum(input_y*tf.log(output),axis=1)) 8 #求解 9 learning_rate = 1e-4 #學習率 10 train = tf.train.AdamOptimizer(learning_rate).minimize(cost) 11 #預測結果評估 12 #tf.argmax(output,1) 按行統計最大值得索引 13 correct = tf.equal(tf.argmax(output,1),tf.argmax(input_y,1)) #返回一個數組 表示統計預測正確或者錯誤 14 accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #求準確率五、訓練以及測試
1 with tf.Session() as sess: 2 #使用會話執行圖 3 sess.run(tf.global_variables_initializer()) #初始化變量 4 batch_size = 128 #小批量大小 5 training_step = 5000 #迭代次數 6 display_step = 200 #顯示步數 7 #開始迭代 使用Adam優化的隨機梯度下降法 8 for i in range(training_step): 9 x_batch,y_batch = mnist.train.next_batch(batch_size = batch_size) 10 #Reshape data to get 28 seq of 28 elements 11 x_batch = x_batch.reshape([-1,n_steps,n_input]) 12 13 #開始訓練 14 train.run(feed_dict={input_x:x_batch,input_y:y_batch}) 15 if (i+1) % display_step == 0: 16 #輸出訓練集準確率 17 training_accuracy,training_cost = sess.run([accuracy,cost],feed_dict={input_x:x_batch,input_y:y_batch}) 18 print('Step {0}:Training set accuracy {1},cost {2}.'.format(i+1,training_accuracy,training_cost)) 19 20 21 #全部訓練完成做測試 分成200次,一次測試50個樣本 22 #輸出測試機準確率 如果一次性全部做測試,內容不夠用會出現OOM錯誤。所以測試時選取比較小的mini_batch來測試 23 for i in range(200): 24 x_batch,y_batch = mnist.test.next_batch(batch_size = 50) 25 #Reshape data to get 28 seq of 28 elements 26 x_batch = x_batch.reshape([-1,n_steps,n_input]) 27 test_accuracy,test_cost = sess.run([accuracy,cost],feed_dict={input_x:x_batch,input_y:y_batch}) 28 test_accuracy_list.append(test_accuracy) 29 test_cost_list.append(test_cost) 30 if (i+1)% 20 == 0: 31 print('Step {0}:Test set accuracy {1},cost {2}.'.format(i+1,test_accuracy,test_cost)) 32 print('Test accuracy:',np.mean(test_accuracy_list))六、訓練模型保存
1 saver = tf.train.Saver() 2 saver.save(sess, './model/rnn_model.ckpt')七、調用模型
1 saver = tf.train.Saver() 2 save_path = tf.train.latest_checkpoint('./model/') 3 with tf.Session() as sess: 4 saver.restore(sess=sess, save_path=save_path) 5 feeds = {x:images,y:labels} 6 preds = sess.run(accr, feed_dict=feeds)?
轉載于:https://www.cnblogs.com/chenfeifen/p/11360659.html
總結
以上是生活随笔為你收集整理的构建单层单向RNN网络对MNIST数据集分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 有关T-SQL的10个好习惯
- 下一篇: python断点续传代码