TensorFlow实验(3)
模型的保存與恢復(fù)
我們來簡單實現(xiàn)一下模型的保存與恢復(fù)
訓(xùn)練完TensorFlow模型后,可將其保存為文件,以便于預(yù)測新數(shù)據(jù)時直接加載使用。
TensorFlow模型主要包含網(wǎng)絡(luò)的設(shè)計或者圖以及已經(jīng)訓(xùn)練好的網(wǎng)絡(luò)參數(shù)的值。
TensorFlow提供的tf.train.Saver()函數(shù)可以建立一個saver對象,在會話中調(diào)用其save()函數(shù),即可將模型保存起來
save()函數(shù)的用法
| 函數(shù) | 說明 |
| save( ????? sess, ????? sace_path, ????? global_step=None, ????? latest_filename=None, ????? meta_graph_suffix='meta', ????? write_meta_graph=True, ????? write_state=True ) | sess:保存模型,要求必須有一個加載了計算圖的會話,且所有變量已被初始化。 sace_path:模型保存路徑及保存名稱 global_step:如果提供,該數(shù)字會添加到save_path后,用于區(qū)分不同訓(xùn)練階段的結(jié)果 latest_filename:檢查點文件的名稱,默認(rèn)是checkpoint meta_graph_suffix= MetaGraphDef元圖后綴,默認(rèn)為meta write_meta_graph=是否要保存元圖數(shù)據(jù),默認(rèn)為True write_state:是否要保存CheckpointStateProto,默認(rèn)為True |
模型保存
import tensorflow as tf m1 = tf.Variable(tf.constant([[1.0,3.0],[2.0,4.0]],shape=[2,2]),name='m1') m2 = tf.Variable(tf.constant([[2.0,7.0],[3.0,8.0]],shape=[2,2]),name='m2') result = m1 + m2 saver = tf.train.Saver() with tf.Session() as sess:sess.run(tf.global_variables_initializer())print('resulit:',sess.run(result))saver.save(sess,'C:/model/model.ckpt')運(yùn)行程序,當(dāng)前目錄的model文件夾下會產(chǎn)生4個文件:checkpoint,data-00000-of-00001,meta和index
checkpoint:保存模型的權(quán)重、偏置、梯度以及其他保護(hù)變量的二進(jìn)制文件。
data:保存模型的所有變量的值
meta:保存計算圖的結(jié)構(gòu)。當(dāng)meta文件存在時,不在程序中定義模型,直接加載meta可以直接運(yùn)行
index:保存string-string的鍵值對。其中的key值為張量名,value為BundleEntryProto
模型恢復(fù)
模型保存好了以后,載入發(fā)出方便。
在會話中調(diào)用saver的restore()函數(shù),就會從指定的路徑找到模型文件,并覆蓋相關(guān)參數(shù)。
saver.restore()函數(shù)的形式如表
| 函數(shù) | 說明 |
| saver.restore( ??? sess, ??? save_path ) | 從指定的路徑恢復(fù)模型。 sess:用于恢復(fù)參數(shù)模型的會話 save_path:已保存模型的路徑,通常包含模型名字 |
?
總結(jié)
以上是生活随笔為你收集整理的TensorFlow实验(3)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 64位ubuntu arm-linux-
- 下一篇: java期末考试2013及答案_java