日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問 生活随笔!

生活随笔

當前位置: 首頁 > 编程资源 > 综合教程 >内容正文

综合教程

java导出mpp格式_tensorflow 模型导出总结

發布時間:2023/12/3 综合教程 47 生活家
生活随笔 收集整理的這篇文章主要介紹了 java导出mpp格式_tensorflow 模型导出总结 小編覺得挺不錯的,現在分享給大家,幫大家做個參考.

  • Checkpoints
    • 導出成CKPT
    • 加載CKPT
  • SavedModel
    • 導出為SavedModel
    • 加載SavedModel
      • Python 加載
      • JAVA 加載
      • CLI 加載
  • Frozen Graph
    • 導出為pb
      • python
      • CLI轉換工具
    • 模型加載
      • Python 加載
      • Java 加載
  • HDF5
    • HDF5導出
    • HDF5加載
  • tfLite
    • TFlite轉換
    • TFLite 加載
  • ref

tensorflow 1.0 以及2.0 提供了多種不同的模型導出格式,例如說有checkpoint,SavedModel,Frozen GraphDef,Keras model(HDF5) 以及用于移動端,嵌入式的TFLite。 本文主要講解了前4中導出格式,分別介紹了四種的導出的各種方式,以及加載,涉及了python以及java的實現。TFLite由于項目中沒有涉及,之后會補充。

模型導出主要包含了:參數以及網絡結構的導出,不同的導出格式可能是分別導出,或者是整合成一個獨立的文件。

  • 參數和網絡結構分開保存:checkpoint, SavedModel
  • 只保存權重:HDF5(可選)
  • 參數和網絡結構保存在一個文件:Frozen GraphDef,HDF5(可選)

在tensorflow 1.0中,可以見下圖,主要有三種主要的API,Keras,Estimator,以及Legacy即最初的session模型,其中tf.Keras主要保存為HDF5,Estimator保存為SavedModel,而Lagacy主要保存的是Checkpoint,并且可以通過freeze_graph,將模型變量凍結,得到Frozen GradhDef的文件。這三種格式的模型,都可以通過TFLite Converter導出為 .tflite 的模型文件,用于安卓/ios/嵌入式設備的serving。

在tensorflow 2.0中,推薦使用SavedModel進行模型的保存,所以keras默認導出格式是SavedModel,也可以通過顯性使用 .h5 后綴,使得保存的模型格式為HDF5 。 此外其他low level API,都支持導出為SavedModel格式,以及Concrete Functions。Concrete Function是一個簽名函數,有固定格式的輸入和輸出。 最終轉化成Flatbuffer,服務端運行結束。

checkpint 的導出是網絡結構和參數權重分開保存的。
其組成:

checkpoint # 列出該目錄下,保存的所有的checkpoint列表,下面有具體的例子
events.out.tfevents.1583930869.prod-cloudserver-gpu169 # tensorboad可視化所需文件,可以直觀看出模型的結構
'''
model.ckpt-13000表示前綴,代表第13000 global steps時的保存結果,我們在指定checkpoint加載時,也只需要說明前綴即可。
'''
model.ckpt-13000.index # 代表了參數名
model.ckpt-13000.data-00000-of-00001 # 代表了參數值
model.ckpt-13000.meta # 代表了網絡結構

所以一個checkpoint 組成是由兩個部分,三個文件組成,其中網絡結構部分(meta文件),以及參數部分(參數名:index,參數值:data)

其中checkpoint文件中

model_checkpoint_path: "model.ckpt-16329"
all_model_checkpoint_paths: "model.ckpt-13000"
all_model_checkpoint_paths: "model.ckpt-14000"
all_model_checkpoint_paths: "model.ckpt-15000"
all_model_checkpoint_paths: "model.ckpt-16000"
all_model_checkpoint_paths: "model.ckpt-16329"

使用tensorboard --logdir PATH_TO_CHECKPOINT: tensorboard 會調用events.out.tfevents.*
文件,并生成tensorboard,例如下圖

導出成CKPT

  • tensorflow 1.0
# in tensorflow 1.0
saver = tf.train.Saver()
saver.save(sess=session, save_path=args.save_path)
  • estimator
# estimator
"""
通過 RunConfig 配置多少時間或者多少個steps 保存一次模型,默認600s 保存一次。
具體參考 https://zhuanlan.zhihu.com/p/112062303
"""
run_config = tf.estimator.RunConfig(model_dir=FLAGS.output_dir, # 模型保存路徑session_config=config,save_checkpoints_steps=FLAGS.save_checkpoints_steps, # 多少steps保存一次ckptkeep_checkpoint_max=1)
estimator = tf.estimator.Estimator(model_fn=model_fn,config=run_config,params=None
)

關于estimator的介紹可以參考

https://zhuanlan.zhihu.com/p/112062303?zhuanlan.zhihu.com

加載CKPT

  • tf1.0
    ckpt加載的腳本如下,加載完后,session就會是保存的ckpt了。
# tf1.0
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=args.save_path)  # 讀取保存的模型
  • 對于estimator 會自動load output_dir 中的最新的ckpt。
  • 我們常用的model_file = tf.train.latest_checkpoint(FLAGS.output_dir) 獲取最新的ckpt

SavedModel

SavedModel 格式是tensorflow 2.0 推薦的格式,他很好地支持了tf-serving等部署,并且可以簡單被python,java等調用。

一個 SavedModel 包含了一個完整的 TensorFlow program, 包含了 weights 以及 計算圖 computation. 它不需要原本的模型代碼就可以加載所以很容易在 TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub 上部署。

通常SavedModel由以下幾個部分組成

├── assets/ # 所需的外部文件,例如說初始化的詞匯表文件,一般無
├── assets.extra/ # TensorFlow graph 不需要的文件, 例如說給用戶知曉的如何使用SavedModel的信息. Tensorflow 不使用這個目錄下的文件。
├── saved_model.pb # 保存的是MetaGraph的網絡結構
├── variables # 參數權重,包含了所有模型的變量(tf.Variable objects)參數├── variables.data-00000-of-00001└── variables.index

導出為SavedModel

  • tf 1.0 方式
"""tf1.0"""
x = tf.placeholder(tf.float32, [None, 784], name="myInput")
y = tf.nn.softmax(tf.matmul(x, W) + b, name="myOutput")
tf.saved_model.simple_save(sess,export_dir,inputs={"myInput": x},outputs={"myOutput": y})

simple_save 是對于普通的tf 模型導出的最簡單的方式,只需要補充簡單的必要參數,有很多參數被省略,其中最重要的參數是tagtag 是用來區別不同的 MetaGraphDef,這是在加載模型所需要的參數。其默認值是tag_constants.SERVING (“serve”).
對于某些節點,如果沒有辦法直接加name,那么可以采用 tf.identity, 為節點加名字,例如說CRF的輸出,以及使用dataset后,無法直接加input的name,都可以采用這個方式:

def addNameToTensor(someTensor, theName):return tf.identity(someTensor, name=theName)
  • estimator 方式
"""estimator"""
def serving_input_fn():label_ids = tf.placeholder(tf.int32, [None], name='label_ids')input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({'label_ids': label_ids,'input_ids': input_ids,'input_mask': input_mask,'segment_ids': segment_ids,})return input_fnif do_export:estimator._export_to_tpu = Falseestimator.export_saved_model(Flags.export_dir, serving_input_fn)
  • 保存多個 MetaGraphDef's
import tensorflow.python.saved_model
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
builder = saved_model.builder.SavedModelBuilder(export_path)signature = predict_signature_def(inputs={'myInput': x},outputs={'myOutput': y})
""" using custom tag instead of: tags=[tag_constants.SERVING] """
builder.add_meta_graph_and_variables(sess=sess,tags=["myTag"],signature_def_map={'predict': signature})
builder.save()
  • ckpt轉SavedModel
def get_saved_model(bert_config, num_labels, use_one_hot_embeddings):tf_config = tf.compat.v1.ConfigProto()tf_config.gpu_options.allow_growth = Truemodel_file = tf.train.latest_checkpoint(FLAGS.output_dir)with tf.Graph().as_default(), tf.Session(config=tf_config) as tf_sess:label_ids = tf.placeholder(tf.int32, [None], name='label_ids')input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')loss, per_example_loss, probabilities, predictions = create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,num_labels, use_one_hot_embeddings)saver = tf.train.Saver()print("restore;{}".format(model_file))saver.restore(tf_sess, model_file)tf.saved_model.simple_save(tf_sess,FLAGS.output_dir,inputs={'label_ids': label_ids,'input_ids': input_ids,'input_mask': input_mask,'segment_ids': segment_ids,},outputs={"probabilities": probabilities})
  • frozen graph to savedModel
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constantsexport_dir = 'inference/pb2saved'
graph_pb = 'inference/robert_tiny_clue/frozen_model.pb'builder = tf.saved_model.builder.SavedModelBuilder(export_dir)with tf.gfile.GFile(graph_pb, "rb") as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())sigs = {}with tf.Session(graph=tf.Graph()) as sess:# name="" is important to ensure we don't get spurious prefixingtf.import_graph_def(graph_def, name="")g = tf.get_default_graph()input_ids = sess.graph.get_tensor_by_name("input_ids:0")input_mask = sess.graph.get_tensor_by_name("input_mask:0")segment_ids = sess.graph.get_tensor_by_name("segment_ids:0")probabilities = g.get_tensor_by_name("loss/pred_prob:0")sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf.saved_model.signature_def_utils.predict_signature_def({"input_ids": input_ids,"input_mask": input_mask,"segment_ids": segment_ids}, {"probabilities": probabilities})builder.add_meta_graph_and_variables(sess,[tag_constants.SERVING],signature_def_map=sigs)builder.save()
  • tf.keras 2.0
model.save('saved_model/my_model')  
"""saved as SavedModel by default"""

加載SavedModel

對于在java中加載SavedModel,我們首先需要知道我們模型輸入和輸出,可以通過以下的腳本在terminal中運行 saved_model_cli show --dir SavedModel路徑 --all 得到類似以下的結果

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:signature_def['serving_default']:The given SavedModel SignatureDef contains the following input(s):inputs['input_ids'] tensor_info:dtype: DT_INT32shape: (-1, 128)name: input_ids:0inputs['input_mask'] tensor_info:dtype: DT_INT32shape: (-1, 128)name: input_mask:0inputs['label_ids'] tensor_info:dtype: DT_INT32shape: (-1)name: label_ids:0inputs['segment_ids'] tensor_info:dtype: DT_INT32shape: (-1, 128)name: segment_ids:0The given SavedModel SignatureDef contains the following output(s):outputs['probabilities'] tensor_info:dtype: DT_FLOATshape: (-1, 7)name: loss/pred_prob:0Method name is: tensorflow/serving/predict

首先我們可以看到有inputs,以及outputs,分別是一個key為string,value為tensor的字典,每個tensor都有各自的名字。

Python 加載

所有我們有常見兩種方式可以加載savedModel,一種是采用 tf.contrib.predictor.from_saved_model 傳入predictor模型的inputs dict,然后得到 outputs dict。 一種是直接類似tf1.0的方式,采用 tf.saved_model.loader.load, feed tensor然后fetch tensor。

  • 采用predictor
    采用predictor時, 需要傳入的字典名字用的是 inputs的key,而不是tensor的names
predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model)
prediction = predict_fn({"input_ids": [feature.input_ids],"input_mask": [feature.input_mask],"segment_ids": [feature.segment_ids],})
probabilities = prediction["probabilities"]
  • tf 1.0 采用 loader
    采用loader的方式是采用 session 的feed_dict 方式,該方式feed的是tenor的names,fetch的同樣也是tensor 的names。
    其中feed_dict的key 可以直接是tensor的name,或者是采用 sess.graph.get_tensor_by_name(TENSOR_NAME) 得到的tensor。
