首先需要保证你已经拥有了一个图像分类的模型。
其次我们需要Zephyr RTOS。
这些可以参考如下文章:
基于Stm32F746g_disg平台下移植zephry使用TinyML预测模型_17岁boy的博客-CSDN博客
Zephyr在编译时将二进制文件转化成c语言数组_17岁boy的博客-CSDN博客
什么是Tensor Flow和lite以及数据流图_17岁boy的博客-CSDN博客
TFLite模型文件转C语言文件_17岁boy的博客-CSDN博客
Tensor Flow V2:将Tensor Flow H5模型文件转换为tflite_17岁boy的博客-CSDN博客
Tensor Flow V2:基于Tensor Flow Keras的摄氏度到华氏度温度转换的训练模型_17岁boy的博客-CSDN博客
如果你需要这里也可以使用我的模型:
我训练了两个模型,一个是识别图像中是否有人,一个是识别花的类型,你可以在这里下载到它们:
Tflite_Model.rar-嵌入式文档类资源-CSDN下载
这里我就以这里面的模型为例
下载下来以后解压可以看到两个文件夹:
People
Flowers
这两个文件夹分别包含了人检测模型与花种类模型,这两个模型是使用TF-Slim模型框架训练的
这里我以Flowers模型为例
首先我们在我们的项目文件夹中创建一个model的文件夹
mkdir model
然后将我们刚刚下载模型文件,Flowers文件夹里的flower.tflite放入到这个刚刚创建的model文件夹中
cp ~/Tflite_Model/Flowers/flower.tflite ./model
然后修改cmake文件,添加如下代码:
generate_inc_file_for_target(app model/flower.tflite ${ZEPHYR_BINARY_DIR}/include/generated/model.inc)
这段代码的意思是在编译时将flower.flite转化为数组格式的文件,并放入到bin目录的include里,这样我们在代码中就可以直接包含它了,就不需要将数组转化为数组了,这样做的原因是因为实时操作系统中是没有文件系统的。
然后我们在创建一个文件夹:image
mkdir image
将需要测试的图像放入里面,这里我建议大家将图像转化为yuv格式的,因为yuv颜色分量较低,并且图像信息不会发生太大改变,并且yuv尺寸非常小,非常适合在嵌入式设备中使用,因为我不用显示出来,只是为了做识别,并且推荐大家将图像尺寸缩小96x96。
如果你的内存大小无所谓的话可以忽略,但是请将图像缩小至96x96且通道数为1,也就是灰度图,因为训练时使用的样本图像都是96x96,这样准确率会更高。
然后在cmake添加如下代码:
generate_inc_file_for_target(app image/rose.yuv ${ZEPHYR_BINARY_DIR}/include/generated/rose.inc)
这样花的数组格式文件也有了。
然后在main函数中添加相关头文件,头文件含义这里就不介绍了,相关请参考开头的相关文章链接。
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/schema/schema_generated.h>
#include <tensorflow/lite/version.h>
#include <zephyr.h>
#include <sys/printk.h>
#include <logging/log.h>
#include <device.h>
#include <soc.h>
#include <stdlib.h>
声明宏来定义图像类型:
#define H 96
#define W 96
#define C 1
声明与定义Flower模型的层数与标签
#define LAYER 5
char label[][1024] = {"Daisy","Dandelion","Roses","Sumflowers","Tulips"}
定义TFLITE工具链
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
constexpr int kTensorArenaSize = 136 * 1024;
static uint8_t tensor_arena[kTensorArenaSize];
static const struct device* flash_shell;
}
包含图像与模型
static uint8_t model_buf[] = {
#include <model.inc>
};
static uint8_t rose_image_buf[] = {
#include <rose.inc>
};
以下代码在main函数中编写
创建日志工具与序列化模型,需要注意序列化以后,model_buf不能被删除或者修改,Tflite只是目前指向这块buff,后面使用模型时会去这个buff里寻找视图与算子等信息,如果清空了的话会导致模型预测失败。
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = µ_error_reporter;
model = tflite::GetModel(model_buf);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return -1;
}
为模型添加2D控制器,AddAveragePool2D创建针对2D数据的池化层,如图像
AddConv2D添加图像处理的卷积层
AddDepthwiseConv2D添加深度卷积层
AddReshape添加矩阵变换功能
AddSoftmax添加Softmax回归分类功能
static tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddAveragePool2D();
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddReshape();
micro_op_resolver.AddSoftmax();
static tflite::MicroInterpreter static_interpreter(model, micro_op_resolver, tensor_arena,
kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
分配算法空间
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return -2;
}
获取输入指针
input = interpreter->input(0);
编写一个输入图像的函数:
int GetImage(int h,int w,int c,int8_t* img_buff,uint8_t* src_buf){
int img_size = h*w*c;
for(int i = 0;i < img_size;++i){
img_buff[i] = src_buf[i];
}
return kTfLiteOk;
}
然后将图像输入里面去:
if(kTfLiteOk != GetImage(96,96,1,input->data.int8,rose_image_buf)){
TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
}
开始预测
if (kTfLiteOk != interpreter->Invoke()) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");
}
预测完成之后开始遍历所有输出层的权重,得到评分最高的那个并输出对应的标签
TfLiteTensor* output = interpreter->output(0);
int8_t person_score = output->data.uint8[0];
int8_t index = 0;
for(int i = 0;i < LAYER ;++i){
if(output->data.uint8[i] > person_score){
person_score = output->data.uint8[i];
index = i;
}
}
printk("score:%d,label:%s\n",person_score,label[i]);
输出结果:
score:240,label:Rose
完整代码:
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/schema/schema_generated.h>
#include <tensorflow/lite/version.h>
#include <zephyr.h>
#include <sys/printk.h>
#include <logging/log.h>
#include <device.h>
#include <soc.h>
#include <stdlib.h>
#define H 96
#define W 96
#define C 1
#define LAYER 5
char label[][1024] = {"Daisy","Dandelion","Roses","Sumflowers","Tulips"}
static uint8_t model_buf[] = {
#include <model.inc>
};
static uint8_t rose_image_buf[] = {
#include <rose.inc>
};
int GetImage(int h,int w,int c,int8_t* img_buff,uint8_t* src_buf){
int img_size = h*w*c;
for(int i = 0;i < img_size;++i){
img_buff[i] = src_buf[i];
}
return kTfLiteOk;
}
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
constexpr int kTensorArenaSize = 136 * 1024;
static uint8_t tensor_arena[kTensorArenaSize];
static const struct device* flash_shell;
}
int main(){
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = µ_error_reporter;
model = tflite::GetModel(model_buf);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return -1;
}
static tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddAveragePool2D();
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddReshape();
micro_op_resolver.AddSoftmax();
static tflite::MicroInterpreter static_interpreter(model, micro_op_resolver, tensor_arena,
kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return -2;
}
input = interpreter->input(0);
if(kTfLiteOk != GetImage(96,96,1,input->data.int8,rose_image_buf)){
TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
}
if (kTfLiteOk != interpreter->Invoke()) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");
}
TfLiteTensor* output = interpreter->output(0);
int8_t person_score = output->data.uint8[0];
int8_t index = 0;
for(int i = 0;i < LAYER ;++i){
if(output->data.uint8[i] > person_score){
person_score = output->data.uint8[i];
index = i;
}
}
printk("score:%d,label:%s\n",person_score,label[i]);
return 0;
}