当前位置: 首页 > 工具软件 > ONNX Runtime > 使用案例 >

pytorch导出模型并使用onnxruntime C++部署加载模型推理

严元白
2023-12-01

机器学习的框架众多,为了方便复用和统一后端模型部署推理,业界主流都在采用onnx格式的模型,支持pytorch,tensorflow,mxnet多种AI框架。为了提高部署推理的性能,考虑采用onnxruntime机器学习后端推理框架进行部署加速,通过简单的C++ api的调用就可以满足基本使用场景。

下载依赖

参考微软开源项目主页https://github.com/microsoft/onnxruntime

  • onnxruntime python包,通过pip安装
  • onnxruntime C++ sdk,下载源码编译

pytorch训练和导出

这里的例子使用了标准的fashion_mnist数据集,训练了简单的线性模型来进行结果分类,输入服装的图片(1 x 28 x 28)输出所属的类型(1 x 10)。

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import onnxruntime

# data file
data_dir = "../data"
model_file = "model.pt"
onnx_model_file = "model.onnx"

# download training data from open datasets.
training_data = datasets.FashionMNIST(
    root=data_dir,
    train=True,
    download=True,
    transform=ToTensor(),
)

# data loader
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

# peek sample
for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape)
    break

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

# define model
device = "cpu"
print("Using device: ", device)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
    
model = NeuralNetwork().to(device)
print(model)

# define train
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print("loss: %7f, current: %5d / size: %5d" % (loss, current, size))

# define test
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: %0.1f, Avg loss: %8f \n" % (correct * 100, test_loss))

# training
epochs = 5
for t in range(epochs):
    print("Epoch %d\n-------------------------------" % (t + 1))
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

# pytorch save and load model
torch.save(model, model_file)

# choose sample
x, y = test_data[6][0], test_data[6][1]

# onnx save and load model
torch.onnx.export(
        model,
        x,
        onnx_model_file,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes = {'input':
                            {0: 'batch_size'},
                       'output':
                            {0: 'batch_size'}
                       })

onnxruntime python加载模型推理

onnx_model = onnxruntime.InferenceSession(onnx_model_file, providers=["CPUExecutionProvider"])
print(onnx_model)

# onnx predict
onnx_input_name = onnx_model.get_inputs()[0].name
onnx_output_name = onnx_model.get_outputs()[0].name

onnx_x = x.numpy()
onnx_pred_y = onnx_model.run([onnx_output_name], {onnx_input_name: onnx_x})
print(onnx_pred_y)
print(int(np.argmax(onnx_pred_y)))

推理结果

[array([[ 0.78701156, -0.3907491 ,  0.8121021 ,  0.14819556,  1.0213459 ,
         -0.7819631 ,  0.95659614, -1.4445262 ,  0.06370842, -1.2055752 ]],
       dtype=float32)]
4

onnxruntime C++加载模型推理

需要构建C++工程,引入onnxruntime依赖,支持windows和linux

工程结构

onnxruntime_demo
├── build
│   ├── onnxruntime_demo
│   └── model.onnx
├── CMakeLists.txt
├── onnxruntime
├── README.md
└── src
    └── main.cpp

main.cpp

#include <iostream>
#include <array>
#include <algorithm>
#include "onnxruntime_cxx_api.h"

int main(int argc, char* argv[])
{
    // --- define model path
#if _WIN32
    const wchar_t* model_path = L"./model.onnx"; // you can use string to wchar_t* function to convert
#else
    const char* model_path = "./model.onnx";
#endif

    // --- init onnxruntime env
    Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");

    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

    // set options
    Ort::SessionOptions session_option;
    session_option.SetIntraOpNumThreads(5); // extend the number to do parallel
    session_option.SetGraphOptimizationLevel(ORT_ENABLE_ALL);

    // --- prepare data
    const char* input_names[] = { "input" }; // must keep the same as model export
    const char* output_names[] = { "output" };

    // use statc array to preallocate data buffer
    std::array<float, 1 * 28 * 28> input_matrix;
    std::array<float, 1 * 10> output_matrix;

    // must use int64_t type to match args
    std::array<int64_t, 3> input_shape{ 1, 28, 28 };
    std::array<int64_t, 2> output_shape{ 1, 10 };

    std::vector<std::vector<std::vector<float>>> sample_x = {
        {
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0667, 0.0000, 0.1373, 0.2157, 0.2039, 0.1765, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.0039, 0.9804, 1.0000, 0.9608, 0.9961, 0.9333, 0.9569,0.9373, 0.5412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3451, 0.4863, 0.6667, 0.9961, 0.5412, 0.7333, 1.0000, 0.7333, 0.1255, 0.0157, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3294, 0.3843, 0.0000, 0.7137, 0.8235, 0.9529, 1.0000, 0.1451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0235, 0.2196, 0.2824, 0.3569, 0.5216, 0.1686, 0.0000, 0.9412, 0.8549, 0.0000, 0.0000, 0.1529, 0.1804, 0.0784, 0.0980, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1882, 0.4275, 0.2745, 0.2118, 0.1725, 0.2784, 0.2196, 0.2510, 0.0588, 0.0784, 0.1137, 0.1098, 0.2275, 0.2235, 0.2000, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2549, 0.2863, 0.3216, 0.1922, 0.2275, 0.2039, 0.1255, 0.3294, 0.2706, 0.0980, 0.1961, 0.2471, 0.1804, 0.1059, 0.0980, 0.1137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.3059, 0.2078, 0.5137, 0.1451, 0.2235, 0.2039, 0.0784, 0.3529, 0.3059, 0.0784, 0.2078, 0.2431, 0.1412, 0.0667, 0.1059, 0.1529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.3333, 0.1137, 0.6039, 0.2275, 0.1843, 0.1686, 0.0471, 0.2980, 0.2784, 0.0824, 0.1333, 0.0745, 0.0824, 0.0745, 0.1294, 0.1686, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1686, 0.3098, 0.0510, 0.5373, 0.2549, 0.1608, 0.1647, 0.0392, 0.3294, 0.2627, 0.0510, 0.1137, 0.1098, 0.0706, 0.1608, 0.1765, 0.1137, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2078, 0.2863, 0.0392, 0.5725, 0.3333, 0.1686, 0.1647, 0.0353, 0.3294, 0.2471, 0.0627, 0.1216, 0.0941, 0.0549, 0.1294, 0.1137, 0.1137, 0.0510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2392, 0.2745, 0.0078, 0.6627, 0.4000, 0.1098, 0.1843, 0.0588, 0.3137, 0.2353, 0.0392, 0.1137, 0.1020, 0.0000, 0.3020, 0.1098, 0.1059, 0.0549, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2627, 0.2353, 0.0118, 0.7176, 0.3059, 0.1725, 0.1804, 0.0510, 0.2941, 0.2431, 0.0353, 0.0941, 0.1098, 0.0000, 0.6314, 0.1529, 0.0510, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2902, 0.1961, 0.0157, 0.8706, 0.2745, 0.1451, 0.1804, 0.0627, 0.2941, 0.2549, 0.0275, 0.1020, 0.0627, 0.0000, 0.9490, 0.1804, 0.0275, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2863, 0.1412, 0.0431, 1.0000, 0.2235, 0.1725, 0.2118, 0.0431, 0.2902, 0.2471, 0.0157, 0.1020, 0.0235, 0.0039, 0.8549, 0.2863, 0.0000, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0314, 0.2941, 0.1137, 0.0980, 1.0000, 0.2627, 0.1804, 0.1961, 0.0235, 0.3098, 0.2471, 0.0314, 0.0980, 0.0000, 0.1059, 0.9569, 0.3961, 0.0000, 0.1137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0431, 0.2941, 0.0588, 0.2078, 1.0000, 0.2275, 0.1529, 0.1922, 0.0706, 0.2980, 0.2549, 0.0235, 0.1059, 0.0157, 0.0000, 0.8627, 0.5412, 0.0000, 0.1098, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0431, 0.2902, 0.0196, 0.4039, 0.9961, 0.1922, 0.1882, 0.1804, 0.0510, 0.2863, 0.2549, 0.0078, 0.0980, 0.0196, 0.0000, 0.8196, 0.6941, 0.0000, 0.1176, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0627, 0.2941, 0.0157, 0.4745, 1.0000, 0.1412, 0.1843, 0.2039, 0.0627, 0.2627, 0.2706, 0.0078, 0.0863, 0.0549, 0.0000, 0.7451, 0.7098, 0.0000, 0.1098, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.3059, 0.0000, 0.5098, 0.9961, 0.0824, 0.2314, 0.2275, 0.1098, 0.2902, 0.2824, 0.0039, 0.1059, 0.0941, 0.0000, 0.6863, 0.8000, 0.0000, 0.0941, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.3020, 0.0000, 0.6078, 0.8549, 0.0784, 0.2235, 0.2078, 0.0941, 0.2745, 0.2863, 0.0078, 0.1059, 0.0863, 0.0000, 0.5255, 0.8392, 0.0000, 0.0784, 0.0471, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0941, 0.2941, 0.0000, 0.7255, 0.7451, 0.0824, 0.2510, 0.2314, 0.1294, 0.2824, 0.2824, 0.0157, 0.1020, 0.1216, 0.0000, 0.4784, 0.8549, 0.0118, 0.0667, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.0941, 0.3176, 0.0000, 0.7608, 0.6157, 0.0706, 0.2235, 0.2196, 0.1176, 0.2784, 0.3020, 0.0157, 0.0902, 0.1020, 0.0000, 0.4353, 0.8902, 0.0510, 0.0510, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.1255, 0.3059, 0.0000, 0.8863, 0.5333, 0.2000, 0.3216, 0.2863, 0.1529, 0.2941, 0.3137, 0.0314, 0.1098, 0.1294, 0.0000, 0.4000, 0.9490, 0.0627, 0.0471, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.2745, 0.0118, 0.9176, 0.3294, 0.2039, 0.2941, 0.2941, 0.2235, 0.2510, 0.2588, 0.0745, 0.1608, 0.1529, 0.0314, 0.2039, 0.8549, 0.1765, 0.0235, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.2706, 0.1137, 0.9412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8980, 0.4353, 0.0000, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.2588, 0.3294, 0.1765, 0.4510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4667, 0.3059, 0.0941, 0.1020, 0.0000, 0.0000, 0.0000, 0.0000},
        {0.0000, 0.0000, 0.0000, 0.0000, 0.2118, 0.2784, 0.1216, 0.2000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1412, 0.1176, 0.1059, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000}
        }
    };
    int sample_y = 4;

    // expand input as one dimention array
    for (int i = 0; i < 1; i++)
        for (int j = 0; j < 28; j++)
            for (int k = 0; k < 28; k++)
                input_matrix[i * 1 * 28 + j * 28 + k] = sample_x[i][j][k];

    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());
    Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());

    // --- predict
    Ort::Session session(env, model_path, session_option); // FIXME: must check if model file exist or valid, otherwise this will cause crash
    session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1); // here only use one input output channel

    // --- result
    // just use output_matrix as output, all you can use bellow code to get
    //float* output_buffer = output_tensor.GetTensorMutableData<float>();

    std::cout << "--- predict result ---" << std::endl;
    // matrix output
    std::cout << "ouput matrix: ";
    for (int i = 0; i < 10; i++)
        std::cout << output_matrix[i] << " ";
    std::cout << std::endl;
    // argmax value
    int argmax_value = std::distance(output_matrix.begin(), std::max_element(output_matrix.begin(), output_matrix.end()));
    std::cout << "output argmax value: " << argmax_value << std::endl;

    getchar();

    return 0;
}

推理结果

--- predict result ---
ouput matrix: 0.787004 -0.390761 0.8121 0.148179 1.02133 -0.781948 0.956589 -1.44451 0.0637145 -1.20555
output argmax value: 4

可以看到,在同样的输入情况下,C++和python版本可以得到同样的输出
其实还可以支持多通道输入输出

代码

onnxruntime_demo

 类似资料: