当前位置: 首页 > 工具软件 > Slim Select > 使用案例 >

tensorflow-slim非量化分类图片模型测试

鲜于阳成
2023-12-01
import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
slim = tf.contrib.slim

def main(_):
    checkpoint_path = './log/'
    test_path = './test/5_4.jpg'
    num_classes = 6
    model_name = 'resnet_v2_152'
    preprocessing_name = None
    test_image_size = None
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()
        # Select the model
        network_fn = nets_factory.get_network_fn(
            model_name,
            num_classes,
            is_training=False)
        # Select the preprocessing function
        preprocessing_name = preprocessing_name or model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=False)

        test_image_size = test_image_size or network_fn.default_image_size

        if tf.gfile.IsDirectory(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint_path = checkpoint_path

        tf.Graph().as_default()
        with tf.Session() as sess:
            image = open(test_path, 'rb').read()
            image = tf.image.decode_jpeg(image, channels=3)
            processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
            processed_images = tf.expand_dims(processed_image, 0)
            logits, _ = network_fn(processed_images)

            predictions = tf.argmax(logits, 1)
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)
            np_image, network_input, predictions = sess.run([image, processed_image, predictions])
            a = logits.eval()

            if predictions[0] == 0:
                label = "chunbai"
            if predictions[0] == 1:
                label = "chunhei"
            if predictions[0] == 2:
                label = "hongse"
            if predictions[0] == 3:
                label = "huierxian or huishiban"
            if predictions[0] == 4:
                label = "huiyudian or huipendian"
            if predictions[0] == 5:
                label = "qilinhua"
            print('{} {}'.format(test_path, label))
if __name__ == '__main__':
    tf.app.run()
 类似资料: