涉及文件:
tensorflow/lite/kernels/modeling/util.sc.h
|-- PrintMatricesInfo
|-- PrintMatrix
|-- PrintMatrices
tensorflow/lite/kernels/cpu_backend_gemm.h
tensorflow/lite/kernels/cpu_backend_gemmlowp.h
Gemm
函数的重载定义
// So in the LHS MxK matrix, the depth is K and the width in M.
// And in the RHS KxN matrix, the depth is K and the width in N.
//
// This is illustrated in this picture:
//
// RHS width
// <----------------->
// +-----------------+ ^
// | RHS | | Depth
// +-----------------+ v
// ^ +--+ +-----------------+
// | |L | | |
// LHS width | |H | | Result |
// | |S | | |
// v +--+ +-----------------+
// <-->
// Depth
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) {
// Special path for 16x8 quant gemm.
template <QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<int8_t>& lhs_params, const int8_t* lhs_data,
const MatrixParams<int16_t>& rhs_params, const int16_t* rhs_data,
const MatrixParams<int16_t>& dst_params, int16_t* dst_data,
const GemmParams<int32_t, int16, quantization_flavor>& params,
CpuBackendContext* context) {
// Special path for gemm with raw accumulator case. i.e. AccumScalar ==
// DstScalar == int32 case.
template <typename LhsScalar, typename RhsScalar,
QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<int32_t>& dst_params, int32_t* dst_data,
const GemmParams<int32_t, int32_t, quantization_flavor>& params,
CpuBackendContext* context) {
由此可见,Gemm暴露出来的接口为:
所有的算子都由以下四个函数定义
void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
void (*free)(TfLiteContext* context, void* buffer);
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
其中TfLiteContext
提供错误报告功能和对全局对象(包括所有张量)的访问,TfLiteNode
允许实现访问其输入和输出。详细信息请参阅common.h
。这两个为所有激活函数的接口(至于其他函数是否使用同样的接口,目前并不确定)。
当解释器加载模型时,它会为计算图中的每个节点调用一次 init()
。如果运算在计算图中被多次使用,则会多次调用给定的 init()
。对于自定义运算,将提供配置缓冲区,其中包含将参数名称映射到它们的值的 flexbuffer
。内置运算的缓冲区为空,因为解释器已经解析了运算参数。需要状态的内核实现应在此处对其进行初始化,并将所有权转移给调用者。对于每个 init()
调用,都会有一个相应的 free()
调用,允许实现释放它们可能在 init()
中分配的缓冲区。每当调整输入张量的大小时,解释器都将遍历计算图以通知更改的实现。这使它们有机会调整其内部缓冲区的大小、检查输入形状和类型的有效性,以及重新计算输出形状。这一切都通过 prepare()
完成,且实现可以使用 node->user_data
访问它们的状态。最后,每次运行推断时,解释器都会遍历调用 invoke()
的计算图,同样,此处的状态也可作为 node->user_data
使用 [1]。
Softmax
的具体实现思路:softmax
函数会默认最后一个channel
为logits
来进行计算,输入的channel数量可以是三维(普通CNN
网络)也可能是四维(transformer
里会用到)Tensorflow-lite
将除了最后一个channel之外的维度全部打平,然后将一个二维数组传输到核心函数中去处理。当然,不同的数据类型(float,int16,int8)会被送入不同的函数去实现。
template <KernelType kernel_type>
TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); // TF_LITE_ENSURE_EQ - 判断输入node是否为1
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE(context, NumDimensions(input) >= 1); // TF_LITE_ENSURE - 判断第二个参数是否为真
if (input->type == kTfLiteInt8 && output->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 256,
(0.001f * 1.f / 256));
} else if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768,
(0.001f * 1.f / 32768));
}
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
if (kernel_type == kReference) {
const int kScaledDiffIntegerBits = 5;
int input_left_shift;
tflite::PreprocessSoftmaxScaling(
static_cast<double>(params->beta),
static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
&data->params.input_multiplier, &input_left_shift);
data->params.input_left_shift = input_left_shift;
data->params.diff_min =
-1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_left_shift);
} else {
switch (output->type) {
case kTfLiteUInt8:
case kTfLiteInt8:
#ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
// Only apply when both input & output are uint8/int8 & build with
// clang on aarch64.
// TODO(b/143709993): Port to ARMv7 and other platforms.
data->params.uint8_table1 = data->uint8_table1;
data->params.uint8_table2 = data->uint8_table2;
optimized_ops::PopulateSoftmaxUInt8LookupTable(
&data->params, input->params.scale, params->beta);
break;
#endif
case kTfLiteInt16:
default:
data->params.table = data->table;
optimized_ops::PopulateSoftmaxLookupTable(
&data->params, input->params.scale, params->beta);
}
data->params.zero_point = output->params.zero_point;
data->params.scale = output->params.scale;
}
} else if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
data->params.exp_lut = data->exp_lut;
// exp LUT only used on nagative values
// we consider exp(-10.0) is insignificant to accumulation
gen_lut<double, int16_t, int16_t>(
[](double value) { return std::exp(value); }, -10.0, 0.0, -1.0, 1.0,
data->params.exp_lut);
data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut;
gen_lut<double, int16_t, int16_t>(
[](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0, -1.0, 1.0,
data->params.one_over_one_plus_x_lut);
data->params.zero_point = output->params.zero_point;
data->params.scale = output->params.scale;
double input_scale_beta_rescale =
input->params.scale * params->beta /
(10.0 / 65535.0); // scale the input_diff such that [-65535, 0]
// correspond to [-10.0, 0.0]
QuantizeMultiplier(input_scale_beta_rescale, &data->params.input_multiplier,
&data->params.input_left_shift);
}
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
template <KernelType kernel_type>
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
return SoftmaxFloat(context, input, output, params, kernel_type);
}
case kTfLiteUInt8: {
switch (output->type) {
case kTfLiteUInt8:
return SoftmaxQuantized<uint8_t, uint8_t>(context, input, output,
data, kernel_type);
case kTfLiteInt16:
return SoftmaxQuantized<uint8_t, int16_t>(context, input, output,
data, kernel_type);
default:
TF_LITE_KERNEL_LOG(context,
"Only uint8_t and int16_t outputs are supported "
"with uint8_t inputs currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
case kTfLiteInt8: {
switch (output->type) {
case kTfLiteInt8:
return SoftmaxQuantized<int8_t, int8_t>(context, input, output, data,
kernel_type);
case kTfLiteInt16:
return SoftmaxQuantized<int8_t, int16_t>(context, input, output, data,
kernel_type);
default:
TF_LITE_KERNEL_LOG(context,
"Only int8_t and int16_t outputs are supported "
"with int8_t inputs currently, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
case kTfLiteInt16: {
return SoftmaxQuantized<int16_t, int16_t>(context, input, output, data,
kernel_type);
}
default:
TF_LITE_KERNEL_LOG(context,
"Only float32, uint8_t, Int8_t, Int16_t are supported "
"currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
实际情况不会像这里这么复杂,可以参考[2]的第二步内容,Prepare
函数来对输入的格式大小等等进行检查
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TfLiteTensor* output = GetOutput(context, node, 0);
output->type = input->type;
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
然后使用Eval
函数来实现具体算法。在下面的函数中,GetInput
函数用来读出输入数据(可能是一维数组,也可能是二维数组),GetOutput
读出输出数据(可能是一维数组,也可能是二维数组)的指针。
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const int elements = NumElements(input);
const float* in = input->data.f;
const float* in_end = in + elements;
float* out = output->data.f;
for (; in < in_end; ++in, ++out) {
*out = std::min(std::max(0.f, *in), 1.f);
}
return kTfLiteOk;
}
[1] 自定义算子
[2] 【手撕 - 深度学习】TF Lite 魔改:添加自定义 op
[3] Loading a TensorFlow-Lite model in Python with Custom Operators
[4] Custom operators