练习利用LSTM实现手写数字分类任务
生活随笔
收集整理的這篇文章主要介紹了
练习利用LSTM实现手写数字分类任务
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
練習(xí)利用LSTM實(shí)現(xiàn)手寫數(shù)字分類任務(wù)
MNIST數(shù)據(jù)集中圖片大小為28*28.
按照行進(jìn)行展開成28維的特征向量。
考慮到這28個(gè)的向量之間存在著順序依賴關(guān)系,我們可以將他們看成是一個(gè)長(zhǎng)為28的輸入序列,將其輸入到LSTM中,LSTM可以從中提取到序列特征,再將此序列特征用一層全聯(lián)接作為分類器,分類器輸出10種分類類別。
綜合代碼
import tensorflow as tf import numpy as np from tensorflow.contrib.layers import fully_connectedimport input_data mnist = input_data.read_data_sets('MNIST_data/',one_hot = True) #one_hot = True 獨(dú)熱編碼,類似[0,0,0,1,0,0,0,0,0,0]這種形式,等價(jià)于class=3n_inputs = 28 #表示輸入神經(jīng)元的個(gè)數(shù) n_steps = 28 #表示序列長(zhǎng)度 n_neurons = 150 #表示LSTM中隱藏層和輸出層神經(jīng)元呢個(gè)數(shù) n_outputs = 10 #是最終分類器輸出的類別數(shù),mnist數(shù)據(jù)集是10分類任務(wù)learning_rate = 0.01 #優(yōu)化方法的學(xué)習(xí)率X = tf.placeholder(tf.float32,[None,n_steps,n_inputs]) Y_labels = tf.placeholder(tf.int32,[None,n_outputs])basic_cell = tf.contrib.rnn.BasicLSTMCell(n_neurons,forget_bias = 1.0, state_is_tuple = True) #獲取一層LSTM網(wǎng)絡(luò),參數(shù)1是每個(gè)cell的輸出神經(jīng)元個(gè)數(shù),參數(shù)2是遺忘的偏置,參數(shù)3表示雙狀態(tài)outneurons, states = tf.nn.dynamic_rnn(basic_cell,X,dtype = tf.float32) #outneurons得到了輸出序列logits = fully_connected(tf.transpose(outneurons,perm = [1,0,2])[-1], n_outputs,activation_fn = None) #在這里由于outneurons的維度為[batch_size,n_steps,n_inputs]的形式,而我們只需要最后一個(gè)cell對(duì)于所有batch的輸出,因此把前兩個(gè)維度調(diào)換一下,再取用[-1]取到最后一個(gè)cell對(duì)于所有batch的輸出。shape為[batch_size,n_inputs] #將其接到一層全連接網(wǎng)絡(luò)作為分類器得到logitscross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = Y_labels,logits = logits) loss = tf.reduce_mean(cross_entropy) #對(duì)logits用softmax做歸一化,計(jì)算其對(duì)于樣本labels的交叉熵的和,取均值作為損失函數(shù)lossoptimizer = tf.train.AdamOptimizer(learning_rate = learning_rate) trainop = optimizer.minimize(loss) #申請(qǐng)一個(gè)優(yōu)化器,用來(lái)最后小化損失函數(shù)losscorrect = tf.equal(tf.argmax(logits,1),tf.argmax(Y_labels,1)) #分析正確率accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))batch_size = 64 init = tf.global_variables_initializer() with tf.Session() as sess:init.run()for i in range(10000):x_batch, y_batch = mnist.train.next_batch(batch_size)x_batch = x_batch.reshape([-1,n_steps,n_inputs])sess.run(trainop,feed_dict = {X : x_batch,Y_labels : y_batch})if i % 200 == 0:print('train accuracy =',sess.run(accuracy,feed_dict = {X : x_batch,Y_labels : y_batch}))X_test = mnist.test.images.reshape((-1,n_steps,n_inputs))Y_test = mnist.test.labelsprint('test accuracy =',sess.run(accuracy,feed_dict = {X : X_test,Y_labels : Y_test}))評(píng)估
實(shí)驗(yàn)表明求得得準(zhǔn)確率可達(dá)到99%。
疑問
我將BasicLSTMCell換成BasicRNNCell就無(wú)法訓(xùn)練,這是為什么呢?難道跟LSTM有遺忘們相關(guān)嗎?
總結(jié)
以上是生活随笔為你收集整理的练习利用LSTM实现手写数字分类任务的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 安卓平板主题软件(安卓平板主题)
- 下一篇: VAE(变分自编码器)学习笔记