tflite::Lable_image::Main函数:
输入参数全部存储在Setting里面:
./lable_image | ||
-i | ./grace_hopper.bmp | |
-l | ./labels.txt | 用于输出结果的标签有哪些。比如 background tench goldfish great white sharp tiger shark hammerhead。。。文件里有很多当然也可以改成汉字的。最后的输出还会输出相关label的置信度 |
-m | ./mobilet_quant_v1_224.tflite | |
-a | 0 | 是否使用android NNAPI加速【interpreter->UseNNAPI(s->accel);】 |
-c | 1 | 循环次数loop_count【 for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed 】 |
-b | 128 | input mean 代码总默认127.5【用于控制收敛速度?】 |
-p | 0 | 是否开启profiling【用于深度学习参数优化 |
-t | 1 | 线程数量【 if (s->number_of_threads != -1) { interpreter->SetNumThreads(s->number_of_threads); }】 |
-v | 1 | 是否显示更多运行信息 |
-s | input std 代码中默认127.5 |
文件中的默认值:
external/tensorflow$ vi tensorflow/contrib/lite/examples/label_image/label_image.h +24
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
namespace label_image {
struct Settings {
bool verbose = false;
bool accel = false;
bool input_floating = false;
int loop_count = 1;
float input_mean = 127.5f;
float input_std = 127.5f;
string model_name = "./mobilenet_quant_v1_224.tflite";
string input_bmp_name = "./grace_hopper.bmp";
string labels_file_name = "./labels.txt";
string input_layer_type = "uint8_t";
int number_of_threads = 4;
};
} // namespace label_image
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
在main函数中对应的解释:
static struct option long_options[] = {
{"accelerated", required_argument, 0, 'a'},
{"count", required_argument, 0, 'c'},
{"verbose", required_argument, 0, 'v'},
{"image", required_argument, 0, 'i'},
{"labels", required_argument, 0, 'l'},
{"tflite_model", required_argument, 0, 'm'},
{"threads", required_argument, 0, 't'},
{"input_mean", required_argument, 0, 'b'},
{"input_std", required_argument, 0, 's'},
{0, 0, 0, 0}};
首先flatbufferbuilder,在modle.h里面包含了两个builder,一个是FlatBufferBuilder,一个是InterpreterBuilder
之前以为FlatBufferBuilder是用来构建只读模型,InterpreterBuilder是用来构建可修改的模型,实际上这样的理解是不对的(从注释上看),FlatBufferBuilder是用来构建tflite的模型,InterpreterBuilder是用来构建interpreter
// An RAII object that represents a read-only tflite model, copied from disk,
// or mmapped. This uses flatbuffers as the serialization format.
// flatbuffers是什么来着,好像是一种固定格式的文件,具体有点忘记了
class FlatBufferModel {
public:
// Builds a model based on a file. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
const char* buffer, size_t buffer_size,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Builds a model directly from a flatbuffer pointer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
// is destroyed. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromModel(
const tflite::Model* model_spec,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Releases memory or unmaps mmaped meory.
~FlatBufferModel();
// Copying or assignment is disallowed to simplify ownership semantics.
FlatBufferModel(const FlatBufferModel&) = delete;
FlatBufferModel& operator=(const FlatBufferModel&) = delete;
bool initialized() const { return model_ != nullptr; }
const tflite::Model* operator->() const { return model_; }
const tflite::Model* GetModel() const { return model_; }
ErrorReporter* error_reporter() const { return error_reporter_; }
const Allocation* allocation() const { return allocation_; }
// Returns true if the model identifier is correct (otherwise false and
// reports an error).
bool CheckModelIdentifier() const;
private:
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
//
// Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
// used.
explicit FlatBufferModel(
const char* filename, bool mmap_file = true,
ErrorReporter* error_reporter = DefaultErrorReporter(),
bool use_nnapi = false);
// Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
// to remain alive and unchanged until the end of this flatbuffermodel's
// lifetime.
//
// Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
// used.
FlatBufferModel(const char* ptr, size_t num_bytes,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Loads a model from Model flatbuffer. The `model` has to remain alive and
// unchanged until the end of this flatbuffermodel's lifetime.
FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
// Flatbuffer traverser pointer. (Model* is a pointer that is within the
// allocated memory of the data allocated by allocation's internals.
const tflite::Model* model_ = nullptr;
ErrorReporter* error_reporter_;
Allocation* allocation_ = nullptr;
};
InterpreterBuilder
// Build an interpreter capable(能力) of interpreting `model`.建立一个能够解析模型的解析器
//
// model: a scoped(作用域) model whose lifetime must be at least as long as
// the interpreter. In principle multiple interpreters can be made from
// a single model.
// op_resolver: An instance(实例) that implements(实现) the Resolver(分解器) interface which maps
// custom op names and builtin op codes to op registrations.
// reportError: a functor that is called to report errors that handles
// printf var arg semantics(语意). The lifetime of the reportError object must
// be greater than or equal to the Interpreter created by operator().
//
// Returns a kTfLiteOk when successful and sets interpreter to a valid
// Interpreter. Note: the user must ensure the model lifetime is at least as
// long as interpreter's lifetime.
class InterpreterBuilder {
public:
InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver);
// Builds an interpreter given only the raw flatbuffer Model object (instead
// of a FlatBufferModel). Mostly used for testing.
// If `error_reporter` is null, then DefaultErrorReporter() is used.
InterpreterBuilder(const ::tflite::Model* model,
const OpResolver& op_resolver,
ErrorReporter* error_reporter = DefaultErrorReporter());
InterpreterBuilder(const InterpreterBuilder&) = delete;
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
int num_threads);
private:
TfLiteStatus BuildLocalIndexToRegistrationMapping();
TfLiteStatus ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
Interpreter* interpreter);
TfLiteStatus ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
Interpreter* interpreter);
const ::tflite::Model* model_;
const OpResolver& op_resolver_;
ErrorReporter* error_reporter_;
std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_;
std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
const Allocation* allocation_ = nullptr;
};
just like this below:
#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_
#define TENSORFLOW_CONTRIB_LITE_MODEL_H_
#include <memory>
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
namespace tflite {
class FlatBufferModel{
*************
};
class InterpreterBuilder{
****************
};
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_
然后是Interpreter,也就是我们的解释器,翻译官,在interperter.h这个头文件中,class里面的factor很多
// Interpreter实际上是翻译官
class Interpreter {
public:
// Instantiate an interpreter. All errors associated with reading and
// processing this model will be forwarded to the error_reporter object.
//
// Note, if error_reporter is nullptr, then a default StderrReporter is
// used.
explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
~Interpreter();
Interpreter(const Interpreter&) = delete;
Interpreter& operator=(const Interpreter&) = delete;
// Functions to build interpreter
// Provide a list of tensor indexes that are inputs to the model.
// Each index is bound check and this modifies the consistent_ flag of the
// interpreter.
TfLiteStatus SetInputs(std::vector<int> inputs);
// Provide a list of tensor indexes that are outputs to the model
// Each index is bound check and this modifies the consistent_ flag of the
// interpreter.
TfLiteStatus SetOutputs(std::vector<int> outputs);
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
// remains with the caller.
TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
const std::vector<int>& outputs,
const char* init_data,
size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration,
int* node_index = nullptr);
// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
// The value pointed to by `first_new_tensor_index` will be set to the
// index of the first new tensor if `first_new_tensor_index` is non-null.
TfLiteStatus AddTensors(int tensors_to_add,
int* first_new_tensor_index = nullptr);
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
TfLiteStatus SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization);
// Functions to access tensor data
// Read only access to list of inputs.
const std::vector<int>& inputs() const { return inputs_; }
// Return the name of a given input. The given index must be between 0 and
// inputs().size().
const char* GetInputName(int index) const {
return context_.tensors[inputs_[index]].name;
}
// Read only access to list of outputs.
const std::vector<int>& outputs() const { return outputs_; }
// Return the name of a given output. The given index must be between 0 and
// outputs().size().
const char* GetOutputName(int index) const {
return context_.tensors[outputs_[index]].name;
}
// Return the number of tensors in the model.
int tensors_size() const { return context_.tensors_size; }
// Return the number of ops in the model.
int nodes_size() const { return nodes_and_registration_.size(); }
// WARNING: Experimental interface, subject to change
const std::vector<int>& execution_plan() const { return execution_plan_; }
// WARNING: Experimental interface, subject to change
// Overrides execution plan. This bounds checks indices sent in.
TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
// Get a tensor data structure.
// TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
// read/write access to structure
TfLiteTensor* tensor(int tensor_index) {
if (tensor_index >= context_.tensors_size || tensor_index < 0)
return nullptr;
return &context_.tensors[tensor_index];
}
// Get an immutable tensor data structure.
const TfLiteTensor* tensor(int tensor_index) const {
if (tensor_index >= context_.tensors_size || tensor_index < 0)
return nullptr;
return &context_.tensors[tensor_index];
}
// Get a pointer to an operation and registration data structure if in bounds.
// TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
// read/write access to structure
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
int node_index) const {
if (node_index >= nodes_and_registration_.size() || node_index < 0)
return nullptr;
return &nodes_and_registration_[node_index];
}
// Perform a checked cast to the appropriate tensor type.
template <class T>
T* typed_tensor(int tensor_index) {
if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
if (tensor_ptr->type == typeToTfLiteType<T>()) {
return reinterpret_cast<T*>(tensor_ptr->data.raw);
}
}
return nullptr;
}
// Return a pointer into the data of a given input tensor. The given index
// must be between 0 and inputs().size().
template <class T>
T* typed_input_tensor(int index) {
return typed_tensor<T>(inputs_[index]);
}
// Return a pointer into the data of a given output tensor. The given index
// must be between 0 and outputs().size().
template <class T>
T* typed_output_tensor(int index) {
return typed_tensor<T>(outputs_[index]);
}
// Change the dimensionality of a given tensor. Note, this is only acceptable
// for tensor indices that are inputs.
// Returns status of failure or success.
// TODO(aselle): Consider implementing ArraySlice equivalent to make this
// more adept at accepting data without an extra copy. Use absl::ArraySlice
// if our partners determine that dependency is acceptable.
TfLiteStatus ResizeInputTensor(int tensor_index,
const std::vector<int>& dims);
// Update allocations for all tensors. This will redim dependent tensors using
// the input tensor dimensionality as given. This is relatively expensive.
// If you know that your sizes are not changing, you need not call this.
// Returns status of success or failure.
TfLiteStatus AllocateTensors();
// Invoke the interpreter (run the whole graph in dependency(依赖) order).
//
// NOTE: It is possible that the interpreter is not in a ready state
// to evaluate (i.e. if a ResizeTensor() has been performed without an
// AllocateTensors().
// Returns status of success or failure.
TfLiteStatus Invoke(); //调用,感觉这个是个最重点的函数
// Enable or disable the NN API (true to enable)
void UseNNAPI(bool enable);
// Set the number of threads available to the interpreter.
void SetNumThreads(int num_threads);
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
// Retrieve an operator's description of its work, for profiling purposes.
const char* OpProfilingString(const TfLiteRegistration& op_reg,
const TfLiteNode* node) const {
// haili TODO:
//if (op_reg.profiling_string == nullptr) return nullptr;
//return op_reg.profiling_string(&context_, node);
return nullptr;
}
void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; }
profiling::Profiler* GetProfiler() { return profiler_; }
private:
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
size_t length) {
if (op_reg.init == nullptr) return nullptr;
return op_reg.init(&context_, buffer, length);
}
// Let 'op_reg' release any memory it might have allocated via 'OpInit'.
void OpFree(const TfLiteRegistration& op_reg, void* buffer) {
if (op_reg.free == nullptr) return;
if (buffer) {
op_reg.free(&context_, buffer);
}
}
// Prepare the given 'node' for execution.
TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) {
if (op_reg.prepare == nullptr) return kTfLiteOk;
return op_reg.prepare(&context_, node);
}
// Invoke the operator represented by 'node'.
TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) {
if (op_reg.invoke == nullptr) return kTfLiteError;
return op_reg.invoke(&context_, node);
}
// Call OpPrepare() for as many ops as possible, allocating memory for their
// tensors. If an op containing dynamic tensors is found, preparation will be
// postponed until this function is called again. This allows the interpreter
// to wait until Invoke() to resolve the sizes of dynamic tensors.
TfLiteStatus PrepareOpsAndTensors();
// Call OpPrepare() for all ops starting at 'first_node'. Stop when a
// dynamic tensors is found or all ops have been prepared. Fill
// 'last_node_prepared' with the id of the op containing dynamic tensors, or
// the last in the graph.
TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index,
int* last_execution_plan_index_prepared);
// Tensors needed by the interpreter. Use `AddTensors` to add more blank
// tensor entries. Note, `tensors_.data()` needs to be synchronized to the
// `context_` whenever this std::vector is reallocated. Currently this
// only happens in `AddTensors()`.
std::vector<TfLiteTensor> tensors_;
// Check if an array of tensor indices are valid with respect to the Tensor
// array.
// NOTE: this changes consistent_ to be false if indices are out of bounds.
TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
int length);
// Compute the number of bytes required to represent a tensor with dimensions
// specified by the array dims (of length dims_size). Returns the status code
// and bytes.
TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size,
size_t* bytes);
// Request an tensor be resized implementation. If the given tensor is of
// type kTfLiteDynamic it will also be allocated new memory.
TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size);
// Report a detailed error string (will be printed to stderr).
// TODO(aselle): allow user of class to provide alternative destinations.
void ReportErrorImpl(const char* format, va_list args);
// Entry point for C node plugin API to request an tensor be resized.
static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor,
TfLiteIntArray* new_size);
// Entry point for C node plugin API to report an error.
static void ReportError(TfLiteContext* context, const char* format, ...);
// Entry point for C node plugin API to add new tensors.
static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add,
int* first_new_tensor_index);
// WARNING: This is an experimental API and subject to change.
// Entry point for C API ReplaceSubgraphsWithDelegateKernels
static TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
TfLiteContext* context, TfLiteRegistration registration,
const TfLiteIntArray* nodes_to_replace);
// Update the execution graph to replace some of the nodes with stub
// nodes. Specifically any node index that has `nodes[index]==1` will be
// slated for replacement with a delegate kernel specified by registration.
// WARNING: This is an experimental interface that is subject to change.
TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace);
// WARNING: This is an experimental interface that is subject to change.
// Gets the internal pointer to a TensorFlow lite node by node_index.
TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node,
TfLiteRegistration** registration);
// WARNING: This is an experimental interface that is subject to change.
// Entry point for C node plugin API to get a node by index.
static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*,
int node_index, TfLiteNode** node,
TfLiteRegistration** registration);
// WARNING: This is an experimental interface that is subject to change.
// Gets an TfLiteIntArray* representing the execution plan. The caller owns
// this memory and must free it with TfLiteIntArrayFree().
TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan);
// WARNING: This is an experimental interface that is subject to change.
// Entry point for C node plugin API to get the execution plan
static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
TfLiteIntArray** execution_plan);
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
TfLiteContext context_;
// Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores
// function pointers to actual implementation.
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>
nodes_and_registration_;
// Whether the model is consistent. That is to say if the inputs and outputs
// of every node and the global inputs and outputs are valid indexes into
// the tensor array.
bool consistent_ = true;
// Whether the model is safe to invoke (if any errors occurred this
// will be false).
bool invokable_ = false;
// Array of indices representing the tensors that are inputs to the
// interpreter.
std::vector<int> inputs_;
// Array of indices representing the tensors that are outputs to the
// interpreter.
std::vector<int> outputs_;
// The error reporter delegate that tflite will forward queries errors to.
ErrorReporter* error_reporter_;
// Index of the next node to prepare.
// During Invoke(), Interpreter will allocate input tensors first, which are
// known to be fixed size. Then it will allocate outputs from nodes as many
// as possible. When there is a node that produces dynamic sized tensor.
// Intepreter will stop allocating tensors, set the value of next allocate
// node id, and execute the node to generate the output tensor before continue
// to allocate successors. This process repeats until all nodes are executed.
// NOTE: this relies on the order of nodes that is in topological order.
int next_execution_plan_index_to_prepare_;
// WARNING: This is an experimental interface that is subject to change.
// This is a list of node indices (to index into nodes_and_registration).
// This represents a valid topological sort (dependency ordered) execution
// plan. In particular, it is valid for this ordering to contain only a
// subset of the node indices.
std::vector<int> execution_plan_;
// In the future, we'd like a TfLiteIntArray compatible representation.
// TODO(aselle): replace execution_plan_ with this.
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> plan_cache_;
// Whether to delegate to NN API
std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
std::unique_ptr<MemoryPlanner> memory_planner_;
// Profiler for this interpreter instance.
profiling::Profiler* profiler_;
};
构建OpResolver
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
namespace ops {
namespace builtin {
//OpResolver 是父类
class BuiltinOpResolver : public OpResolver {//OpResolver负责维护函数和指针之间的对应关系
public:
BuiltinOpResolver();
TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override;
TfLiteRegistration* FindOp(const char* op) const override;
void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration);
void AddCustom(const char* name, TfLiteRegistration* registration);
private:
struct BuiltinOperatorHasher {
size_t operator()(const tflite::BuiltinOperator& x) const {
return std::hash<size_t>()(static_cast<size_t>(x));
}
};
std::unordered_map<tflite::BuiltinOperator, TfLiteRegistration*,
BuiltinOperatorHasher>
builtins_;
std::unordered_map<std::string, TfLiteRegistration*> custom_ops_;
};
} // namespace builtin
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
完整的RunInference函数如下:
double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
void RunInference(Settings* s) {
if (!s->model_name.c_str()) {
LOG(ERROR) << "no model file name\n";
exit(-1);
}
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
// 1、建立模型
/*
public:
// Builds a model based on a file. Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
*/
model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str());
if (!model) {
LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
exit(-1);
}
LOG(INFO) << "Loaded model " << s->model_name << "\n";
/* ErrorReporter* error_reporter() const { return error_reporter_; }*/
model->error_reporter();
LOG(INFO) << "resolved reporter\n";
//2)建立OpResolver 用于指向每个node的操作函数 tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::ops::builtin::BuiltinOpResolver resolver;
//3)建立解释器 tflite::InterpreterBuilder(*model, resolver)(&interpreter);
/*
// Builds an interpreter given only the raw flatbuffer Model object (instead
// of a FlatBufferModel). Mostly used for testing.
// If `error_reporter` is null, then DefaultErrorReporter() is used.
InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter());
传入的第二个参数是引用,实际上有好几个构造函数,maybe this is true or not
*/
// 构建之后生成的是class Interpreter
tflite::InterpreterBuilder(*model, resolver)(&interpreter); // 后面这样的操作可能是将interperter赋值给他,我去,忘得差不多了
if (!interpreter) {
LOG(FATAL) << "Failed to construct interpreter\n";
exit(-1);
}
//4)对解释器进行参数设置包括
interpreter->UseNNAPI(s->accel);
// 具体可以看class Interpreter里剩下的函数
if (s->verbose) {
LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n";
LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n";
LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n";
LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n";
int t_size = interpreter->tensors_size();
for (int i = 0; i < t_size; i++) {
// tensor()是TFliteTensor的格式
// 模型中的tensor会被加载成TFliteTensor的格式
if (interpreter->tensor(i)->name)
LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", "
<< interpreter->tensor(i)->bytes << ", "
<< interpreter->tensor(i)->type << ", "
<< interpreter->tensor(i)->params.scale << ", "
<< interpreter->tensor(i)->params.zero_point << "\n";
}
}
if (s->number_of_threads != -1) {
interpreter->SetNumThreads(s->number_of_threads);
}
// 5)bmp文件读入并进行必要的resize
int image_width = 224;
int image_height = 224;
int image_channels = 3;
// examples/label_image/bitmap_helpers.cc 可以借鉴一下
uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height,
&image_channels, s);
// 为什么只取第一个数据呢?
int input = interpreter->inputs()[0];
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
/*
// Array of indices representing the tensors that are inputs to the
// interpreter.
std::vector<int> inputs_;
// Array of indices representing the tensors that are outputs to the
// interpreter.
std::vector<int> outputs_;
*/
const std::vector<int> inputs = interpreter->inputs();
const std::vector<int> outputs = interpreter->outputs();
if (s->verbose) {
LOG(INFO) << "number of inputs: " << inputs.size() << "\n";
LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
}
/*
// Returns status of success or failure.
TfLiteStatus AllocateTensors();
*/
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!";
}
//打印运行参数相关信息
//optional_debug_tools.cc +72
if (s->verbose) PrintInterpreterState(interpreter.get());
// get input dimension from the input tensor metadata
// assuming one input only
/*
// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
// indices
typedef struct {
int size;
// gcc 6.1+ have a bug where flexible members aren't properly handled
// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
__GNUC_MINOR__ >= 1
int data[0];
#else
int data[];
#endif
} TfLiteIntArray;
*/
TfLiteIntArray* dims = interpreter->tensor(input)->dims;
int wanted_height = dims->data[1];
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
// 大胆假设是将这些数据都转换成tensor指定的type的类型
switch (interpreter->tensor(input)->type) {
case kTfLiteFloat32:
s->input_floating = true;
resize<float>(interpreter->typed_tensor<float>(input), in, image_height,
image_width, image_channels, wanted_height, wanted_width, wanted_channels, s);
break;
case kTfLiteUInt8:
resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
image_height, image_width, image_channels, wanted_height,wanted_width, wanted_channels, s);
break;
default:
LOG(FATAL) << "cannot handle input type "
<< interpreter->tensor(input)->type << " yet";
exit(-1);
}
struct timeval start_time, stop_time;
gettimeofday(&start_time, NULL);
//运行模型及获得运行时间
for (int i = 0; i < s->loop_count; i++) {
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke tflite!\n";
}
}
gettimeofday(&stop_time, NULL);
LOG(INFO) << "invoked \n";
LOG(INFO) << "average time: "
<< (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000)
<< " ms \n";
const int output_size = 1000;
const size_t num_results = 5;
const float threshold = 0.001f;
std::vector<std::pair<float, int>> top_results;
// 为什么也是取第一个数据呢?
int output = interpreter->outputs()[0];
//获取输出,和上面类似,格式化输出数据的类型
switch (interpreter->tensor(output)->type) {
case kTfLiteFloat32:
get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
num_results, threshold, &top_results, true);
break;
case kTfLiteUInt8:
get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
output_size, num_results, threshold, &top_results,
false);
break;
default:
LOG(FATAL) << "cannot handle output type "
<< interpreter->tensor(input)->type << " yet";
exit(-1);
}
//加载label并显示对应输出结果
std::vector<string> labels;
size_t label_count;
//vi examples/label_image/label_image.cc +52
// 读取标签文件
if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk)
exit(-1);
// first是float的数据
// secound是int的数据
for (const auto& result : top_results) {
const float confidence = result.first;
const int index = result.second;
LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n";
}
}
读取lable文件是自己定义:
// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
TfLiteStatus ReadLabelsFile(const string& file_name,
std::vector<string>* result,
size_t* found_label_count) {
std::ifstream file(file_name);
if (!file) {
LOG(FATAL) << "Labels file " << file_name << " not found\n";
return kTfLiteError;
}
result->clear();
string line;
while (std::getline(file, line)) {
result->push_back(line);
}
*found_label_count = result->size();
const int padding = 16;
while (result->size() % padding) {
result->emplace_back();
}
return kTfLiteOk;
}