机器学习的框架众多,为了方便复用和统一后端模型部署推理,业界主流都在采用onnx格式的模型,支持pytorch,tensorflow,mxnet多种AI框架。为了提高部署推理的性能,考虑采用onnxruntime机器学习后端推理框架进行部署加速,通过简单的C++ api的调用就可以满足基本使用场景。
参考微软开源项目主页https://github.com/microsoft/onnxruntime
这里的例子使用了标准的fashion_mnist数据集,训练了简单的线性模型来进行结果分类,输入服装的图片(1 x 28 x 28)输出所属的类型(1 x 10)。
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import onnxruntime
# data file
data_dir = "../data"
model_file = "model.pt"
onnx_model_file = "model.onnx"
# download training data from open datasets.
training_data = datasets.FashionMNIST(
root=data_dir,
train=True,
download=True,
transform=ToTensor(),
)
# data loader
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# peek sample
for X, y in test_dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape)
break
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
# define model
device = "cpu"
print("Using device: ", device)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
# define train
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print("loss: %7f, current: %5d / size: %5d" % (loss, current, size))
# define test
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: %0.1f, Avg loss: %8f \n" % (correct * 100, test_loss))
# training
epochs = 5
for t in range(epochs):
print("Epoch %d\n-------------------------------" % (t + 1))
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
# pytorch save and load model
torch.save(model, model_file)
# choose sample
x, y = test_data[6][0], test_data[6][1]
# onnx save and load model
torch.onnx.export(
model,
x,
onnx_model_file,
input_names=['input'],
output_names=['output'],
dynamic_axes = {'input':
{0: 'batch_size'},
'output':
{0: 'batch_size'}
})
onnx_model = onnxruntime.InferenceSession(onnx_model_file, providers=["CPUExecutionProvider"])
print(onnx_model)
# onnx predict
onnx_input_name = onnx_model.get_inputs()[0].name
onnx_output_name = onnx_model.get_outputs()[0].name
onnx_x = x.numpy()
onnx_pred_y = onnx_model.run([onnx_output_name], {onnx_input_name: onnx_x})
print(onnx_pred_y)
print(int(np.argmax(onnx_pred_y)))
推理结果
[array([[ 0.78701156, -0.3907491 , 0.8121021 , 0.14819556, 1.0213459 ,
-0.7819631 , 0.95659614, -1.4445262 , 0.06370842, -1.2055752 ]],
dtype=float32)]
4
需要构建C++工程,引入onnxruntime依赖,支持windows和linux
工程结构
onnxruntime_demo
├── build
│ ├── onnxruntime_demo
│ └── model.onnx
├── CMakeLists.txt
├── onnxruntime
├── README.md
└── src
└── main.cpp
main.cpp
#include <iostream>
#include <array>
#include <algorithm>
#include "onnxruntime_cxx_api.h"
int main(int argc, char* argv[])
{
// --- define model path
#if _WIN32
const wchar_t* model_path = L"./model.onnx"; // you can use string to wchar_t* function to convert
#else
const char* model_path = "./model.onnx";
#endif
// --- init onnxruntime env
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
// set options
Ort::SessionOptions session_option;
session_option.SetIntraOpNumThreads(5); // extend the number to do parallel
session_option.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// --- prepare data
const char* input_names[] = { "input" }; // must keep the same as model export
const char* output_names[] = { "output" };
// use statc array to preallocate data buffer
std::array<float, 1 * 28 * 28> input_matrix;
std::array<float, 1 * 10> output_matrix;
// must use int64_t type to match args
std::array<int64_t, 3> input_shape{ 1, 28, 28 };
std::array<int64_t, 2> output_shape{ 1, 10 };
std::vector<std::vector<std::vector<float>>> sample_x = {
{
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0667, 0.0000, 0.1373, 0.2157, 0.2039, 0.1765, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.0039, 0.9804, 1.0000, 0.9608, 0.9961, 0.9333, 0.9569,0.9373, 0.5412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3451, 0.4863, 0.6667, 0.9961, 0.5412, 0.7333, 1.0000, 0.7333, 0.1255, 0.0157, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3294, 0.3843, 0.0000, 0.7137, 0.8235, 0.9529, 1.0000, 0.1451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0235, 0.2196, 0.2824, 0.3569, 0.5216, 0.1686, 0.0000, 0.9412, 0.8549, 0.0000, 0.0000, 0.1529, 0.1804, 0.0784, 0.0980, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1882, 0.4275, 0.2745, 0.2118, 0.1725, 0.2784, 0.2196, 0.2510, 0.0588, 0.0784, 0.1137, 0.1098, 0.2275, 0.2235, 0.2000, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2549, 0.2863, 0.3216, 0.1922, 0.2275, 0.2039, 0.1255, 0.3294, 0.2706, 0.0980, 0.1961, 0.2471, 0.1804, 0.1059, 0.0980, 0.1137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.3059, 0.2078, 0.5137, 0.1451, 0.2235, 0.2039, 0.0784, 0.3529, 0.3059, 0.0784, 0.2078, 0.2431, 0.1412, 0.0667, 0.1059, 0.1529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.3333, 0.1137, 0.6039, 0.2275, 0.1843, 0.1686, 0.0471, 0.2980, 0.2784, 0.0824, 0.1333, 0.0745, 0.0824, 0.0745, 0.1294, 0.1686, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.3098, 0.0510, 0.5373, 0.2549, 0.1608, 0.1647, 0.0392, 0.3294, 0.2627, 0.0510, 0.1137, 0.1098, 0.0706, 0.1608, 0.1765, 0.1137, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2078, 0.2863, 0.0392, 0.5725, 0.3333, 0.1686, 0.1647, 0.0353, 0.3294, 0.2471, 0.0627, 0.1216, 0.0941, 0.0549, 0.1294, 0.1137, 0.1137, 0.0510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2392, 0.2745, 0.0078, 0.6627, 0.4000, 0.1098, 0.1843, 0.0588, 0.3137, 0.2353, 0.0392, 0.1137, 0.1020, 0.0000, 0.3020, 0.1098, 0.1059, 0.0549, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2627, 0.2353, 0.0118, 0.7176, 0.3059, 0.1725, 0.1804, 0.0510, 0.2941, 0.2431, 0.0353, 0.0941, 0.1098, 0.0000, 0.6314, 0.1529, 0.0510, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2902, 0.1961, 0.0157, 0.8706, 0.2745, 0.1451, 0.1804, 0.0627, 0.2941, 0.2549, 0.0275, 0.1020, 0.0627, 0.0000, 0.9490, 0.1804, 0.0275, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2863, 0.1412, 0.0431, 1.0000, 0.2235, 0.1725, 0.2118, 0.0431, 0.2902, 0.2471, 0.0157, 0.1020, 0.0235, 0.0039, 0.8549, 0.2863, 0.0000, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.2941, 0.1137, 0.0980, 1.0000, 0.2627, 0.1804, 0.1961, 0.0235, 0.3098, 0.2471, 0.0314, 0.0980, 0.0000, 0.1059, 0.9569, 0.3961, 0.0000, 0.1137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0431, 0.2941, 0.0588, 0.2078, 1.0000, 0.2275, 0.1529, 0.1922, 0.0706, 0.2980, 0.2549, 0.0235, 0.1059, 0.0157, 0.0000, 0.8627, 0.5412, 0.0000, 0.1098, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0431, 0.2902, 0.0196, 0.4039, 0.9961, 0.1922, 0.1882, 0.1804, 0.0510, 0.2863, 0.2549, 0.0078, 0.0980, 0.0196, 0.0000, 0.8196, 0.6941, 0.0000, 0.1176, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0627, 0.2941, 0.0157, 0.4745, 1.0000, 0.1412, 0.1843, 0.2039, 0.0627, 0.2627, 0.2706, 0.0078, 0.0863, 0.0549, 0.0000, 0.7451, 0.7098, 0.0000, 0.1098, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.3059, 0.0000, 0.5098, 0.9961, 0.0824, 0.2314, 0.2275, 0.1098, 0.2902, 0.2824, 0.0039, 0.1059, 0.0941, 0.0000, 0.6863, 0.8000, 0.0000, 0.0941, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.3020, 0.0000, 0.6078, 0.8549, 0.0784, 0.2235, 0.2078, 0.0941, 0.2745, 0.2863, 0.0078, 0.1059, 0.0863, 0.0000, 0.5255, 0.8392, 0.0000, 0.0784, 0.0471, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0941, 0.2941, 0.0000, 0.7255, 0.7451, 0.0824, 0.2510, 0.2314, 0.1294, 0.2824, 0.2824, 0.0157, 0.1020, 0.1216, 0.0000, 0.4784, 0.8549, 0.0118, 0.0667, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.0941, 0.3176, 0.0000, 0.7608, 0.6157, 0.0706, 0.2235, 0.2196, 0.1176, 0.2784, 0.3020, 0.0157, 0.0902, 0.1020, 0.0000, 0.4353, 0.8902, 0.0510, 0.0510, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.1255, 0.3059, 0.0000, 0.8863, 0.5333, 0.2000, 0.3216, 0.2863, 0.1529, 0.2941, 0.3137, 0.0314, 0.1098, 0.1294, 0.0000, 0.4000, 0.9490, 0.0627, 0.0471, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.2745, 0.0118, 0.9176, 0.3294, 0.2039, 0.2941, 0.2941, 0.2235, 0.2510, 0.2588, 0.0745, 0.1608, 0.1529, 0.0314, 0.2039, 0.8549, 0.1765, 0.0235, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.2706, 0.1137, 0.9412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8980, 0.4353, 0.0000, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.2588, 0.3294, 0.1765, 0.4510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4667, 0.3059, 0.0941, 0.1020, 0.0000, 0.0000, 0.0000, 0.0000},
{0.0000, 0.0000, 0.0000, 0.0000, 0.2118, 0.2784, 0.1216, 0.2000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.1176, 0.1059, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000}
}
};
int sample_y = 4;
// expand input as one dimention array
for (int i = 0; i < 1; i++)
for (int j = 0; j < 28; j++)
for (int k = 0; k < 28; k++)
input_matrix[i * 1 * 28 + j * 28 + k] = sample_x[i][j][k];
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());
// --- predict
Ort::Session session(env, model_path, session_option); // FIXME: must check if model file exist or valid, otherwise this will cause crash
session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1); // here only use one input output channel
// --- result
// just use output_matrix as output, all you can use bellow code to get
//float* output_buffer = output_tensor.GetTensorMutableData<float>();
std::cout << "--- predict result ---" << std::endl;
// matrix output
std::cout << "ouput matrix: ";
for (int i = 0; i < 10; i++)
std::cout << output_matrix[i] << " ";
std::cout << std::endl;
// argmax value
int argmax_value = std::distance(output_matrix.begin(), std::max_element(output_matrix.begin(), output_matrix.end()));
std::cout << "output argmax value: " << argmax_value << std::endl;
getchar();
return 0;
}
推理结果
--- predict result ---
ouput matrix: 0.787004 -0.390761 0.8121 0.148179 1.02133 -0.781948 0.956589 -1.44451 0.0637145 -1.20555
output argmax value: 4
可以看到,在同样的输入情况下,C++和python版本可以得到同样的输出
其实还可以支持多通道输入输出