生活随笔
收集整理的這篇文章主要介紹了
tensorflow--模型的保存和提取
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
參考:
TensorFlow:保存和提取模型
最全Tensorflow模型保存和提取的方法——附實例
模型的保存會覆蓋,后一次保存的模型會覆蓋上一次保存的模型。最多保存近5次結果。應當保存效果最優時候的模型,而不是訓練最后一次的模型。所以應該在每次進行模型性能評估后與保存的目前最后效果比較,如果性能更好則進行模型的保存。模型的復用,當你想用別的性能評估指標的時候,不需要再次訓練模型來獲得指標值,可以提取最優模型直接計算新指標的值。
sess
=tf
.InteractiveSession
()
sess
.run
(tf
.global_variables_initializer
())is_train
=False
saver
=tf
.train
.Saver
(max_to_keep
=3)
if is_train
:max_acc
=0f
=open('ckpt/acc.txt','w')for i
in range(100):batch_xs
, batch_ys
= mnist
.train
.next_batch
(100)sess
.run
(train_op
, feed_dict
={x
: batch_xs
, y_
: batch_ys
})val_loss
,val_acc
=sess
.run
([loss
,acc
], feed_dict
={x
: mnist
.test
.images
, y_
: mnist
.test
.labels
})print('epoch:%d, val_loss:%f, val_acc:%f'%(i
,val_loss
,val_acc
))f
.write
(str(i
+1)+', val_acc: '+str(val_acc
)+'\n')if val_acc
>max_acc
:max_acc
=val_accsaver
.save
(sess
,'ckpt/mnist.ckpt',global_step
=i
+1)f
.close
()
else:model_file
=tf
.train
.latest_checkpoint
('ckpt/')saver
.restore
(sess
,model_file
)val_loss
,val_acc
=sess
.run
([loss
,acc
], feed_dict
={x
: mnist
.test
.images
, y_
: mnist
.test
.labels
})print('val_loss:%f, val_acc:%f'%(val_loss
,val_acc
))
sess
.close
()
實操:
說明:
Social Attentional Memory Network 是一個推薦系統的模型,代碼中沒有模型保存和提取操作,數據量也算是小的,可以下載下來練習一下如何實際操作。SAMN 是我用這個模型進行的練習,可以參考,代碼后面標注 lly 的是我寫的或者修改的內容。
步驟:
先在原代碼的主目錄的下面建一個文件夾 model 。第一次進行訓練,進入目錄執行 python SAMN.py ,其中參數 is_train = True。
訓練完后發現model文件夾下面多了五個模型,最后一次保存的模型為最后模型,出現在第171次迭代的時候,即epoch=170
然后在控制臺可以看到,epoch=170時候的評估結果:
迭代第 166 次的損失為:26.586210:迭代第 167 次的損失為:26.567725:迭代第 168 次的損失為:26.586499:迭代第 169 次的損失為:26.571110:迭代第 170 次的損失為:26.668282:recall--------------------------------------------------------------------------------0.16846666666666665 0.19796666666666665 0.22703333333333334 0.24936666666666668 0.2713666666666667ndcg----------------------------------------------------------------------------------0.103169807535364 0.11131981364691529 0.11824016391770284 0.12317271387061263 0.12777428228959994save epoch 170
第二次使用保存好的模型,先將 SAMN.py 文件的參數 is_train 改為 False,再執行文件。
執行完后可以看到控制臺輸出的評估結果和之前訓練的時候的結果一樣,證明操作成功。(最優結果我只保留了k=[10, 20, 50]的情況)
總結
以上是生活随笔為你收集整理的tensorflow--模型的保存和提取的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。