13 Tensorflow机制(翻译)
??? 代碼: tensorflow/examples/tutorials/mnist/
??? 本文的目的是來展示如何使用Tensorflow訓(xùn)練和評估手寫數(shù)字識別問題。本文的觀眾是那些對使用Tensorflow進(jìn)行機器學(xué)習(xí)感興趣的人。
??? 本文的目的并不是講解機器學(xué)習(xí)。
??? 請確認(rèn)您已經(jīng)安裝了Tensorflow。
?
??? 教程文件
| 文件 | 作用 |
| mnist.py | 用來創(chuàng)建一個完全連接的MNIST模型。 |
| fully_connected_feed.py | 使用下載的數(shù)據(jù)集訓(xùn)練模型。 |
??? 運行fully_connected_feed.py文件開始訓(xùn)練。
python fully_connected_feed.py?
??? 準(zhǔn)備數(shù)據(jù)
??? MNIST是機器學(xué)習(xí)的一個經(jīng)典問題。這個問題是識別28*28像素圖片上的數(shù)字,從0到9。
??? 更多信息,請參考Yann LeCun's MNIST page?或者 Chris Olah's visualizations of MNIST。
?
??? 數(shù)據(jù)下載
??? 在run_training()方法之前,input_data.read_data_sets()方法可以讓數(shù)據(jù)下載到本機訓(xùn)練文件夾,解壓數(shù)據(jù)并返回一個DataSet實例。
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)??? 注意:fake_data是用來進(jìn)行單元測試的,讀者可以忽略。
| 數(shù)據(jù)集 | 作用 |
| data_sets.train | 55000圖片和標(biāo)簽,用來訓(xùn)練。 |
| data_sets.validation | 5000圖片和標(biāo)簽,用來在迭代中校驗?zāi)P蜏?zhǔn)確度。 |
| data_sets.test | 10000圖片和標(biāo)簽,用來測試訓(xùn)練模型準(zhǔn)確度。 |
???
??? 輸入和占位符
??? placeholder_inputs()函數(shù)創(chuàng)建兩個tf.placeholder,用來定義輸入的形狀,包括fetch_size。
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS)) labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))??? 在訓(xùn)練循環(huán)中,圖片和標(biāo)簽數(shù)據(jù)集會被切分成batch_size大小,跟占位符匹配,然后通過feed_dict參數(shù)傳遞到sess.run()方法中。
?
????創(chuàng)建圖
??? 創(chuàng)建占位符后,mnist.py文件中會通過三個步驟來創(chuàng)建圖:inference(), loss(), 和training()。
??? inference層
??? inference()函數(shù)創(chuàng)建圖,返回預(yù)測結(jié)果。
??? 它把圖片占位符當(dāng)作輸入,并在上面構(gòu)建一對完全連接的層,使用ReLU激活后,連接一個10個節(jié)點的線性層。
??? 每一層都位于tf.name_scope聲明的命名空間中。
with tf.name_scope('hidden1'):??? 在該命名空間中,權(quán)重和偏置會產(chǎn)生tf.Variable實例,并具有所需的形狀。
weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units], stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), name='weights') biases = tf.Variable(tf.zeros([hidden1_units]), name='biases')??? 例如,這些會在hidden1命名空間中創(chuàng)建,那么權(quán)重的唯一名稱為“hidden1/weights”。
????每個變量使用初始化器作為構(gòu)造函數(shù)。
??? 通常,權(quán)重會使用tf.truncated_normal(截尾正態(tài)分布)作為初始化器,它是一個2D張量,第一個參數(shù)表示該層中的神經(jīng)元數(shù),第二個表示它連接的層中的神經(jīng)元數(shù)。再第一層hidden1中,權(quán)限矩陣的大小是[圖片像素, hidden1神經(jīng)元數(shù)],因為該權(quán)重連接圖片輸入。tf.truncated_normal初始化器會根據(jù)平均值和標(biāo)準(zhǔn)差產(chǎn)生一些隨機數(shù)。
??然后,偏置會使用tf.zeros作為初始化器,保證開始時所有數(shù)都是0。它們的形狀跟它們連接的層的神經(jīng)元一樣。
??該圖的三個主要運算:兩個tf.nn.relu操作(包括隱層中的一個tf.matmul操作)和一個額外的tf.matmul操作。然后依次創(chuàng)建,連接到輸入占位符或上一層的輸出張量上。
?
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)?
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) logits = tf.matmul(hidden2, weights) + biases??? 最后,logits張量包含輸出結(jié)果。
?
??? 損失
??? loss()函數(shù)通過添加所需的損失操作來進(jìn)一步構(gòu)建圖形。
??? 首先,將labels_placeholder的值轉(zhuǎn)換為64位整數(shù)。 然后,添加tf.nn.sparse_softmax_cross_entropy_with_logits操作,以自動從labels_placeholder產(chǎn)生標(biāo)簽,并將inference()函數(shù)的輸出與這些標(biāo)簽進(jìn)行比較。
???
labels = tf.to_int64(labels) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='xentropy')??? 然后使用tf.reduce_mean將batch維度(第一維)的交叉熵的平均數(shù)作為總損耗。
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')??? 然后返回包含損失值的張量。
??? 注意:交叉熵是信息論中的一個想法,它使我們能夠描述神經(jīng)網(wǎng)絡(luò)的預(yù)測有多糟糕。有關(guān)更多信息,請閱讀博客文章Visual Information Theory(http://colah.github.io/posts/2015-09-Visual-Information/) ??? 訓(xùn)練 ??? training()函數(shù)通過梯度下降法計算最小損失。??? 首先,它從loss()函數(shù)中獲取損失張量,并將其傳遞給tf.summary.scalar,該函數(shù)用于在與tf.summary.FileWriter一起使用時將事件生成摘要。 ??? tf.summary.scalar('loss', loss)
??? 接下來,我們實例化一個tf.train.GradientDescentOptimizer,進(jìn)行梯度下降算法。
optimizer = tf.train.GradientDescentOptimizer(learning_rate)??? 然后,我們定義一個變量,用來作為全局訓(xùn)練步驟的計數(shù)器,并且tf.train.Optimizer.minimize op用于更新系統(tǒng)中的可訓(xùn)練權(quán)重,并增加全局步長。 通常,這個操作被稱為train_op. 它是由TensorFlow會話運行的,以便引導(dǎo)一個完整的訓(xùn)練步驟。
global_step = tf.Variable(0, name='global_step', trainable=False) train_op = optimizer.minimize(loss, global_step=global_step)???
??? 訓(xùn)練模型
??? 構(gòu)建圖形后,可以在full_connected_feed.py中由用戶代碼控制的循環(huán)中進(jìn)行迭代訓(xùn)練和評估。
??? 圖
??? 在run_training()函數(shù)的頂部,其中的命令指示所有構(gòu)建的操作都與默認(rèn)的全局tf.Graph實例相關(guān)聯(lián)。
with tf.Graph().as_default():??? tf.Graph是可以作為一組一起執(zhí)行的操作的集合。 大多數(shù)TensorFlow用戶只需要依賴于單個默認(rèn)圖形。
??? 更復(fù)雜的使用多個圖形是可能的,但超出了這個簡單教程的范圍。
??? 會話
??? 一旦所有的構(gòu)建準(zhǔn)備工作已經(jīng)完成并且生成了所有必要的操作,就會創(chuàng)建一個tf.Session來運行圖形。
sess = tf.Session()??? 或者,可以將會話生成到某個作用域中:
with tf.Session() as sess:??? 會話的空參數(shù)表示此代碼將附加到默認(rèn)本地會話(或創(chuàng)建尚未創(chuàng)建)。
??? 在創(chuàng)建會話之后,所有的tf.Variable實例都通過在初始化操作中調(diào)用tf.Session.run來初始化。
??? tf.Session.run方法將進(jìn)行參數(shù)傳遞操作。在這個調(diào)用中,只進(jìn)行變量的初始值。 圖的其余部分都不在這里運行; 這在下面的訓(xùn)練循環(huán)中運行。
?
??? 訓(xùn)練循環(huán)
??? 在會話初始化變量后,可以開始訓(xùn)練。
??? 用戶代碼控制每一步的訓(xùn)練,最簡單的循環(huán)可以是:
??? 但是,本教程稍微復(fù)雜一些,因為它還必須分割每個步驟的輸入數(shù)據(jù),以匹配先前生成的占位符。
???
??? 數(shù)據(jù)輸入到圖
??? 對于每個步驟,代碼將生成一個Feed字典,其中包含一組數(shù)據(jù),用于訓(xùn)練,由其所對應(yīng)的占位符操作輸入。
??? 在fill_feed_dict()函數(shù)中,查詢給定的DataSet用于其下一個batch_size圖像和標(biāo)簽集,填充與占位符匹配的張量,其中包含下一個圖像和標(biāo)簽。
??? 然后生成一個python字典對象,其中占位符作為鍵,代表性的Feed張量作為值。??
feed_dict = {images_placeholder: images_feed,labels_placeholder: labels_feed, }??? 這將被傳遞給sess.run()函數(shù)的feed_dict參數(shù),以供該訓(xùn)練循環(huán)使用。
???
??? 檢查狀態(tài)
??? 該代碼指定在運行調(diào)用中獲取的兩個值:[train_op,loss]。
for step in xrange(FLAGS.max_steps):feed_dict = fill_feed_dict(data_sets.train,images_placeholder,labels_placeholder)_, loss_value = sess.run([train_op, loss],feed_dict=feed_dict)??? 因為要獲取兩個值,所以sess.run()返回一個包含兩個項的元組。 要提取的值列表中的每個Tensor對應(yīng)于返回的元組中的numpy數(shù)組,在該訓(xùn)練步驟中填充該張量的值。 由于train_op是沒有輸出值的操作,返回的元組中的相應(yīng)元素為None,因此被丟棄。 然而,如果模型在訓(xùn)練過程中發(fā)生分歧,則損失張量的值可能變?yōu)镹aN,因此我們捕獲該值用于記錄。
??? 假設(shè)沒有NaN,訓(xùn)練運行良好,訓(xùn)練循環(huán)還會每100個步驟打印一個簡單的狀態(tài)文本,讓用戶知道訓(xùn)練狀態(tài)。
???
??? 狀態(tài)可視化
????為了輸出TensorBoard使用的事件文件,在圖形構(gòu)建階段,所有的摘要(在這種情況下只有一個)被收集到一個Tensor中。
summary = tf.summary.merge_all()??? 然后在創(chuàng)建會話之后,可以將tf.summary.FileWriter實例化為寫入事件文件,其中包含圖形本身和摘要的值。
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)??? 最后,每次評估摘要并將輸出傳遞給add_summary()函數(shù)時,事件文件將被更新為新的摘要值。
summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step)??? 當(dāng)寫入事件文件時,可以針對訓(xùn)練文件夾運行TensorBoard,以顯示摘要中的值。
??? 注意:有關(guān)如何構(gòu)建和運行Tensorboard的更多信息,請參閱隨附的教程Tensorboard:可視化學(xué)習(xí)。
???
??? 保存檢查點
??? 為了輸出一個檢查點文件,可以用于稍后恢復(fù)模型進(jìn)行進(jìn)一步的訓(xùn)練或評估,我們實例化一個tf.train.Saver。
saver = tf.train.Saver()??? 在訓(xùn)練循環(huán)中,將定期調(diào)用tf.train.Saver.save方法,將訓(xùn)練中各變量的值寫入檢查點文件。
???
saver.save(sess, FLAGS.train_dir, global_step=step)??? 在稍后的某些時候,可以使用tf.train.Saver.restore方法來重新加載模型參數(shù)來恢復(fù)訓(xùn)練。
saver.restore(sess, FLAGS.train_dir)???
??? 評估模型
??? 每一步,代碼將嘗試針對訓(xùn)練和測試數(shù)據(jù)集來評估模型。 do_eval()函數(shù)被執(zhí)行三次,用于訓(xùn)練,驗證和測試數(shù)據(jù)集。
print('Training Data Eval:') do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.train) print('Validation Data Eval:') do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.validation) print('Test Data Eval:') do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.test)??? 請注意,更復(fù)雜的使用通常會將data_sets.test隔離,以便在大量超參數(shù)調(diào)整后才能進(jìn)行檢查。 然而,為了簡單的小MNIST問題,我們對所有數(shù)據(jù)進(jìn)行評估。
???
??? 構(gòu)建評估圖
??? 在進(jìn)入訓(xùn)練循環(huán)之前,評估操作應(yīng)該是通過調(diào)用mnist.py中的evaluate()函數(shù),使用與loss()函數(shù)相同的參數(shù)構(gòu)建的。
eval_correct = mnist.evaluation(logits, labels_placeholder)??? 評估函數(shù)簡單地生成一個tf.nn.in_top_k操作,如果真正的標(biāo)簽可以在K個最可能的預(yù)測中找到,那么可以自動對每個模型輸出進(jìn)行評分。 在這種情況下,我們將K的值設(shè)置為1,以便僅對真實標(biāo)簽考慮預(yù)測是否正確。?
eval_correct = tf.nn.in_top_k(logits, labels, 1)???
??? 評估輸出
??? 然后可以創(chuàng)建一個填充feed_dict的循環(huán),并針對eval_correct op調(diào)用sess.run()來評估給定數(shù)據(jù)集上的模型。
for step in xrange(steps_per_epoch):feed_dict = fill_feed_dict(data_set,images_placeholder,labels_placeholder)true_count += sess.run(eval_correct, feed_dict=feed_dict)??? true_count變量簡單地累加了in_top_k op已經(jīng)確定為正確的所有預(yù)測。 從那里可以從簡單地除以實例的總數(shù)來計算精度。
precision = true_count / num_examples print(' ?Num examples: %d ?Num correct: %d ?Precision @ 1: %0.04f' %(num_examples, true_count, precision))?
?
?? 原文:《TensorFlow Mechanics 101》:https://www.tensorflow.org/get_started/mnist/mechanics
?
???
?
轉(zhuǎn)載于:https://www.cnblogs.com/tengge/p/6920670.html
總結(jié)
以上是生活随笔為你收集整理的13 Tensorflow机制(翻译)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: HTML多行代码搞定微信8.0的炸裂特效
- 下一篇: 06-Flutter移动电商实战-dio