pip3 install https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp36-cp36m-win_amd64.whl
本机是intel的cpu但安装amd64也可以正常运行,毕竟windows里就没有x86的选项。
将基于 TensorFlow 2.3.1训练得到的模型直接转换为 TensorFlow Lite 模型,官方指南,我的TensorFlow 模型是PB格式,转换代码如下:
import tensorflow as tf
if __name__ == '__main__':
int_path = "TensorFlow_model_path_pb"
out_path = ""
# Converting a SavedModel to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_saved_model(int_path )
#Normalization or quantization
converter.post_training_quantize = True
tflite_model = converter.convert()
print(tflite_model)
open(out_path, "wb").write(tflite_model)
其中有一个选项是Normalization and quantization parameters,官方给出了两者的区别:
Normalization is a common data preprocessing technique in machine learning. The goal of normalization is to change the values to a common scale, without distorting differences in the ranges of values.
Model quantization is a technique that allows for reduced precision representations of weights and optionally, activations for both storage and computation.
In terms of preprocessing and post-processing, normalization and quantization are two independent steps. Here are the details.
转换好了模型就可以用tensoreflow lite model来预测了
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import numpy as np
import tensorflow as tf
import random
import yaml
import time
import warnings
import cv2
warnings.filterwarnings('ignore')
model_path = "Tensorflow_lite_model_path"
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print(str(input_details))
output_details = interpreter.get_output_details()
print(str(output_details))
def dataloder(img_root, txt_path):
img_paths = []
img_labels = []
with open(txt_path, 'r') as reader:
lines = reader.readlines()
for line in lines:
line = line.split(';')
# print(line)
img_name = line[0]
img_class = int(line[-1])
img_path = os.path.join(img_root, img_name)
img_paths.append(img_path)
img_labels.append(img_class)
return img_paths, img_labels
def decode_and_resize(file_path, img_label):
image_string = tf.io.read_file(file_path)
image_decoded = tf.image.decode_png(image_string)
image_resized = tf.image.resize(image_decoded, args['train']['data_loader']['resize_shape']) / 255.0
img_label = tf.cast(img_label, dtype=tf.int64)
return image_resized, img_label, file_path
if __name__ == '__main__':
config_path = './config.yml'
with open(config_path, 'r') as f:
args = yaml.load(f.read())
test_filenames, test_labels = dataloder(args['test']['test_img'], args['test']['test_txt'])
test_filenames = tf.constant(test_filenames)
test_labels = tf.constant(test_labels)
test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
test_dataset = test_dataset.map(
map_func=decode_and_resize,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(args['train']['data_loader']['batch_size'])
test_dataset = test_dataset.prefetch(tf.data.experimental.AUTOTUNE)
cnt = 0
correct = 0
sum_time = 0
wrong_imgPaths = []
befor_start = time.time()
for step, (images, labels, filepath) in enumerate(test_dataset):
start = time.time()
# 填装数据
interpreter.set_tensor(input_details[0]['index'], images)
# 调用模型了
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
end = time.time()
# print('single_image_pred_time: ', end - start)
sum_time += (end - start)
labels = labels.numpy()
paths = filepath.numpy()
for i, pred in enumerate(output_data):
cnt += 1
if np.argmax(pred) == labels[i]:
correct += 1
print("right+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
else:
wrong_imgPaths.append(paths[i])
print("wrong-----------------------------------------------------------")
print('pred_time: ', sum_time / cnt)
affor_end = time.time()
print('for_time', befor_start-affor_end)
print('correct:{}/{}'.format(correct, cnt))
print('acc: ', correct / cnt)
print('sum_time: ', sum_time)
print('pred_time: ', sum_time / cnt)
print('wrong images paths: ')
for path in wrong_imgPaths:
print(path)
配置文件
train:
data_loader:
batch_size: 1
resize_shape: [600, 600]
base_model:
input_shape: [600, 600, 3]
trainable: True
num_classes: 2
log_dir: "./logs"
test:
test_img: ''
test_txt: ''
infer:
infer_img: './data/sample/infer_data/'