tensorflow学习笔记(三十四):Saver(保存与加载模型)
Saver
tensorflow 中的 Saver 對(duì)象是用于 參數(shù)保存和恢復(fù)的。如何使用呢?
這里介紹了一些基本的用法。
官網(wǎng)中給出了這么一個(gè)例子:
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
#注意,如果不給Saver傳var_list 參數(shù)的話, 他將已 所有可以保存的 variable作為其var_list的值。
1
2
3
4
5
6
7
8
9
10
11
12
13
這里使用了三種不同的方式來(lái)創(chuàng)建 saver 對(duì)象, 但是它們內(nèi)部的原理是一樣的。我們都知道,參數(shù)會(huì)保存到 checkpoint 文件中,通過(guò)鍵值對(duì)的形式在 checkpoint中存放著。如果 Saver 的構(gòu)造函數(shù)中傳的是 dict,那么在 save 的時(shí)候,checkpoint文件中存放的就是對(duì)應(yīng)的 key-value。如下:
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, 'test-ckpt/model-2')
1
2
3
4
5
6
7
8
9
10
我們通過(guò)官方提供的工具來(lái)看一下 checkpoint 中保存了什么
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("test-ckpt/model-2", None, True)
# 輸出:
#tensor_name: variable_1
#1.0
#tensor_name: variable_2
#2.0
1
2
3
4
5
6
7
8
如果構(gòu)建saver對(duì)象的時(shí)候,我們傳入的是 list, 那么將會(huì)用對(duì)應(yīng) Variable 的 variable.op.name 作為 key。
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
saver = tf.train.Saver([v1, v2])
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, 'test-ckpt/model-2')
1
2
3
4
5
6
7
8
9
10
我們?cè)偈褂霉俜焦ぞ叽蛴〕?checkpoint 中的數(shù)據(jù),得到
tensor_name: v1
1.0
tensor_name: v2
2.0
1
2
3
4
。
如果我們現(xiàn)在想將 checkpoint 中v2的值restore到v1 中,v1的值restore到v2中,我們?cè)撛趺醋觯?
這時(shí),我們只能采用基于 dict 的 saver
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, 'test-ckpt/model-2')
1
2
3
4
5
6
7
8
9
10
save 部分的代碼如上所示,下面寫 restore 的代碼,和save代碼有點(diǎn)不同。
```python
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
#restore的時(shí)候,variable_1對(duì)應(yīng)到v2,variable_2對(duì)應(yīng)到v1,就可以實(shí)現(xiàn)目的了。
saver = tf.train.Saver({"variable_1":v2, "variable_2": v1})
# Use the saver object normally after that.
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.restore(sess, 'test-ckpt/model-2')
print(sess.run(v1), sess.run(v2))
# 輸出的結(jié)果是 2.0 1.0,如我們所望
1
2
3
4
5
6
7
8
9
10
11
12
13
我們發(fā)現(xiàn),其實(shí) 創(chuàng)建 saver對(duì)象時(shí)使用的鍵值對(duì)就是表達(dá)了一種對(duì)應(yīng)關(guān)系:
save時(shí), 表示:variable的值應(yīng)該保存到 checkpoint文件中的哪個(gè) key下
restore時(shí),表示:checkpoint文件中key對(duì)應(yīng)的值,應(yīng)該restore到哪個(gè)variable
其它
一個(gè)快速找到ckpt文件的方式
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
1
2
3
參考資料
https://www.tensorflow.org/api_docs/python/tf/train/Saver
---------------------
作者:ke1th
來(lái)源:CSDN
原文:https://blog.csdn.net/u012436149/article/details/56665612
版權(quán)聲明:本文為博主原創(chuàng)文章,轉(zhuǎn)載請(qǐng)附上博文鏈接!
總結(jié)
以上是生活随笔為你收集整理的tensorflow学习笔记(三十四):Saver(保存与加载模型)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 团队项目风险
- 下一篇: Linux配置本地端口映射