TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
生活随笔
收集整理的這篇文章主要介紹了
TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
TF之LSTM:利用LSTM算法對mnist手寫數字圖片數據集(TF函數自帶)訓練、評估(偶爾100%準確度,交叉熵驗證)
?
?
目錄
輸出結果
設計思路
代碼設計
?
?
?
輸出結果
第 0 accuracy 0.125 第 20 accuracy 0.6484375 第 40 accuracy 0.78125 第 60 accuracy 0.9296875 第 80 accuracy 0.8671875 第 100 accuracy 0.90625 第 120 accuracy 0.8671875 第 140 accuracy 0.8671875 第 160 accuracy 0.8671875 第 180 accuracy 0.921875 第 200 accuracy 0.890625 第 220 accuracy 0.953125 第 240 accuracy 0.921875 第 260 accuracy 0.9296875 第 280 accuracy 0.9140625 第 300 accuracy 0.921875 第 320 accuracy 0.9609375 第 340 accuracy 0.953125 第 360 accuracy 0.984375 第 380 accuracy 0.921875 第 400 accuracy 0.9453125 第 420 accuracy 0.921875 第 440 accuracy 0.9296875 第 460 accuracy 0.96875 第 480 accuracy 0.984375 第 500 accuracy 0.96875 第 520 accuracy 0.953125 第 540 accuracy 0.96875 第 560 accuracy 0.953125 第 580 accuracy 0.9921875 第 600 accuracy 0.984375 第 620 accuracy 0.953125 第 640 accuracy 0.953125 第 660 accuracy 0.9921875 第 680 accuracy 0.96875 第 700 accuracy 0.9765625 第 720 accuracy 0.96875 第 740 accuracy 0.9921875 第 760 accuracy 0.984375 第 780 accuracy 0.953125?
設計思路
?
代碼設計
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)lr=0.001 training_iters=100000 batch_size=128 n_inputs=28 n_steps=28 n_hidden_units=128 n_classes=10 x=tf.placeholder(tf.float32, [None,n_steps,n_inputs]) y=tf.placeholder(tf.float32, [None,n_classes])weights ={'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),} biases ={'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),}def RNN(X,weights,biases): X=tf.reshape(X,[-1,n_inputs])X_in=tf.matmul(X,weights['in'])+biases['in'] X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)__init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False)outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))results=tf.matmul(outputs[-1],weights['out'])+biases['out']return resultspred =RNN(x,weights,biases) cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) train_op=tf.train.AdamOptimizer(lr).minimize(cost) correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32)) <br> with tf.Session() as sess: sess.run(init)step=0while step*batch_size < training_iters: batch_xs,batch_ys=mnist.train.next_batch(batch_size)batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])sess.run([train_op],feed_dict={x:batch_xs,y:batch_ys,})if step%20==0: print(sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys,}))step+=1?
?
相關文章
TF之LSTM:利用LSTM算法對mnist手寫數字圖片數據集訓練、評估(偶爾100%準確度)
總結
以上是生活随笔為你收集整理的TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: AI:一个20年程序猿的学习资料大全—前
- 下一篇: 成功解决h5py\_init_.py:2