with tf.Session(graph=tf.Graph()) as sess:tf.saved_model.loader.load(sess, ["serve"], export_path)graph = tf.get_default_graph()feed_dict = {"input_ids_1:0": [feature.input_ids],"input_mask_1:0": [feature.input_mask],"segment_ids_1:0": [feature.segment_ids]}"""# alternative wayfeed_dict = {sess.graph.get_tensor_by_name("input_ids_1:0"): [feature.input_ids],sess.graph.get_tensor_by_name("input_mask_1:0"):[feature.input_mask],sess.graph.get_tensor_by_name("segment_ids_1:0"):[feature.segment_ids]}"""sess.run('loss/pred_prob:0',feed_dict=feed_dict
  • tf.keras 2.0
    new_model = tf.keras.models.load_model('saved_model/my_model')

JAVA 加載

注意 java加載的時候,如果遇到Op not defined 的錯誤,是需要匹配模型訓練python的tensorflow版本以及java的tensorflow版本的。

所以我們知道我們在tag-set 為serve的tag下,有4個inputs tensors,name分別為input_ids:0, input_mask:0, label_ids:0, segment_ids:0, 輸出為1個,name是 loss/pred_prob:0
并且我們知道這些tensor的類型。

所以我們可以通過下面的java代碼,進行加載,獲得結果。注意我們需要傳入的name中不需要傳入:0

import org.tensorflow.*
SavedModelBundle savedModelBundle = SavedModelBundle.load("./export_path", "serve");
Graph graph = savedModelBundle.graph();Tensor tensor = this.savedModelBundle.session().runner().feed("input_ids", inputIdTensor).feed("input_mask", inputMaskTensor).feed("segment_ids", inputSegmentTensor).fetch("loss/pred_prob").run().get(0);

CLI 加載

$ saved_model_cli show --dir export/1524906774 --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):inputs['inputs'] tensor_info:dtype: DT_STRINGshape: (-1)
The given SavedModel SignatureDef contains the following output(s):outputs['classes'] tensor_info:dtype: DT_STRINGshape: (-1, 3)outputs['scores'] tensor_info:dtype: DT_FLOATshape: (-1, 3)
Method name is: tensorflow/serving/classify$ saved_model_cli run --dir export/1524906774 --tag_set serve --signature_def serving_default --input_examples 'inputs=[{"SepalLength":[5.1],"SepalWidth":[3.3],"PetalLength":[1.7],"PetalWidth":[0.5]}]'
Result for output key classes:
[[b'0' b'1' b'2']]
Result for output key scores:
[[9.9919027e-01 8.0969761e-04 1.2872645e-09]]

Frozen Graph

frozen Graphdef 將tensorflow導出的模型的權重都freeze住,使得其都變為常量。并且模型參數和網絡結構保存在同一個文件中,可以在python以及java中自由調用。

導出為pb

python

  • 采用session方式保存frozen graph
"""tf1.0"""
from tensorflow.python.framework.graph_util import convert_variables_to_constantsoutput_graph_def = convert_variables_to_constants(session,session.graph_def,output_node_names=['loss/pred_prob'])
tf.train.write_graph(output_graph_def, args.export_dir, args.model_name, as_text=False)
  • 采用ckpt 轉換成frozen graph
    以下采用bert tensorflow模型做演示
"""
NB:首先我們要在create_model() 函數中,為我們需要的輸出節點取個名字,比如說我們要: probabilities = tf.nn.softmax(logits, axis=-1, name='pred_prob')
""" def get_frozen_model(bert_config, num_labels, use_one_hot_embeddings):tf_config = tf.compat.v1.ConfigProto()tf_config.gpu_options.allow_growth = Trueoutput_node_names = ['loss/pred_prob']model_file = tf.train.latest_checkpoint(FLAGS.output_dir)with tf.Graph().as_default(), tf.Session(config=tf_config) as tf_sess: label_ids = tf.placeholder(tf.int32, [None], name='label_ids')input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,num_labels, use_one_hot_embeddings)saver = tf.train.Saver()print("restore;{}".format(model_file))saver.restore(tf_sess, model_file)tmp_g = tf_sess.graph.as_graph_def()if FLAGS.use_opt:input_tensors = [input_ids, input_mask, segment_ids]dtypes = [n.dtype for n in input_tensors]print('optimize...')tmp_g = optimize_for_inference(tmp_g,[n.name[:-2] for n in input_tensors],output_node_names,[dtype.as_datatype_enum for dtype in dtypes],False)print('freeze...')frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess, tmp_g, output_node_names)out_graph_path = os.path.join(FLAGS.output_dir, "frozen_model.pb")with tf.io.gfile.GFile(out_graph_path, "wb") as f:f.write(frozen_graph.SerializeToString())      print(f'pb file saved in {out_graph_path}')
  • 采用savedModel 轉換成 frozen graph
from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model import tag_constantsinput_saved_model_dir = "./1583934987/"
output_node_names = "loss/pred_prob"
input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = False
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tag_constants.SERVING
output_graph_filename='frozen_graph.pb'freeze_graph.freeze_graph(input_graph_filename,input_saver_def_path,input_binary,checkpoint_path,output_node_names,restore_op_name,filename_tensor_name,output_graph_filename,clear_devices,"", "", "",input_meta_graph,input_saved_model_dir,saved_model_tags)
  • HDF5 to pb
from keras import backend as Kdef freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):"""Freezes the state of a session into a pruned computation graph.Creates a new computation graph where variable nodes are replaced byconstants taking their current value in the session. The new graph will bepruned so subgraphs that are not necessary to compute the requestedoutputs are removed.@param session The TensorFlow session to be frozen.@param keep_var_names A list of variable names that should not be frozen,or None to freeze all the variables in the graph.@param output_names Names of the relevant graph outputs.@param clear_devices Remove the device directives from the graph for better portability.@return The frozen graph definition."""graph = session.graphwith graph.as_default():freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))output_names = output_names or []output_names += [v.op.name for v in tf.global_variables()]input_graph_def = graph.as_graph_def()if clear_devices:for node in input_graph_def.node:node.device = ""frozen_graph = tf.graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)return frozen_graphfrozen_graph = freeze_session(K.get_session(),output_names=[out.op.name for out in model.outputs])tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)

CLI轉換工具

以下的工具可以快速進行ckpt到pb的轉換,但是不能再原本的基礎上增加tensor 的名字。

freeze_graph --input_checkpoint model.ckpt-16329 --output_graph 0316_roberta.pb --output_node_names loss/pred_prob --checkpoint_version 1 --input_meta_graph model.ckpt-16329.meta --input_binary true

模型加載

獲取frozen graph 中節點名字的腳本如下,但是一般來說,我們的inputs都是我們定義好的placeholders。

import tensorflow as tfdef printTensors(pb_file):"""read pb into graph_def"""with tf.gfile.GFile(pb_file, "rb") as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())"""import graph_def"""with tf.Graph().as_default() as graph:tf.import_graph_def(graph_def)"""print operations"""for op in graph.get_operations():print(op.name)printTensors("path-to-my-pbfile.pb")

得到類似如下的結果

import/input_ids:0
import/input_mask:0
import/segment_ids:0
...
import/loss/pred_prob:0

當我們知道我們要feed以及fetch的節點名稱之后,我們就可以通過python/java加載了。
跟savedModel一樣,對于某些節點,如果沒有辦法直接加name,那么可以采用 tf.identity, 為節點加名字,例如說CRF的輸出,以及使用dataset后,無法直接加input的name,都可以采用這個方式

def addNameToTensor(someTensor, theName):return tf.identity(someTensor, name=theName)

Python 加載

我們保存完frozen graph 模型后,假設我們的模型包含以下的tensors:

那么我們通過python加載的代碼如下, 采用的是session feed和fetch的方式。

with tf.Graph().as_default():output_graph_def = tf.GraphDef()"""load pb model"""with open(args_in_use.model, 'rb') as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name='') #name是必須的"""enter a text and predict"""with tf.Session() as sess:tf.global_variables_initializer().run()input_ids = sess.graph.get_tensor_by_name("input_ids:0")input_mask = sess.graph.get_tensor_by_name("input_mask:0")segment_ids = sess.graph.get_tensor_by_name("segment_ids:0")output = "loss/pred_prob:0"feed_dict = {input_ids: [feature.input_ids],input_mask: [feature.input_mask],segment_ids: [feature.segment_ids],}# 也可以直接使用# feed_dict = {#     "input_ids:0": [feature.input_ids],#     "input_mask:0": [feature.input_mask],#     "segment_ids:0": [feature.segment_ids],# }y_pred_cls = sess.run(output, feed_dict=feed_dict)

Java 加載

對于frozen graph,我們加載的方式和savedModel很類似,首先我們需要先啟動一個session,然后在啟動一個runner(),然后再feed模型的輸入,以及fetch模型的輸出。

注意 java加載的時候,如果遇到Op not defined 的錯誤,是需要匹配模型訓練python的tensorflow版本以及java的tensorflow版本的。

// TensorUtil.class
public static Session generateSession(String modelPath) throws IOException {Preconditions.checkNotNull(modelPath);byte[] graphDef = ByteStreams.toByteArray(TensorUtil.class.getResourceAsStream(modelPath));LOGGER.info("Graph Def Length: {}", graphDef.length);Graph graph = new Graph();graph.importGraphDef(graphDef);return new Session(graph);
}// model.class
this.session = TensorUtil.generateSession(modelPath);Tensor tensor = this.session.runner().feed("input_ids", inputIdTensor).feed("input_mask", inputMaskTensor).feed("segment_ids", inputSegmentTensor).fetch("loss/pred_prob").run().get(0);

HDF5

HDF5 是keras or tf.keras 特有的存儲格式。

HDF5導出

  • 導出整個模型
"""默認1.0 是HDF5,但是2.0中,是SavedModel,所以需要顯性地指定`.h5`后綴"""
model.save('my_model.h5') 
  • 導出模型weights
"""keras 1.0"""
model.save_weights('my_model_weights.h5')

HDF5加載

  • 加載整個模型(無自定義部分)
    • keras1.0
"""keras 1.0"""
from keras.models import load_model
model = load_model(model_path)
    • keras2.0
"""keras 2.0"""
new_model = tf.keras.models.load_model('my_model.h5')
  • 加載整個模型(含自定義部分)
    對于有自定義layers的或者實現的模型加載,需要增加dependencies 的映射字典,例如下面的例子:
    • keras1.0
dependencies = {'MyLayer': MyLayer(), 'auc': auc, 'log_loss': log_loss}
model = load_model(model_path, custom_objects=dependencies, compile=False)
    • keras 2.0
"""
To save custom objects to HDF5, you must do the following:1. Define a get_config method in your object, and optionally a from_config classmethod.
get_config(self) returns a JSON-serializable dictionary of parameters needed to recreate the object.
from_config(cls, config) uses the returned config from get_config to create a new object. By default, this function will use the config as initialization kwargs (return cls(**config)).2. Pass the object to the custom_objects argument when loading the model. The argument must be a dictionary mapping the string class name to the Python class. E.g. tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
"""
  • 加載模型權重
    假設你有了相同的模型構建了,那么直接運行下面的代碼,加載模型
model.load_weights('my_model_weights.h5')

如果你想要做transfer learning,即從其他的已保存的模型中加載部分的模型參數權重,自己目前的模型結構與保存的模型不同,可以通過參數的名字進行加載,加上 by_name=True

model.load_weights('my_model_weights.h5', by_name=True)

tfLite

TFlite轉換

  • savedModel to TFLite
"""
--saved_model_dir:  Type: string. Specifies the full path to the directory containing the SavedModel generated in 1.X or 2.X.
--output_file: Type: string. Specifies the full path of the output file.
"""
tflite_convert --saved_model_dir=1583934987 --output_file=rbt.tflite
  • frozen graph to TFLite
tflite_convert --graph_def_file albert_tiny_zh.pb --input_arrays 'input_ids,input_masks,segment_ids' --output_arrays 'finetune_mrc/add, finetune_mrc/add_1'--input_shapes 1,512:1,512:1,512 --output_file saved_model.tflite --enable_v1_converter --experimental_new_converter
  • HDF5 to TFLite
#--keras_model_file. Type: string. Specifies the full path of the HDF5 file containing the tf.keras model generated in 1.X or 2.X.   
#--output_file: Type: string. Specifies the full path of the output file.
tflite_convert --keras_model_file=h5_dir/ --output_file=rbt.tflite

TFLite 加載

參考 https://www.tensorflow.org/lite/guide/inference
參考 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/index.md

ref

  • https://zhuanlan.zhihu.com/p/64099452
  • https://zhuanlan.zhihu.com/p/60064947
  • https://zhuanlan.zhihu.com/p/60069860
  • https://medium.com/@jsflo.dev/saving-and-loading-a-tensorflow-model-using-the-savedmodel-api-17645576527
  • https://stackoverflow.com/questions/44329185/convert-a-graph-proto-pb-pbtxt-to-a-savedmodel-for-use-in-tensorflow-serving-o/44329200#44329200
  • https://stackoverflow.com/questions/47029048/tensorflow-how-to-freeze-a-pb-from-the-savedmodel-to-be-used-for-inference-in
  • https://zhuanlan.zhihu.com/p/64099452
  • https://stackoverflow.com/questions/59263406/how-to-find-operation-names-in-tensorflow-graph
  • http://shzhangji.com/cnblogs/2018/05/14/serve-tensorflow-estimator-with-savedmodel/?utm_source=wechat_session&utm_medium=social&utm_oi=613327238821842944&from=singlemessage&isappinstalled=0
  • http://shzhangji.com/cnblogs/2018/05/14/serve-tensorflow-estimator-with-savedmodel/?utm_source=wechat_session&utm_medium=social&utm_oi=613327238821842944&from=singlemessage&isappinstalled=0
  • https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/index.md

總結

以上是生活随笔為你收集整理的java导出mpp格式_tensorflow 模型导出总结的全部內容,希望文章能夠幫你解決所遇到的問題。

如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。