单独调用开源验证码深度学习项目captcha_trainer生成的pb模型

鲜于华容
2023-12-01
见代码,注意在tensorflow2.0环境下
import tensorflow as tf
import numpy as np
import cv2


def load_pb(pb_path, img_array, input_tensor_name, out_tensor_name):
    """
    通过加载pb格式的模型来预测
    input_tensor_name like "conv2d_1_input:0"
    """
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_path, "rb") as f:
            # model_bytes = parse_model(f.read())
            init = tf.global_variables_initializer()

            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")
        # 用于获取 input_tensor_name
        tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
        # print(tensor_name_list)
        with tf.Session() as sess:

            sess.run(init)
            input_x = sess.graph.get_tensor_by_name(input_tensor_name)
            output = sess.graph.get_tensor_by_name(out_tensor_name)
            res = sess.run(output, feed_dict={input_x: img_array})
        return res


def predict():
    x_data = np.ones([1, 64, 32, 3])
    img = cv2.imdecode(np.fromfile(r'0211_11ce565a-0ab8-41dd-b467-ce9ff0790753.png', dtype=np.uint8), -1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # cv2默认是BGR模式
    img = cv2.resize(img, (64, 32))
    img = img.astype(np.float32)
    img /= 255
    img = img.T
    img = np.expand_dims(img, -1)
    x_data[0] = img
    d = {0: '', 1: '0', 2: '1', 3: '2', 4: '3', 5: '4', 6: '5', 7: '6', 8: '7', 9: '8', 10: '9'}
    a = load_pb('test-CNNX-LSTM-H64-CTC-C3_10000.pb', x_data, 'input:0', 'dense_decoded:0')
    text = ''
    for i in a[0]:
        if i == -1 or i == len(d):
            continue
        text += d[i]
    print(text)

欢迎加入小白交流群1135165504,一起学习共同进步,只交流深度学习相关不吹水,有资源大家一起分享

 类似资料: