Tensorflow Lite API使用

谯皓君
2023-12-01

TensorFlow Lite provides programming APIs in C++, Java and Python, with experimental bindings for several other languages (C, Swift, Objective-C).

TensorFlow Lite支持的API语言非常多。

C++ 加载Model

Tensorflow Lite多数情况下都是namespace tflite为命名空间的,

tflite::FlatBufferModel类封装了加载Model。

class FlatBufferModel {
 public:
  // Builds a model based on a file.
  // Caller retains ownership of `error_reporter` and must ensure its lifetime
  // is longer than the FlatBufferModel instance.
  // Returns a nullptr in case of failure.
  static std::unique_ptr<FlatBufferModel> BuildFromFile(
      const char* filename,
      ErrorReporter* error_reporter = DefaultErrorReporter());

  // Builds a model based on a pre-loaded flatbuffer.
  // Caller retains ownership of the buffer and should keep it alive until
  // the returned object is destroyed. Caller also retains ownership of
  // `error_reporter` and must ensure its lifetime is longer than the
  // FlatBufferModel instance.
  // Returns a nullptr in case of failure.
  // NOTE: this does NOT validate the buffer so it should NOT be called on
  // invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
  static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
      const char* caller_owned_buffer, size_t buffer_size,
      ErrorReporter* error_reporter = DefaultErrorReporter());
}

传入Model文件路径即可加载

tflite::FlatBufferModel model(path_to_model);

JAVA 加载Model

对于JAVA 来说最主要的API就是Interpreter接口;

初始化Model

public Interpreter(@NotNull File modelFile);

或者通过mappedByteBuffer初始化

public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer);

C++ 运行 Model

C++运行Model有如下几个步骤:

  • 构建FlatBufferModel,并且从构建的FlatBufferModel中初始化Interpreter实例;
  • 可以选择性地进行优化Tensor;
  • 设置Tensor的值;
  • 调用运行推理;
  • 读取Tensor的输出值;

使用Interpreter需要注意的地方有:

  • Tensor用整数值表示,避免使用字符串进行比较;
  • 访问Interpreter避免在多个线程中并发访问;
  • Tensor的输入和输出内存分配使用AllocateTensors() 在重置大小之后进行

简单地调用TensorFlow Lite如下:

tflite::FlatBufferModel model(path_to_model);

tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);

// Resize input tensors, if desired.
interpreter->AllocateTensors();

float* input = interpreter->typed_input_tensor<float>(0);
// Fill `input`.

interpreter->Invoke();

float* output = interpreter->typed_output_tensor<float>(0);

 

 类似资料: