现在很多算法都是用pytorch框架训练的,但是在移动端部署很多又使用TensorFlow lite,因此需要将pytorch模型转换到TensorFlow lite。
将pytorch模型转到TensorFlow lite的流程是pytorch->onnx->tensorflow->tensorflow lite,本文记录一下踩坑的过程。
这一步比较简单,使用pytorch自带接口就行。不过有一点需要注意的,就是opset版本,可能会影响后续的转换。
os.environ['CUDA_VISIBLE_DEVICES']='0'
model_path = 'model.pth'
model = architecture.IMDN_RTC(upscale=2).cuda()
model_dict = utils.load_state_dict(model_path)
model.load_state_dict(model_dict, strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
dummy_input = torch.rand(1, 3, 224, 224).cuda()
input_names = ["input"]
#output_names = ["output1", "output2", "output3"]
output_names = ["output"]
#使用pytorch的onnx模块来进行转换
#opset 10转换后,使用onnxruntime运行,在pixelshuffle处会出错
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, verbose=True,
input_names=input_names, output_names=output_names,
dynamic_axes={'input': [0, 2, 3], 'output': [0, 2, 3]})
session = onnxruntime.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
#output_name = session.get_outputs()[0].name
output_names = [s.name for s in session.get_outputs()]
input_shape = session.get_inputs()[0].shape
img = cv2.imread('babyx2.bmp')[:,:,::-1]
img = np.transpose(img, (2, 0, 1)) / 255.
img = torch.from_numpy(img).unsqueeze(0).float()
res = session.run(output_names, {input_name: img.cpu().numpy()})
tmp = res[0]
tmp = np.clip(tmp[0], 0, 1)
img = np.array(tmp*255, dtype=np.uint8)
img = np.transpose(img, (1, 2, 0))[:,:,::-1]
cv2.imwrite('tmp.jpg', img)
torch.onnx.export后,就得到了onnx模型,后面的代码是使用onnxruntime测试转换后的onnx模型。建议每一步转换后,都测试一下转换后模型的结果,确保每一步都是正确的。
需要安装onnx-tensorflow进行转换。
from onnx_tf.backend import prepare
import onnx
import tensorflow as tf
if __name__ == '__main__':
onnx_model = onnx.load("model.onnx") # load onnx model
tf_rep = prepare(onnx_model) # prepare tf representation
tf_rep.export_graph("model.tf") # export the model
img = cv2.imread('babyx2.bmp')[:,:,::-1]
img = np.transpose(img, (2, 0, 1)) / 255.
img = torch.from_numpy(img).unsqueeze(0).float()
input = img.numpy()
if 0:
output = tf_rep.run(input) # run the loaded model
res = output.output[0]
res = np.clip(res, 0, 1)
im = np.array(res*255, dtype=np.uint8)
im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
cv2.imwrite('tfres.jpg', im1)
else:
saved_model = tf.saved_model.load("model.tf")
detect_fn = saved_model.signatures["serving_default"]
output = detect_fn(tf.constant(input))
tmp = np.array(output['output'])[0]
res = np.clip(tmp, 0, 1)
im = np.array(res*255, dtype=np.uint8)
im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
cv2.imwrite('savedmodelres.jpg', im1)
转换部分就是前三行代码,后面是对TensorFlow模型的测试,确保转换结果没有问题。
没想到这一步是比较坑的,换了几个TensorFlow版本,最终使用tf2.5,转换成功了,参考issue。
import tensorflow as tf
if __name__ == '__main__':
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('model.tf') # path to the SavedModel directory
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
# test tflite model
interpreter = tf.lite.Interpreter(model_path='model.tflite')
#my_signature = interpreter.get_signature_runner()
img = cv2.imread('babyx2.bmp')[:,:,::-1]
img = np.transpose(img, (2, 0, 1)) / 255.
img = img[np.newaxis, :]
#output = my_signature(tf.constant(img))
print()
interpreter.resize_tensor_input(0, [1, 3, 256, 256])
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on random input data.
#input_shape = input_details[0]['shape_signature']
#input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
input_data = img.astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
res = np.clip(output_data[0], 0, 1)
im = np.array(res*255, dtype=np.uint8)
im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
cv2.imwrite('tfliteres.jpg', im1)
这里需要注意一点,converter.target_spec.supported_ops这个需要加上,不然有些op在TensorFlow lite中不支持,转换不成功。