当前位置: 首页 > 工具软件 > tvm > 使用案例 >

TVM Pass概述

辛可人
2023-12-01

是什么

Pass又称transform,每一个transform要么把现有程序转换并优化为一个等价的程序,要么把程序lower到下层。
Pass和Schedule的区别在于,前者包括一些Schedule Primitives(调度原语),其用于生成IR,而后者是提供了修改IR的方法。

TVM中的Pass有两种:

  • Relay层的Pass。relay/transforms/包括很多优化图结构用的Pass,包括fusion(图融合),常量折叠(constant folding)和死代码删除(dead-code elimination)等。属于前端优化。
  • TIR层的Pass。tir/transforms包括偏向编译器方面的优化,比如prefetch注入,unrollLoop等。属于后端优化。

实现上,Pass分为:

  • Module-Level Pass
    • 利用全局信息进行优化
    • 可以删减Function,如DSE Pass
    • 核心Pass函数是PackedFunc类型
  • Function-Level Pass
    • 对Module中的每个Function进行优化,只有局部信息
    • 不允许删减Function

Pass的转化逻辑可以简化为:IRModule -> Pass -> … -> IRModule

这里拿tests/python/relay/test_pass_fold_constant.py里的单测作为例子:

import numpy as np
import tvm
from tvm import te
import tvm.relay as relay
c_data = np.array([1, 2, 3]).astype("float32")
t = relay.TensorType([1, 2, 3], "float32")
def example():
    c = relay.const(c_data)
    x = relay.var("x", t)
    y = relay.add(c, c)
    y = relay.multiply(y, relay.const(2, "float32"))
    y = relay.add(x, y)
    z = relay.add(y, c)
    return relay.Function([x], z)

可以将其IRModule打印出来:

f = example()
mod = tvm.IRModule.from_expr(f)
print(mod)

得到一个不经过任何pass优化的script:

def @main(%x: Tensor[(1, 2, 3), float32]) -> Tensor[(1, 2, 3), float32] {
  %0 = add(%x, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(1, 2, 3), float32] */;
  add(%0, meta[relay.Constant][1] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(1, 2, 3), float32] */
}

对应关系:i个声明为const(relay.const)的数据会被储存在meta[relay.Constant][i]位置上。

加一个fold_constant pass后得到log信息(在运行时,需要export TVM_LOG_DEBUG="relay/transforms/fold_constant.cc=1"来指定需要debug的cc文件):

fold_const = relay.transform.FoldConstant()
mod = fold_const(mod)
print(mod)
[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/runtime/logging.cc:239: TVM_LOG_DEBUG enables VLOG statements in 'relay/transforms/fold_constant.cc' up to level 1
[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:414: FoldConstant: FoldConstantExpr: folding:
fn (%x: Tensor[(1, 2, 3), float32]) {
  %0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);
  %1 = multiply(%0, 2f);
  %2 = add(%x, %1);
  add(%2, meta[relay.Constant][0])
}

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:247: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating :
add(meta[relay.Constant][0], meta[relay.Constant][0])

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:259: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant:
meta[relay.Constant][0]

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:247: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating :
multiply(meta[relay.Constant][0], 2f)

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:259: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant:
meta[relay.Constant][0]

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:416: FoldConstant: FoldConstantExpr: folded to:
fn (%x: Tensor[(1, 2, 3), float32]) {
  %0 = add(%x, meta[relay.Constant][0]);
  add(%0, meta[relay.Constant][1])
}

def @main(%x: Tensor[(1, 2, 3), float32]) -> Tensor[(1, 2, 3), float32] {
  %0 = add(%x, meta[relay.Constant][0] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(1, 2, 3), float32] */;
  add(%0, meta[relay.Constant][1] /* ty=Tensor[(3), float32] */) /* ty=Tensor[(1, 2, 3), float32] */
}

常量折叠的主要目的是,将代码中所有的常量用它的值替换。

常量折叠是一个在编译时期简化常量的一个过程,常量在表示式中仅仅代表一个简单的数值,就像是整数 2,若是一个变量从未被修改也可作为常量,或者直接将一个变量被明确地被标注为常量。

先分析其中的一段pass,它对应原先script中的第一行:

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:247: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating :
add(meta[relay.Constant][0], meta[relay.Constant][0])

[17:30:51] /home/yuan/Coding/compiler/repos/tvm/src/relay/transforms/fold_constant.cc:259: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant:
meta[relay.Constant][0]

可以推断pass直接把add(meta[relay.Constant][0], meta[relay.Constant][0])的结果算出来了,并inplace地替代掉了原来的值。
这样不断地替换,也称为常量传播。

怎么用

自定义Pass

所有Pass需要继承自ExprFunctor接口。

  • AST遍历。用于确定哪些Node需要修改。重载VisitExpr_/VisitExpr来定义。VisitExpr将特定类型的Expr分派到对应的VisitExpr_上。
  • 节点修改。Expression Mutators,用于修改和替换满足条件的Node。

PassInfo: the basic information needed by a pass。保存Pass name、opt_level优化等级和pass依赖。
PassContext,Pass上下文:全局的信息,包括错误信息,当前启用的Pass和禁用的Pass等。
memo_:一个Map,记录哪些node是常量node。

参考资料

use pass infra
Design and Architecture
博客

 类似资料: