当前位置: 首页 > 工具软件 > Template Lite > 使用案例 >

TensorFlow-lite添加自定义算子

终洛华
2023-12-01

TFLITE-SOC GEMM接口分析

涉及文件:
tensorflow/lite/kernels/modeling/util.sc.h
|-- PrintMatricesInfo
|-- PrintMatrix
|-- PrintMatrices
tensorflow/lite/kernels/cpu_backend_gemm.h
tensorflow/lite/kernels/cpu_backend_gemmlowp.h

cpu_backend_gemm.h

Gemm函数的重载定义

//  So in the LHS MxK matrix, the depth is K and the width in M.
// And in the RHS KxN matrix, the depth is K and the width in N.
//
// This is illustrated in this picture:
//
//                             RHS width
//                        <----------------->
//                        +-----------------+ ^
//                        |       RHS       | | Depth
//                        +-----------------+ v
//                 ^ +--+ +-----------------+
//                 | |L | |                 |
//       LHS width | |H | |      Result     |
//                 | |S | |                 |
//                 v +--+ +-----------------+
//                   <-->
//                   Depth

template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
          typename DstScalar, QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
          const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
          const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
          const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
          CpuBackendContext* context) {

// Special path for 16x8 quant gemm.
template <QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<int8_t>& lhs_params, const int8_t* lhs_data,
          const MatrixParams<int16_t>& rhs_params, const int16_t* rhs_data,
          const MatrixParams<int16_t>& dst_params, int16_t* dst_data,
          const GemmParams<int32_t, int16, quantization_flavor>& params,
          CpuBackendContext* context) {

// Special path for gemm with raw accumulator case. i.e. AccumScalar ==
// DstScalar == int32 case.
template <typename LhsScalar, typename RhsScalar,
          QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
          const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
          const MatrixParams<int32_t>& dst_params, int32_t* dst_data,
          const GemmParams<int32_t, int32_t, quantization_flavor>& params,
          CpuBackendContext* context) {

由此可见,Gemm暴露出来的接口为:

  1. Input feature map 1 parameters
  2. Input feature map 1 data
  3. Input feature map 2 parameters
  4. Input feature map 2 data
  5. Output feature map parameters
  6. Output feature map data
  7. Gemm parameters(累加器位宽,量化方式)

TFLITE-SOC Softmax接口分析

activation.cc

所有的算子都由以下四个函数定义

void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
void (*free)(TfLiteContext* context, void* buffer);
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);

其中TfLiteContext提供错误报告功能和对全局对象(包括所有张量)的访问,TfLiteNode允许实现访问其输入和输出。详细信息请参阅common.h。这两个为所有激活函数的接口(至于其他函数是否使用同样的接口,目前并不确定)。

当解释器加载模型时,它会为计算图中的每个节点调用一次 init()。如果运算在计算图中被多次使用,则会多次调用给定的 init()。对于自定义运算,将提供配置缓冲区,其中包含将参数名称映射到它们的值的 flexbuffer。内置运算的缓冲区为空,因为解释器已经解析了运算参数。需要状态的内核实现应在此处对其进行初始化,并将所有权转移给调用者。对于每个 init() 调用,都会有一个相应的 free() 调用,允许实现释放它们可能在 init() 中分配的缓冲区。每当调整输入张量的大小时,解释器都将遍历计算图以通知更改的实现。这使它们有机会调整其内部缓冲区的大小、检查输入形状和类型的有效性,以及重新计算输出形状。这一切都通过 prepare() 完成,且实现可以使用 node->user_data 访问它们的状态。最后,每次运行推断时,解释器都会遍历调用 invoke() 的计算图,同样,此处的状态也可作为 node->user_data 使用 [1]。

Softmax的具体实现思路:softmax函数会默认最后一个channellogits来进行计算,输入的channel数量可以是三维(普通CNN网络)也可能是四维(transformer里会用到)Tensorflow-lite将除了最后一个channel之外的维度全部打平,然后将一个二维数组传输到核心函数中去处理。当然,不同的数据类型(float,int16,int8)会被送入不同的函数去实现。

template <KernelType kernel_type>
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
  auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
  SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);

  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); // TF_LITE_ENSURE_EQ - 判断输入node是否为1
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  const TfLiteTensor* input;
  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
  TfLiteTensor* output;
  TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));

  TF_LITE_ENSURE(context, NumDimensions(input) >= 1); // TF_LITE_ENSURE - 判断第二个参数是否为真

  if (input->type == kTfLiteInt8 && output->type == kTfLiteInt8) {
    TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
    TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 256,
                        (0.001f * 1.f / 256));
  } else if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
    TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
                        (0.001f * 1.f / 32768));
  }

  if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
    if (kernel_type == kReference) {
      const int kScaledDiffIntegerBits = 5;
      int input_left_shift;
      tflite::PreprocessSoftmaxScaling(
          static_cast<double>(params->beta),
          static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
          &data->params.input_multiplier, &input_left_shift);
      data->params.input_left_shift = input_left_shift;
      data->params.diff_min =
          -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
                                              input_left_shift);
    } else {
      switch (output->type) {
        case kTfLiteUInt8:
        case kTfLiteInt8:
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
          // Only apply when both input & output are uint8/int8 & build with
          // clang on aarch64.
          // TODO(b/143709993): Port to ARMv7 and other platforms.
          data->params.uint8_table1 = data->uint8_table1;
          data->params.uint8_table2 = data->uint8_table2;
          optimized_ops::PopulateSoftmaxUInt8LookupTable(
              &data->params, input->params.scale, params->beta);
          break;
#endif
        case kTfLiteInt16:
        default:
          data->params.table = data->table;
          optimized_ops::PopulateSoftmaxLookupTable(
              &data->params, input->params.scale, params->beta);
      }

      data->params.zero_point = output->params.zero_point;
      data->params.scale = output->params.scale;
    }
  } else if (input->type == kTfLiteInt16) {
    TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
    TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);

    data->params.exp_lut = data->exp_lut;
    // exp LUT only used on nagative values
    // we consider exp(-10.0) is insignificant to accumulation
    gen_lut<double, int16_t, int16_t>(
        [](double value) { return std::exp(value); }, -10.0, 0.0, -1.0, 1.0,
        data->params.exp_lut);
    data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
    gen_lut<double, int16_t, int16_t>(
        [](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0, -1.0, 1.0,
        data->params.one_over_one_plus_x_lut);
    data->params.zero_point = output->params.zero_point;
    data->params.scale = output->params.scale;

    double input_scale_beta_rescale =
        input->params.scale * params->beta /
        (10.0 / 65535.0);  // scale the input_diff such that [-65535, 0]
                           // correspond to [-10.0, 0.0]
    QuantizeMultiplier(input_scale_beta_rescale, &data->params.input_multiplier,
                       &data->params.input_left_shift);
  }

  return context->ResizeTensor(context, output,
                               TfLiteIntArrayCopy(input->dims));
}
template <KernelType kernel_type>
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
  auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
  SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);

  const TfLiteTensor* input;
  TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
  TfLiteTensor* output;
  TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));

  switch (input->type) {
    case kTfLiteFloat32: {
      return SoftmaxFloat(context, input, output, params, kernel_type);
    }
    case kTfLiteUInt8: {
      switch (output->type) {
        case kTfLiteUInt8:
          return SoftmaxQuantized<uint8_t, uint8_t>(context, input, output,
                                                    data, kernel_type);
        case kTfLiteInt16:
          return SoftmaxQuantized<uint8_t, int16_t>(context, input, output,
                                                    data, kernel_type);
        default:
          TF_LITE_KERNEL_LOG(context,
                             "Only uint8_t and int16_t outputs are supported "
                             "with uint8_t inputs currently, got %s.",
                             TfLiteTypeGetName(output->type));
          return kTfLiteError;
      }
    }
    case kTfLiteInt8: {
      switch (output->type) {
        case kTfLiteInt8:
          return SoftmaxQuantized<int8_t, int8_t>(context, input, output, data,
                                                  kernel_type);
        case kTfLiteInt16:
          return SoftmaxQuantized<int8_t, int16_t>(context, input, output, data,
                                                   kernel_type);
        default:
          TF_LITE_KERNEL_LOG(context,
                             "Only int8_t and int16_t outputs are supported "
                             "with int8_t inputs currently, got %s.",
                             TfLiteTypeGetName(output->type));
          return kTfLiteError;
      }
    }
    case kTfLiteInt16: {
      return SoftmaxQuantized<int16_t, int16_t>(context, input, output, data,
                                                kernel_type);
    }

    default:
      TF_LITE_KERNEL_LOG(context,
                         "Only float32, uint8_t, Int8_t, Int16_t are supported "
                         "currently, got %s.",
                         TfLiteTypeGetName(input->type));
      return kTfLiteError;
  }
}

实际情况不会像这里这么复杂,可以参考[2]的第二步内容,Prepare函数来对输入的格式大小等等进行检查

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  const TfLiteTensor* input = GetInput(context, node, 0);
  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
  TfLiteTensor* output = GetOutput(context, node, 0);
  output->type = input->type;
  return context->ResizeTensor(context, output,
                               TfLiteIntArrayCopy(input->dims));
}

然后使用Eval函数来实现具体算法。在下面的函数中,GetInput函数用来读出输入数据(可能是一维数组,也可能是二维数组),GetOutput读出输出数据(可能是一维数组,也可能是二维数组)的指针。

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteTensor* input = GetInput(context, node, 0);
  TfLiteTensor* output = GetOutput(context, node, 0);
  const int elements = NumElements(input);
  const float* in = input->data.f;
  const float* in_end = in + elements;
  float* out = output->data.f;
  for (; in < in_end; ++in, ++out) {
    *out = std::min(std::max(0.f, *in), 1.f);
  }
  return kTfLiteOk;
}

参考资料

[1] 自定义算子
[2] 【手撕 - 深度学习】TF Lite 魔改:添加自定义 op
[3] Loading a TensorFlow-Lite model in Python with Custom Operators
[4] Custom operators

 类似资料: