NNVM Frontend组件主要负责将多种深度学习框架训练出来的模型转换成如下内容:
NNVM Frontend组件将不同深度学习框架的模型格式统一转换成nnvm.Graph和tvm.nd.array的组合。
本篇文档暂时先只关注nnvm.Graph对象和mxnet模型转换。
相关代码位于:
mxnet模型加载与转换的接口函数为nnvm.frontend.from_mxnet
在介绍转换接口函数前先了解一下nnvm.Graph
这个数据结构,nnvm.Graph
定义于python/nnvm/graph.py
:
nnvm.Graph
用来表示一个graph对象,这个对象可以被用于应用优化pass。它包含了额外的一些计算图级别专用的属性。
class Graph(object):
def json_attr(self, key) # 获取属性字符串
def _set_json_attr(self, key, value, type_name=None) # 设置属性
def json(self) # 获取计算图的json表示
def _tvm_graph_json(self) # 获取TVM计算图的json表示
def ir(self, join_entry_attrs=None, join_node_attrs=None) # 获取计算图IR的文本形式
def apply(self, passes) # 针对某个graph应用pass
Graph对象比较重要的一个函数是apply,具体是通过调用NNGraphApplyPasses
来实现。
接下来介绍一下mxnet模型的具体转换过程。
python/nnvm/frontend/mxnet.py
def _convert_symbol(op_name, inputs, attrs,
identity_list=None,
convert_map=None):
identity_list = identity_list if identity_list else _identity_list
convert_map = convert_map if convert_map else _convert_map
if op_name in identity_list:
op = _get_nnvm_op(op_name)
sym = op(*inputs, **attrs)
elif op_name in convert_map:
sym = convert_map[op_name](inputs, attrs)
else:
_raise_not_supported('Operator: ' + op_name)
return sym
针对单个运算符的转换过程主要由_convert_symbol
函数完成,其中涉及到两个运算符列表
python/nnvm/frontend/mxnet.py
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
'broadcast_add', 'broadcast_div', 'broadcast_mul',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
# _convert_map列表较长,只列出部分运算符
_convert_map = {
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
'Cast' : _rename('cast'),
'Concat' : _concat,
'Convolution' : _conv2d,
'Convolution_v1': _conv2d,
'Deconvolution' : _conv2d_transpose,
'Dropout' : _dropout,
}
获取到运算符名称之后可以通过_get_nnvm_op
函数来获取nnvm运算符
python/nnvm/frontend/mxnet.py
from .. import symbol as _sym
def _get_nnvm_op(op_name):
op = getattr(_sym, op_name)
if not op:
raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name))
return op
_get_nnvm_op
的主要功能是通过getattr
内建函数来获取nnvm op对象,这个函数能获取到所有通过NNVM_REGISTER_OP
注册的运算符