import os
import math
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
slim = tf.contrib.slim
def main(_):
checkpoint_path = './log/'
test_path = './test/5_4.jpg'
num_classes = 6
model_name = 'resnet_v2_152'
preprocessing_name = None
test_image_size = None
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()
# Select the model
network_fn = nets_factory.get_network_fn(
model_name,
num_classes,
is_training=False)
# Select the preprocessing function
preprocessing_name = preprocessing_name or model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
test_image_size = test_image_size or network_fn.default_image_size
if tf.gfile.IsDirectory(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
else:
checkpoint_path = checkpoint_path
tf.Graph().as_default()
with tf.Session() as sess:
image = open(test_path, 'rb').read()
image = tf.image.decode_jpeg(image, channels=3)
processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
processed_images = tf.expand_dims(processed_image, 0)
logits, _ = network_fn(processed_images)
predictions = tf.argmax(logits, 1)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
np_image, network_input, predictions = sess.run([image, processed_image, predictions])
a = logits.eval()
if predictions[0] == 0:
label = "chunbai"
if predictions[0] == 1:
label = "chunhei"
if predictions[0] == 2:
label = "hongse"
if predictions[0] == 3:
label = "huierxian or huishiban"
if predictions[0] == 4:
label = "huiyudian or huipendian"
if predictions[0] == 5:
label = "qilinhua"
print('{} {}'.format(test_path, label))
if __name__ == '__main__':
tf.app.run()