NNVM Compiler组件是和使用者比较近的一个组件,本篇文档将详细阅读相关代码。
NNVM Compiler组件中比较重要的函数是nnvm.compiler.build
。
可以将nnvm.compiler.build
的执行过程总结为如下步骤:
参考文档1为了快速了解NNVM和TVM是如何交互的,只讲解了步骤6,本文档将介绍所有步骤。
python/nnvm/compiler/build_module.py build函数
# 如果需要时校正Layout
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply("CorrectLayout")
index = graph.index
layouts = graph.json_attr("layout")
layout = {x : layouts[index.entry_id(x)] for x in index.input_names}
graph.apply
在之前的参考文档1中已经讲解了,对于"CorrectLayout"
这个Pass而言会调用每个操作符的FCorrectLayout
,操作符的FCorrectLayout
函数是由参考文档3里面讲解的NNVM Top组件C++部分定义的。
下面从一个比较简单的操作符max_pool2d
入手理解FCorrectLayout
的功能。
src/top/nn/pooling.cc
NNVM_REGISTER_OP(max_pool2d)
.set_attr<FCorrectLayout>("FCorrectLayout", Pool2DCorrectLayout)
inline bool Pool2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Pool2DParam ¶m = nnvm::get<Pool2DParam>(attrs.parsed);
CHECK_EQ(ilayouts->size(), 1);
CHECK_EQ(last_ilayouts->size(), 1);
CHECK_EQ(olayouts->size(), 1);
Layout input = (*ilayouts)[0];
const Layout layout(param.layout);
if (input.defined()) {
CHECK(input.convertible(layout)) << "Invalid input layout " << input;
if (input.indexof('W') != layout.indexof('W') ||
input.indexof('H') != layout.indexof('H') ||
input.contains('w') || input.contains('h')) {
// as long as the index doesn't change for width and height
// pool2d can keep the input layout.
input = layout;
}
} else {
input = layout;
}
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, input);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, input);
return true;
}
/*
* Pool2DCorrectLayout主要功能为:
* 1. 如果input layout没有指定,则设置成默认layout
* 2. 如果input layout已经指定,但是和默认layout不一致则校正成默认layout
*/
python/nnvm/compiler/build_module.py build函数
ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape))
if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype))
python/nnvm/compiler/graph_util.py
# infer_shape函数功能为利用提供的输入节点的shape信息计算计算图涉及节点的推理shape
def infer_shape(graph, **shape):
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape
下面从一个比较简单的操作符max_pool2d
入手理解InferShape
的功能。
/src/pass/infer_shape_type.cc
// InferShape的主要实现是调用了FInferShape操作符函数
NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body([](Graph ret) {
return InferAttr<TShape>(
std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr);
})
.set_change_graph(false)
.provide_graph_attr("shape");
/src/top/nn/pooling.cc
NNVM_REGISTER_OP(max_pool2d)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
/src/top/elemwise_op_common.h
template<int n_in, int n_out>
inline bool ElemwiseType(const NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
}
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}
template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
std::string (*attr_string)(const AttrType&),
int n_in = -1, int n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
if (n_in != -1)
in_size = static_cast<size_t>(n_in);
if (n_out != -1)
out_size = static_cast<size_t>(n_out);
auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&dattr, (*vec)[i]))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
}
};
deduce(in_attrs, in_size, "input");
if (reverse_infer) deduce(out_attrs, out_size, "output");
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
}
};
write(in_attrs, in_size, "input");
write(out_attrs, out_size, "output");
if (is_none(dattr)) return false;
return true;
}
在推测出推理阶段使用的shape参数之后,可以申请变量空间和初始化变量
python/nnvm/compiler/build_module.py build函数
init_var = {}
if _all_var_init:
init_var = initialize_variables(shape, dtype)
def initialize_variables(ishape, idtype):
""" Initialize variables stored in _all_var_init dictionary.
Parameters
----------
ishape : dict of str to tuple of int
The input shape to the graph
idtype : str or dict of str to str
The input types to the graph
Returns
-------
init_var : dict of str to tvm.ndarray
"""
symbol_init_dict = {}
const_init_dict = {}
init_var = {}
for key, value in _all_var_init.items():
if isinstance(value, sym.Symbol):
symbol_init_dict[key] = value
else:
const_init_dict[key] = tvm.nd.array(value)
# Make sure variables are initialized only once.
_all_var_init.clear()
if symbol_init_dict:
# Create dummy params to run initialization graph
params = {}
for name, shape in ishape.items():
dtype = idtype if isinstance(idtype, str) else idtype[name]
params[name] = tvm.nd.empty(shape, dtype, ctx=tvm.cpu())
init_group_sym = sym.Group(symbol_init_dict.values())
graph = _graph.create(init_group_sym)
with tvm.build_config(auto_unroll_max_step=0):
init_values = _run_graph(graph, params)
init_var.update(dict(zip(symbol_init_dict.keys(), init_values)))
init_var.update(const_init_dict)
for name, data in init_var.items():
ishape[name] = data.shape
return init_var
graph = optimize(graph, shape, dtype, layout)
optimize函数主要针对计算图应用图优化相关的pass:
预计算裁剪的主要功能是通过预计算整个计算图的一部分,然后排除掉一些和前向推理无关的计算节点
参考文档1已经讲解了这个步骤的一部分知识,这里将进行继续介绍两个pass:
GraphFusePartition的功能主要是将可以fuse的节点放到一个segment中,以供后面编译使用。
GraphFuseCompile是进行lowering编译的过程,其中调用了nnvm.compiler.lower
。
nnvm.compiler.lower
的定义位于tvm/python/tvm/build_module.py
# 代码节选
def lower(sch,
args,
name="default_function",
binds=None,
simple_mode=False):
"""
Parameters
----------
sch : tvm.Schedule
需要被编译的调度器
"""
...
# normalize schedule first
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
...
lower接口中的参数sch的类型为tvm.Schedule,这里sch是由nnvm Top组件和TOPI组件一起决定的。
lower的具体实现中值得注意的是schedule.ScheduleOps这个函数,利用sch生成HalideIR::Internal::Stmt
表达式。
tvm/src/api/api_schedule.cc
TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 2)
*ret = ScheduleOps(args[0], args[1], false);
else
*ret = ScheduleOps(args[0], args[1], args[2]);
});
tvm/src/schedule/schedule_ops.cc
Stmt ScheduleOps(Schedule sch, Map<IterVar, Range> dom_map_, bool debug_keep_trivial_loop) {
Stmt body = Stmt();
std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_);
// scan init and scan updates
std::unordered_map<Operation, Operation> scan_init;
for (Stage s : sch->stages) {
const ScanOpNode* scan = s->op.as<ScanOpNode>();
if (!scan) continue;
for (Tensor t : scan->init) {
if (scan_init.count(t->op)) {
CHECK(scan_init.at(t->op).same_as(s->op))
<< "Scan init tensor can only belong to one scan";
} else {
scan_init[t->op] = s->op;
}
}
}
// 确认group的正确性.
for (Stage g : sch->groups) {
CHECK(!g->op.defined());
CHECK_EQ(g->leaf_iter_vars.size(), 0U);
}
// reverse the post DFS order.
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage s = sch->stages[i - 1];
CHECK_NE(s->attach_type, kInline)
<< "call schedule.normalize before scheduleops";
CHECK(s->op.defined());
// no need to specify place holder op.
if (s->op.as<PlaceholderOpNode>()) continue;
// Remove grouping sugar, get the real attach spec.
Stage attach_spec = s.GetAttachSpec();
if (scan_init.count(s->op)) {
CHECK(body.defined());
InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.init";
} else if (attach_spec->attach_type == kScanUpdate) {
// Handle scan update
CHECK(body.defined());
InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop);
body = mu.Mutate(body);
CHECK(mu.found_attach)
<< "did not find attachment point for scan.update";
} else if (attach_spec->attach_type == kInlinedAlready) {
// do nothing
} else if (attach_spec->attach_type == kGroupRoot) {
CHECK(!s->group.defined());
body = MakePipeline(s, dom_map, body, debug_keep_trivial_loop);
} else {
CHECK_EQ(attach_spec->attach_type, kScope);
CHECK(body.defined());
InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop);
body = mutator.Mutate(body);
CHECK(mutator.found_attach)
<< "did not find attachment point for " << s << " in "
<< attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar
<< ", body:\n"
<< body;
}
}
SchedulePostProc post_proc;
post_proc.Init(sch);
return post_proc.Mutate(body);
}