tensorflow saver_TensorFlow: Model Persistence
TensorFlow提供了一個非常簡單的API來保存和還原神經網絡模型。這個API就是tf.train.Saver類。以下代碼給出了保存TensorFlow計算圖的方法:
import tensorflow as tf# 聲明兩個變量并計算它們的和 v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2") result = v1 + v2init_op = tf.global_variables_initializer() saver = tf.train.Saver()with tf.Session() as sess:sess.run(init_op)# 將模型保存到Saved_model目錄下saver.save(sess, "Saved_model/model.ckpt")上面的代碼實現了持久化一個簡單的TensorFlow模型的功能。雖然上述程序只指定了一個文件路徑,但這個目錄下會出現多個文件。原書中說會生成3個文件,分別是
1) model.ckpt.meta —— 保存了TensorFlow計算圖的結構
2) model.ckpt —— 保存了TensorFlow程序中每個變量的取值
3) checkpoint —— 保存了一個目錄下所有的模型文件列表
但我運行的結果是出現了4個文件,不知和系統是否有關。我用的是OpenSUSE(Linux)系統。
加載已保存的模型:
import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2") result = v1 + v2init_op = tf.global_variables_initializer() saver = tf.train.Saver()with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(result))運行結果如下:
[ 3.]這段加在模型的代碼和前面保存模型的代碼幾乎是一樣的——也是先定義了TensorFlow計算圖上的所有運算,并聲明了一個tf.train.Saver類。兩段代碼唯一的區別是,在加在模型的代碼中沒有運行變量的初始化過程,而是將變量的值通過已經保存的模型加載進來。
如果不希望重復定義圖上的運算,也可以直接加在已經持久化的圖,代碼如下:
import tensorflow as tf# 直接加載持久化的圖 saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")with tf.Session() as sess:saver.restore(sess, "Saved_model/model.ckpt")print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [3.]# 運行結果: # [ 3.]- 保存滑動平均模型
- variables_to_restore函數的使用樣例
為了方便加載時重命名滑動平均變量,tf.train.ExponentialMovingAverage類提供了variables_to_restore函數來生成tf.train.Saver類所需要的變量重命名字典
import tensorflow as tfv = tf.Variable(0, dtype=tf.float32, name="v") ema = tf.train.ExponentialMovingAverage(0.99) print(ema.variables_to_restore())# 運行結果: # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}# 注意整個saver和上一段代碼片中的saver的區別,這里就不用以變量重命名的方式載入ema了 saver = tf.train.Saver(ema.variables_to_restore()) with tf.Session() as sess:saver.restore(sess, "Saved_model/model2.ckpt")print(sess.run(v))# 運行結果: # 0.0999999- 保存為pb格式
使用tf.train.Saver會保存運行TensorFlow程序所需要的全部信息,然而有時并不需要某些信息。比如在測試或離線預測時,只需要知道如何從神經網絡的輸入層經過前向傳播計算得到輸出層即可,而不需要類似于變量初始化、模型保存等輔助節點的信息。在第6章介紹遷移學習時,會遇到類似的情況。而且,將變量取值和計算圖結構分成不同的文件存儲有時也不方便,于是TensorFlow提供了convert_variables_to_constants函數,通過這個函數可以將計算圖中的變量及其取值通過常量的方式保存,這樣整個TensorFlow計算圖可以統一存放在一個文件中。下面的程序提供了一個樣例:
import tensorflow as tf from tensorflow.python.framework import graph_utilv1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2") result = v1 + v2init_op = tf.global_variables_initializer() with tf.Session() as sess:sess.run(init_op)# 導出當前計算圖中的GraphDef部分,只需要這一部分就可以完成從輸入層到輸出層的計算過程graph_def = tf.get_default_graph().as_graph_def()# 將圖中的變量及其取值轉化為常量,同時將圖中不必要的節點去掉# 如果只關心程序中定義的某些計算時,和這些計算無關的節點就沒有必要導出并保存了# 在下面一行代碼中,最后一個參數['add']給出了需要保存的節點名稱# 注意add節點是上面定義的兩個變量相加的操作,其后面沒有:0# 而張量的名稱后面有:0,表示的是某個計算節點的第一個輸出output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:f.write(output_graph_def.SerializeToString())# 運行結果: # Converted 2 variables to const ops.- 加載pb格式的文件
總結
以上是生活随笔為你收集整理的tensorflow saver_TensorFlow: Model Persistence的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 稀疏自编码器_基于tf实现稀疏自编码和在
- 下一篇: charles 安装 ssl_charl