第一步:模型转换,按照github一步一步来就ok了~此处无坑
第二步:cmake建立vs工程,需要在cmakelist里面需要使用的accelerator,否则在getdevice会返回NULL值
第三步:调用
#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/blob_transfer_utils.h"
#include "tnn/core/tnn.h"
#include <fstream>
static void ModifynhwcTonchw(float* dst, float* src,
int batch, int channel,
int height, int width)
{
for (int n = 0; n < batch; n++)
{
for (int c = 0; c < channel; c++)
{
for (int h = 0; h < height; h++)
{
for (int w = 0; w < width; w++)
{
dst[n*height*width*channel + c*height*width + h*width + w] =
src[n*height*width*channel + h*width*channel + w*channel + c];
}
}
}
}
}
void CopyDataToDeviceFromFile(TNN_NS::BlobMap blob_map,std::string input_file,void* command_queue)
{
//get input_blob info
std::string input_name = (blob_map.begin())->first;
TNN_NS::Blob* device_blob = (blob_map.begin())->second;
TNN_NS::BlobConverter blob_converter(device_blob);
TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc();
//get input data
TNN_NS::BlobHandle data_handle;
int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims);
float* input_data = (float*)malloc(data_count * sizeof(float));
FILE* fp = fopen(input_file.data(), "rb");
if (fp == NULL)
{
printf("CopyDataToDeviceFromFile Err,read input file failed: %s\n",input_file.data());
}
fread(input_data, data_count, sizeof(float), fp);
fclose(fp);
//if necessary
if (1)
{
float* trans_data = (float*)malloc(data_count * sizeof(float));
ModifynhwcTonchw(trans_data, input_data,
blob_desc.dims[0], blob_desc.dims[1],
blob_desc.dims[2], blob_desc.dims[3]);
free(input_data);
input_data = trans_data;
}
data_handle.base = input_data;
data_handle.bytes_offset = 0;
//convert
TNN_NS::Blob data_blob(blob_desc,data_handle);
TNN_NS::CopyToDevice(device_blob, &data_blob, command_queue);
free(input_data);
}
void CopyDataFromDevicveToFile(TNN_NS::BlobMap blob_map, std::string out_file, void* command_queue)
{
//get output info
TNN_NS::Blob* device_blob = (blob_map.begin())->second;
TNN_NS::BlobConverter blob_converter_out(device_blob);
TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc();
int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims);
//get input data
TNN_NS::BlobHandle data_handle;
float* input_data = (float*)malloc(data_count * sizeof(float));
data_handle.base = input_data;
data_handle.bytes_offset = 0;
//convert
TNN_NS::Blob data_blob(blob_desc, data_handle);
TNN_NS::CopyFromDevice(&data_blob, device_blob, command_queue);
//write file
FILE *fp = fopen(out_file.data(),"w");
for (int i = 0; i < data_count; i++)
{
fprintf(fp, "%f\n", input_data[i]);
}
fclose(fp);
free(input_data);
}
int main()
{
std::string model_name = "test.opt.tnnmodel";
std::string bin_name = "test.opt.tnnproto";
std::string input_file = "input.txt";
std::string output_file = "data.txt";
TNN_NS::NetworkConfig myNet;
TNN_NS::ModelConfig myModel;
myModel.model_type = TNN_NS::MODEL_TYPE_TNN;
myNet.device_type = TNN_NS::DEVICE_NAIVE;
myNet.data_format = TNN_NS::DATA_FORMAT_NCHW;
//read proto first
std::ifstream proto_stream(bin_name);
if (!proto_stream.is_open() || !proto_stream.good()) {
printf("read proto_file failed!\n");
}
auto buffer =
std::string((std::istreambuf_iterator<char>(proto_stream)), std::istreambuf_iterator<char>());
myModel.params.push_back(buffer);
//read model bin
std::ifstream model_stream(model_name, std::ios::binary);
if (!model_stream.is_open() || !model_stream.good()) {
myModel.params.push_back("");
}
auto model_content =
std::string((std::istreambuf_iterator<char>(model_stream)), std::istreambuf_iterator<char>());
myModel.params.push_back(model_content);
//Init
TNN_NS::TNN net;
TNN_NS::Status ret = net.Init(myModel);
TNN_NS::InputShapesMap input_shape;
auto instance = net.CreateInst(myNet, ret);
TNN_NS::BlobMap input_blob_maps;
TNN_NS::BlobMap output_blob_maps;
void* command_queue;
instance->GetAllInputBlobs(input_blob_maps);
instance->GetAllOutputBlobs(output_blob_maps);
instance->GetCommandQueue(&command_queue);
CopyDataToDeviceFromFile(input_blob_maps,input_file, command_queue);
ret = instance->Forward();
CopyDataFromDevicveToFile(output_blob_maps, output_file, command_queue);
ret = net.DeInit();
}