9.1 mnist_softmax 交叉熵多分类器
生活随笔
收集整理的這篇文章主要介紹了
9.1 mnist_softmax 交叉熵多分类器
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
softmax交叉熵多分類器
具體含義不再解釋,這是一個我們比較常用的一個多分類器.深度學習的一大優點就是特征的自動構建,也正是因為該優點,使得分類器層顯得不再那么重要,在Tensorflow的官方源碼中,softmax是很常見的一個多分類器.其調用也十分的簡單.此處再此單獨拿出來介紹,是為了下一步的學習做準備.
使用方法
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))用于損失函數的定義.
源碼分解
引用
# 引用,官網自帶的源碼有很多特殊之處,但是沒啥影響,自己寫的時候,完全沒必要這么多引用 # 額外添加了控制警告消息等級的code from __future__ import absolute_import from __future__ import division from __future__ import print_functionimport argparse import sysfrom tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tf import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'FLAGS = None讀取數據,定義model
# Import data,此處使用我本機的數據文件,源碼中是先檢測默認位置,沒有則自動下載mnist = input_data.read_data_sets("/home/fonttian/Data/MNIST_data/", one_hot=True)# Create the model,可以看出此處的model非常簡單,就是一層y=Wx+b,你也可以繼續增加層數,或者將其替代為卷積層,但是此處對于展示softmax并沒有什么意義x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))y = tf.matmul(x, W) + b# Define loss and optimizery_ = tf.placeholder(tf.float32, [None, 10])分類器? 損失函數==>優化算法
# 這部分代碼很簡單,一些細節我在之前已經介紹過了.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)啟動會話,運行輸出
sess = tf.InteractiveSession()tf.global_variables_initializer().run()# Trainfor _ in range(1000):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})# Test trained modelcorrect_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print(sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))main
關于main的部分之前已經有介紹了:http://blog.csdn.net/fontthrone/article/details/76735591
if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--data_dir', type=str, default='/home/fonttian/Data/MNIST_data',help='Directory for storing input data')FLAGS, unparsed = parser.parse_known_args()tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)運行結果
總結
以上是生活随笔為你收集整理的9.1 mnist_softmax 交叉熵多分类器的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python自定义:粒子群优化算法
- 下一篇: 2.1 name_scope 简单入门(