当前位置: 首页 > 工具软件 > TNN > 使用案例 >

腾讯开源框架TNN调用过程

陈琪
2023-12-01

第一步:模型转换,按照github一步一步来就ok了~此处无坑

第二步:cmake建立vs工程,需要在cmakelist里面需要使用的accelerator,否则在getdevice会返回NULL值

第三步:调用

#include "tnn/utils/dims_vector_utils.h"
#include "tnn/utils/blob_transfer_utils.h"
#include "tnn/core/tnn.h"
#include <fstream>
static void ModifynhwcTonchw(float* dst, float* src,
	int batch, int channel,
	int height, int width)
{
	for (int n = 0; n < batch; n++)
	{
		for (int c = 0; c < channel; c++)
		{
			for (int h = 0; h < height; h++)
			{
				for (int w = 0; w < width; w++)
				{
					dst[n*height*width*channel + c*height*width + h*width + w] =
						src[n*height*width*channel + h*width*channel + w*channel + c];
				}
			}
		}
	}
}

void CopyDataToDeviceFromFile(TNN_NS::BlobMap blob_map,std::string input_file,void* command_queue)
{
	//get input_blob info
	std::string input_name = (blob_map.begin())->first;
	TNN_NS::Blob* device_blob = (blob_map.begin())->second;
	TNN_NS::BlobConverter blob_converter(device_blob);
	TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc();
	
	//get input data
	TNN_NS::BlobHandle data_handle;
	int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims);
	float* input_data = (float*)malloc(data_count * sizeof(float));
	FILE* fp = fopen(input_file.data(), "rb");
	if (fp == NULL)
	{
		printf("CopyDataToDeviceFromFile Err,read input file failed: %s\n",input_file.data());
	}
	fread(input_data, data_count, sizeof(float), fp);
	fclose(fp);

	//if necessary
	if (1)
	{
		float* trans_data = (float*)malloc(data_count * sizeof(float));
		ModifynhwcTonchw(trans_data, input_data,
			blob_desc.dims[0], blob_desc.dims[1],
			blob_desc.dims[2], blob_desc.dims[3]);
		free(input_data);
		input_data = trans_data;
	}

	data_handle.base = input_data;
	data_handle.bytes_offset = 0;

	//convert
	TNN_NS::Blob data_blob(blob_desc,data_handle);
	TNN_NS::CopyToDevice(device_blob, &data_blob, command_queue);

	free(input_data);
}

void CopyDataFromDevicveToFile(TNN_NS::BlobMap blob_map, std::string out_file, void* command_queue)
{
	//get output info
	TNN_NS::Blob* device_blob = (blob_map.begin())->second;
	TNN_NS::BlobConverter blob_converter_out(device_blob);
	TNN_NS::BlobDesc blob_desc = device_blob->GetBlobDesc();
	int data_count = TNN_NS::DimsVectorUtils::Count(blob_desc.dims);

	//get input data
	TNN_NS::BlobHandle data_handle;
	float* input_data = (float*)malloc(data_count * sizeof(float));
	data_handle.base = input_data;
	data_handle.bytes_offset = 0;

	//convert
	TNN_NS::Blob data_blob(blob_desc, data_handle);
	TNN_NS::CopyFromDevice(&data_blob, device_blob, command_queue);

	//write file
	FILE *fp = fopen(out_file.data(),"w");
	for (int i = 0; i < data_count; i++)
	{
		fprintf(fp, "%f\n", input_data[i]);
	}
	fclose(fp);
	
	free(input_data);
}

int main()
{
	std::string model_name = "test.opt.tnnmodel";
	std::string bin_name = "test.opt.tnnproto";
	std::string input_file = "input.txt";
	std::string output_file = "data.txt";

	TNN_NS::NetworkConfig myNet;
	TNN_NS::ModelConfig myModel;
	myModel.model_type = TNN_NS::MODEL_TYPE_TNN;
	myNet.device_type = TNN_NS::DEVICE_NAIVE;
	myNet.data_format = TNN_NS::DATA_FORMAT_NCHW;

	//read proto first
	std::ifstream proto_stream(bin_name);
	if (!proto_stream.is_open() || !proto_stream.good()) {
		printf("read proto_file failed!\n");
	}
	auto buffer =
		std::string((std::istreambuf_iterator<char>(proto_stream)), std::istreambuf_iterator<char>());
	myModel.params.push_back(buffer);

	//read model bin
	std::ifstream model_stream(model_name, std::ios::binary);
	if (!model_stream.is_open() || !model_stream.good()) {
		myModel.params.push_back("");
	}
	auto model_content =
		std::string((std::istreambuf_iterator<char>(model_stream)), std::istreambuf_iterator<char>());

	myModel.params.push_back(model_content);

	//Init
	TNN_NS::TNN net;
	TNN_NS::Status ret = net.Init(myModel);
	TNN_NS::InputShapesMap input_shape;
	auto instance = net.CreateInst(myNet, ret);

	TNN_NS::BlobMap input_blob_maps;
	TNN_NS::BlobMap output_blob_maps;
	void* command_queue;
	instance->GetAllInputBlobs(input_blob_maps);
	instance->GetAllOutputBlobs(output_blob_maps);
	instance->GetCommandQueue(&command_queue);

	CopyDataToDeviceFromFile(input_blob_maps,input_file, command_queue);

	ret = instance->Forward();

	CopyDataFromDevicveToFile(output_blob_maps, output_file, command_queue);

	ret = net.DeInit();
}

 

 类似资料: