转载:tensorflow保存训练后的模型
生活随笔
收集整理的這篇文章主要介紹了
转载:tensorflow保存训练后的模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
訓練完一個模型后,為了以后重復使用,通常我們需要對模型的結果進行保存。如果用Tensorflow去實現神經網絡,所要保存的就是神經網絡中的各項權重值。建議可以使用Saver類保存和加載模型的結果。
1、使用tf.train.Saver.save()方法保存模型
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)
- sess: 用于保存變量操作的會話。
- save_path: String類型,用于指定訓練結果的保存路徑。
- global_step: 如果提供的話,這個數字會添加到save_path后面,用于構建checkpoint文件。這個參數有助于我們區分不同訓練階段的結果。
2、使用tf.train.Saver.restore方法價值模型
tf.train.Saver.restore(sess, save_path)
- sess: 用于加載變量操作的會話。
- save_path: 同保存模型是用到的的save_path參數。
下面通過一個代碼演示這兩個函數的使用方法
import tensorflow as tf import numpy as npx = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + bloss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss)isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = ''saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))with tf.Session() as sess:sess.run(tf.initialize_all_variables())if isTrain:for i in xrange(train_steps):sess.run(train, feed_dict={x: x_data})if (i + 1) % checkpoint_steps == 0:saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)else:ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)else:passprint(sess.run(w))print(sess.run(b))轉載于:https://www.cnblogs.com/txq157/p/7242385.html
總結
以上是生活随笔為你收集整理的转载:tensorflow保存训练后的模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: hibernate cascade的真正
- 下一篇: 字符串匹配(KMP 算法 含代码)