将Pytorch模型转为ONNX 时出现shape 为负数情况,该问题还是由于ONNX不支持一些Pytorch的函数。
torch.repeat_interleave();
解决方法:
将repeat_interleave函数替换为interpolate(近邻插值法)
使用torch_nn_func.interpolate()替换torch.repeat_interleave();其中mode = 'nearest'
验证repeat_interleave ( )方式训练的模型、interpolate()方式训练的模型的精度与效率:
使用interpolate( )函数重新训练Pytorch模型,将该模型与源码中的repeat_interleave()方式保存的模型进行精度及效率对比,其精度和效率均未损失;
替换interpolate( )函数方式模型:25ms、abs_rel=0.14;【700张测试图像】
源码repeat_interleave ( )函数模型:32ms、abs_rel=0.14;【700张测试图像】
替换函数后的精度和效率并没有损失。