当前位置: 首页 > 工具软件 > lite-flow > 使用案例 >

基于Zephyr在微型MCU上使用Tensor Flow Lite Micro做图像分类

宇文卓
2023-12-01

首先需要保证你已经拥有了一个图像分类的模型。

其次我们需要Zephyr RTOS。

这些可以参考如下文章:

基于Stm32F746g_disg平台下移植zephry使用TinyML预测模型_17岁boy的博客-CSDN博客

Zephyr在编译时将二进制文件转化成c语言数组_17岁boy的博客-CSDN博客

什么是Tensor Flow和lite以及数据流图_17岁boy的博客-CSDN博客

TFLite模型文件转C语言文件_17岁boy的博客-CSDN博客

Tensor Flow V2:将Tensor Flow H5模型文件转换为tflite_17岁boy的博客-CSDN博客

Tensor Flow V2:基于Tensor Flow Keras的摄氏度到华氏度温度转换的训练模型_17岁boy的博客-CSDN博客

如果你需要这里也可以使用我的模型:

我训练了两个模型,一个是识别图像中是否有人,一个是识别花的类型,你可以在这里下载到它们:

Tflite_Model.rar-嵌入式文档类资源-CSDN下载

这里我就以这里面的模型为例

下载下来以后解压可以看到两个文件夹:

People

Flowers

这两个文件夹分别包含了人检测模型与花种类模型,这两个模型是使用TF-Slim模型框架训练的

这里我以Flowers模型为例

首先我们在我们的项目文件夹中创建一个model的文件夹

mkdir model

然后将我们刚刚下载模型文件,Flowers文件夹里的flower.tflite放入到这个刚刚创建的model文件夹中

cp ~/Tflite_Model/Flowers/flower.tflite ./model

然后修改cmake文件,添加如下代码:

generate_inc_file_for_target(app model/flower.tflite ${ZEPHYR_BINARY_DIR}/include/generated/model.inc)

这段代码的意思是在编译时将flower.flite转化为数组格式的文件,并放入到bin目录的include里,这样我们在代码中就可以直接包含它了,就不需要将数组转化为数组了,这样做的原因是因为实时操作系统中是没有文件系统的。

然后我们在创建一个文件夹:image

mkdir image

将需要测试的图像放入里面,这里我建议大家将图像转化为yuv格式的,因为yuv颜色分量较低,并且图像信息不会发生太大改变,并且yuv尺寸非常小,非常适合在嵌入式设备中使用,因为我不用显示出来,只是为了做识别,并且推荐大家将图像尺寸缩小96x96。

如果你的内存大小无所谓的话可以忽略,但是请将图像缩小至96x96且通道数为1,也就是灰度图,因为训练时使用的样本图像都是96x96,这样准确率会更高。

然后在cmake添加如下代码:

generate_inc_file_for_target(app image/rose.yuv ${ZEPHYR_BINARY_DIR}/include/generated/rose.inc)

这样花的数组格式文件也有了。

然后在main函数中添加相关头文件,头文件含义这里就不介绍了,相关请参考开头的相关文章链接。

#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/schema/schema_generated.h>
#include <tensorflow/lite/version.h>
#include <zephyr.h>
#include <sys/printk.h>
#include <logging/log.h>
#include <device.h>
#include <soc.h>
#include <stdlib.h>

声明宏来定义图像类型:

#define H 96
#define W 96
#define C 1

声明与定义Flower模型的层数与标签

#define LAYER 5
char label[][1024] = {"Daisy","Dandelion","Roses","Sumflowers","Tulips"} 

定义TFLITE工具链

namespace {
    tflite::ErrorReporter* error_reporter = nullptr;
    const tflite::Model* model = nullptr;
    tflite::MicroInterpreter* interpreter = nullptr;
    TfLiteTensor* input = nullptr;
    constexpr int kTensorArenaSize = 136 * 1024;
    static uint8_t tensor_arena[kTensorArenaSize];
    static const struct device* flash_shell;
}

包含图像与模型

static  uint8_t model_buf[] = {
#include <model.inc>
};
static uint8_t rose_image_buf[] = {
#include <rose.inc>
};

以下代码在main函数中编写

创建日志工具与序列化模型,需要注意序列化以后,model_buf不能被删除或者修改,Tflite只是目前指向这块buff,后面使用模型时会去这个buff里寻找视图与算子等信息,如果清空了的话会导致模型预测失败。

    static tflite::MicroErrorReporter micro_error_reporter;
    error_reporter = &micro_error_reporter;
    model = tflite::GetModel(model_buf);
    if (model->version() != TFLITE_SCHEMA_VERSION) {
        TF_LITE_REPORT_ERROR(error_reporter,
                            "Model provided is schema version %d not equal "
                            "to supported version %d.",
                            model->version(), TFLITE_SCHEMA_VERSION);
        return -1;
    }

为模型添加2D控制器,AddAveragePool2D创建针对2D数据的池化层,如图像

AddConv2D添加图像处理的卷积层

AddDepthwiseConv2D添加深度卷积层

AddReshape添加矩阵变换功能

AddSoftmax添加Softmax回归分类功能

 static tflite::MicroMutableOpResolver<5> micro_op_resolver;
    micro_op_resolver.AddAveragePool2D();
    micro_op_resolver.AddConv2D();
    micro_op_resolver.AddDepthwiseConv2D();
    micro_op_resolver.AddReshape();
    micro_op_resolver.AddSoftmax();
    static tflite::MicroInterpreter static_interpreter(model, micro_op_resolver, tensor_arena, 
                                                       kTensorArenaSize, error_reporter);
    interpreter = &static_interpreter;

分配算法空间

 TfLiteStatus allocate_status = interpreter->AllocateTensors();
    if (allocate_status != kTfLiteOk) {
        TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
        return -2;
    }

获取输入指针

input = interpreter->input(0);

编写一个输入图像的函数:

int GetImage(int h,int w,int c,int8_t* img_buff,uint8_t* src_buf){
    int img_size = h*w*c;
    for(int i = 0;i < img_size;++i){
        img_buff[i] = src_buf[i];
    }
    return kTfLiteOk;
}

然后将图像输入里面去:

if(kTfLiteOk != GetImage(96,96,1,input->data.int8,rose_image_buf)){
        TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
    }

开始预测

if (kTfLiteOk != interpreter->Invoke()) {
        TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");          
    
    }

预测完成之后开始遍历所有输出层的权重,得到评分最高的那个并输出对应的标签

TfLiteTensor* output = interpreter->output(0);
int8_t person_score = output->data.uint8[0];
int8_t index = 0;

for(int i = 0;i < LAYER ;++i){

     if(output->data.uint8[i] > person_score){

            person_score = output->data.uint8[i];
            index = i;
     }


}

printk("score:%d,label:%s\n",person_score,label[i]);

输出结果:

score:240,label:Rose

完整代码:

#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/schema/schema_generated.h>
#include <tensorflow/lite/version.h>
#include <zephyr.h>
#include <sys/printk.h>
#include <logging/log.h>
#include <device.h>
#include <soc.h>
#include <stdlib.h>

#define H 96
#define W 96
#define C 1

#define LAYER 5
char label[][1024] = {"Daisy","Dandelion","Roses","Sumflowers","Tulips"} 

static  uint8_t model_buf[] = {
#include <model.inc>
};
static uint8_t rose_image_buf[] = {
#include <rose.inc>
};

int GetImage(int h,int w,int c,int8_t* img_buff,uint8_t* src_buf){
    int img_size = h*w*c;
    for(int i = 0;i < img_size;++i){
        img_buff[i] = src_buf[i];
    }
    return kTfLiteOk;
}

namespace {
    tflite::ErrorReporter* error_reporter = nullptr;
    const tflite::Model* model = nullptr;
    tflite::MicroInterpreter* interpreter = nullptr;
    TfLiteTensor* input = nullptr;
    constexpr int kTensorArenaSize = 136 * 1024;
    static uint8_t tensor_arena[kTensorArenaSize];
    static const struct device* flash_shell;
}

int main(){

    static tflite::MicroErrorReporter micro_error_reporter;
    error_reporter = &micro_error_reporter;
    model = tflite::GetModel(model_buf);
    if (model->version() != TFLITE_SCHEMA_VERSION) {
        TF_LITE_REPORT_ERROR(error_reporter,
                            "Model provided is schema version %d not equal "
                            "to supported version %d.",
                            model->version(), TFLITE_SCHEMA_VERSION);
        return -1;
    }

    static tflite::MicroMutableOpResolver<5> micro_op_resolver;
    micro_op_resolver.AddAveragePool2D();
    micro_op_resolver.AddConv2D();
    micro_op_resolver.AddDepthwiseConv2D();
    micro_op_resolver.AddReshape();
    micro_op_resolver.AddSoftmax();
    static tflite::MicroInterpreter static_interpreter(model, micro_op_resolver,         tensor_arena, 
                                                       kTensorArenaSize, error_reporter);
    interpreter = &static_interpreter;

    TfLiteStatus allocate_status = interpreter->AllocateTensors();
    if (allocate_status != kTfLiteOk) {
        TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
        return -2;
    }
    
    input = interpreter->input(0);

    if(kTfLiteOk != GetImage(96,96,1,input->data.int8,rose_image_buf)){
        TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
    }

    if (kTfLiteOk != interpreter->Invoke()) {
        TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");          
    
    }

    TfLiteTensor* output = interpreter->output(0);
    int8_t person_score = output->data.uint8[0];
    int8_t index = 0;

    for(int i = 0;i < LAYER ;++i){

         if(output->data.uint8[i] > person_score){

                person_score = output->data.uint8[i];
                index = i;
     }


    }

    printk("score:%d,label:%s\n",person_score,label[i]);

    return 0;

}

 类似资料: