1.6.2 TensorFlow Serving介绍
优质
小牛编辑
125浏览
2023-12-01
简介
TensorFlow的模型文件包含了深度学习模型的Graph和所有参数,其实就是checkpoint文件,用户可以加载模型文件继续训练或者对外提供Inference服务。
使用SavedModel导出模型
模型导出方式参考 https://tensorflow.github.io/serving/serving_basic 。
使用方法基本如下。
from tensorflow.python.saved_model import builder as saved_model_builder
export_path_base = sys.argv[-1]
export_path = os.path.join(
compat.as_bytes(export_path_base),
compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
可以参考 https://github.com/tobegit3hub/deep_recommend_system/ 提供的可运行代码示例。
./dense_classifier.py --mode savedmodel
使用exporter导出模型
这里有导出TensorFlow serving支持的模型文件例子,可以参考使用 https://github.com/tobegit3hub/deep_recommend_system/blob/master/dense_classifier.py 。
导出的代码也比较简单,用户在inputs和output中填入模型Inference时的输入和输出即可。
from tensorflow.contrib.session_bundle import exporter
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("model_path", "./model", "The path to export the model")
flags.DEFINE_integer("export_version", 1, "Version number of the model")
# Define the graph
keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
keys = tf.identity(keys_placeholder)
# Start the session
# Export the model
print("Exporting trained model to {}".format(FLAGS.model_path))
model_exporter = exporter.Exporter(saver)
model_exporter.init(
sess.graph.as_graph_def(),
named_graph_signatures={
'inputs': exporter.generic_signature({"keys": keys_placeholder, "features": inference_features}),
'outputs': exporter.generic_signature({"keys": keys, "softmax": inference_softmax, "prediction": inference_op})
})
model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess)
print 'Done exporting!'
与SavedModel方法相比,两者都可以直接用TensorFlow Serving加载,我们使用deep_recommend_system导出两种模型方式测试过预测结果一模一样,只是模型文件大小不同。
导入带assert的模型文件
在NLP等场景除了参数文件,还需要导入vocabulary等文件,可以在exporter中设置assets_collection,参考 https://github.com/tensorflow/serving/issues/264 。