pytorch模型转到TensorFlow lite:pytorch->onnx->tensorflow->tensorflow lite

施飞雨
2023-12-01

现在很多算法都是用pytorch框架训练的,但是在移动端部署很多又使用TensorFlow lite,因此需要将pytorch模型转换到TensorFlow lite。

将pytorch模型转到TensorFlow lite的流程是pytorch->onnx->tensorflow->tensorflow lite,本文记录一下踩坑的过程。

1、pytorch转onnx

这一步比较简单,使用pytorch自带接口就行。不过有一点需要注意的,就是opset版本,可能会影响后续的转换。

    os.environ['CUDA_VISIBLE_DEVICES']='0'
    model_path = 'model.pth'
    model = architecture.IMDN_RTC(upscale=2).cuda()
    model_dict = utils.load_state_dict(model_path)
    model.load_state_dict(model_dict, strict=True)

    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False

    dummy_input = torch.rand(1, 3, 224, 224).cuda()
    input_names = ["input"]
    #output_names = ["output1", "output2", "output3"]
    output_names = ["output"]
    #使用pytorch的onnx模块来进行转换
    #opset 10转换后,使用onnxruntime运行,在pixelshuffle处会出错
    torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11, verbose=True, 
            input_names=input_names, output_names=output_names, 
            dynamic_axes={'input': [0, 2, 3], 'output': [0, 2, 3]})

    session = onnxruntime.InferenceSession("model.onnx")
    input_name = session.get_inputs()[0].name
    #output_name = session.get_outputs()[0].name
    output_names = [s.name for s in session.get_outputs()]
    input_shape = session.get_inputs()[0].shape

    img = cv2.imread('babyx2.bmp')[:,:,::-1]
    img = np.transpose(img, (2, 0, 1)) / 255.
    img = torch.from_numpy(img).unsqueeze(0).float()
    res = session.run(output_names, {input_name: img.cpu().numpy()})
    tmp = res[0]
    tmp = np.clip(tmp[0], 0, 1)
    img = np.array(tmp*255, dtype=np.uint8)
    img = np.transpose(img, (1, 2, 0))[:,:,::-1]
    cv2.imwrite('tmp.jpg', img)

torch.onnx.export后,就得到了onnx模型,后面的代码是使用onnxruntime测试转换后的onnx模型。建议每一步转换后,都测试一下转换后模型的结果,确保每一步都是正确的。

2、onnx转TensorFlow

需要安装onnx-tensorflow进行转换。

from onnx_tf.backend import prepare
import onnx
import tensorflow as tf
if __name__ == '__main__':
    onnx_model = onnx.load("model.onnx")  # load onnx model
    tf_rep = prepare(onnx_model)  # prepare tf representation
    tf_rep.export_graph("model.tf")  # export the model

    img = cv2.imread('babyx2.bmp')[:,:,::-1]
    img = np.transpose(img, (2, 0, 1)) / 255.
    img = torch.from_numpy(img).unsqueeze(0).float()
    input = img.numpy()
    if 0:
        output = tf_rep.run(input)  # run the loaded model
        res = output.output[0]
        res = np.clip(res, 0, 1)
        im = np.array(res*255, dtype=np.uint8)
        im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
        cv2.imwrite('tfres.jpg', im1)
    else:
        saved_model = tf.saved_model.load("model.tf")
        detect_fn = saved_model.signatures["serving_default"]
        output = detect_fn(tf.constant(input))
        tmp = np.array(output['output'])[0]
        res = np.clip(tmp, 0, 1)
        im = np.array(res*255, dtype=np.uint8)
        im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
        cv2.imwrite('savedmodelres.jpg', im1)

转换部分就是前三行代码,后面是对TensorFlow模型的测试,确保转换结果没有问题。

3、TensorFlow转TensorFlow lite

没想到这一步是比较坑的,换了几个TensorFlow版本,最终使用tf2.5,转换成功了,参考issue

import tensorflow as tf

if __name__ == '__main__':
    # Convert the model
    converter = tf.lite.TFLiteConverter.from_saved_model('model.tf') # path to the SavedModel directory
    converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
    ]
    tflite_model = converter.convert()

    # Save the model.
    with open('model.tflite', 'wb') as f:
      f.write(tflite_model)

    # test tflite model
    interpreter = tf.lite.Interpreter(model_path='model.tflite')
    #my_signature = interpreter.get_signature_runner()
    img = cv2.imread('babyx2.bmp')[:,:,::-1]
    img = np.transpose(img, (2, 0, 1)) / 255.
    img = img[np.newaxis, :]
    #output = my_signature(tf.constant(img))
    print()

    interpreter.resize_tensor_input(0, [1, 3, 256, 256])
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Test the model on random input data.
    #input_shape = input_details[0]['shape_signature']
    #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
    input_data = img.astype(np.float32)
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details[0]['index'])
    res = np.clip(output_data[0], 0, 1)
    im = np.array(res*255, dtype=np.uint8)
    im1 = np.transpose(im, (1, 2, 0))[:,:,::-1]
    cv2.imwrite('tfliteres.jpg', im1)

这里需要注意一点,converter.target_spec.supported_ops这个需要加上,不然有些op在TensorFlow lite中不支持,转换不成功。

 

 

 

 

 

 

 类似资料: