瞄准痛点:静态图好部署,动态图易调试,但两者难以兼得
import megengine.functional as F
from megengine.jit import trace
# import trace之后设置 enabled 属性切换动静态图
trace.enabled = True # 开启trace,使用静态图模式
# 使用 trace 类装饰网络 forward 的函数
@trace
def train_func(data, label, *, opt, net):
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return pred, loss
# 调用函数训练网络,动静态图一套代码
train_func(data, label, opt=optimizer, net=le_net)
瞄准痛点:框架学习接口各异,模型复现困难,学习成本高
import megengine as mge
import megengine.functional as F
import megengine.module as M
import numpy as np
# 经典的基于 Module 的网络搭建接口
class LeNet(M.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = M.Conv2d(1, 6, 5)
self.relu1 = M.ReLU()
self.pool1 = M.MaxPool2d(2, 2)
# 省略部分代码...
self.classifer = M.Linear(84, 10)
# 符合 Pythonic 风格的计算流程代码
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
# 省略部分代码...
x = self.classifer(x)
return x
瞄准痛点:生产环境计算设备繁多,缺乏优秀性能
瞄准痛点:从研究到生产,流程复杂,精度难以对齐
from megengine.jit import trace
# 使用 trace 类装饰网络 forward 的函数
@trace
def val_func(x, *, net):
return net(x)
# 调用trace接口无需运行直接编译网络
val_func.trace(inp, net=net)
# 将编译后的网络进行导出,直接生成可用于部署的序列化文件
val_func.dump('./mnist.mge', arg_names=["data"])