import tensorflow as tf
import numpy as np
import cv2
def load_pb(pb_path, img_array, input_tensor_name, out_tensor_name):
"""
通过加载pb格式的模型来预测
input_tensor_name like "conv2d_1_input:0"
"""
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
# model_bytes = parse_model(f.read())
init = tf.global_variables_initializer()
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
# 用于获取 input_tensor_name
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
# print(tensor_name_list)
with tf.Session() as sess:
sess.run(init)
input_x = sess.graph.get_tensor_by_name(input_tensor_name)
output = sess.graph.get_tensor_by_name(out_tensor_name)
res = sess.run(output, feed_dict={input_x: img_array})
return res
def predict():
x_data = np.ones([1, 64, 32, 3])
img = cv2.imdecode(np.fromfile(r'0211_11ce565a-0ab8-41dd-b467-ce9ff0790753.png', dtype=np.uint8), -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # cv2默认是BGR模式
img = cv2.resize(img, (64, 32))
img = img.astype(np.float32)
img /= 255
img = img.T
img = np.expand_dims(img, -1)
x_data[0] = img
d = {0: '', 1: '0', 2: '1', 3: '2', 4: '3', 5: '4', 6: '5', 7: '6', 8: '7', 9: '8', 10: '9'}
a = load_pb('test-CNNX-LSTM-H64-CTC-C3_10000.pb', x_data, 'input:0', 'dense_decoded:0')
text = ''
for i in a[0]:
if i == -1 or i == len(d):
continue
text += d[i]
print(text)
欢迎加入小白交流群1135165504,一起学习共同进步,只交流深度学习相关不吹水,有资源大家一起分享