PyTorch Python API:FX || Intro

郭永安
2023-12-01

参考:https://pytorch.org/docs/stable/fx.html

Intro

  FX 是针对 torch.nn.module 而开发的工具,其能动态地获取 model 前向传播的执行过程,以便动态地增加、删除、改动、检查运算操作。其由三个主要组件组成:符号追踪器(Symbolic Tracer)、中间表示(Intermediate Representation, IR)和 Python 代码生成。这三个组件常常同时出现,如下面的例子:

import torch
# 一个简单的模型
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# 符号追踪。捕获模型的forward的内容
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# 查看该模型的 IR图
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# 查看由 IR 图生成的 Python 代码
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

  • 符号追踪器会对代码执行“符号”。其喂入(自己生成的)假数据(Proxy),来执行代码。由 Proxy 经过、执行的代码会被记录下来。
  • 中间表示是在 Trace 期间记录各种操作的“容器”。其由一个节点(Node)列表组成,这些节点表示了函数的输入、名字和返回值。
  • Python 代码生成是一种代码生成工具,可以根据当前 IR 图的内容生成正确、可执行的 Python 代码。这代码是可以复制出来黏贴使用的,可以用于进一步配置模型的(forward)定义。

  总的来说,FX 的使用流程为:符号跟踪->中间表示->转换->Python代码生成。这是一种 Python-to-Python 的方法。FX 的精髓在于“Dynamic Transformation”,即当你需要对模型进行额外改动设计(如插入量化节点、算子 Fusion)时,不需要繁琐地针对模型的每一个部分来修改代码,只需要按照 FX 的流程来高效自动化地实现。

FX 定义的类对象

  • GraphModule:是由 fx.Graph 生成而来的 nn.Module,其有对应的 graphcode 成员变量。当 graph 成员变量被重新赋值过,code 变量和 forward() 函数回自动重新生成。如果你编辑过 graph 的内容却没有重新赋值过,那你必须调用 recompile() 函数来更新信息。torch.fx.symbolic_trace() 函数作用完后 return 的就是 GraphModule
  • Graph:是 FX 的 IR 图的主要数据结构,由一系列有序的 Node 组成。这一一系列的 Node 就构成了执行逻辑。torch.fx.Tracer.trace() 函数作用完后 return 的就是 Graph
  • Node:是 graph 中操作的单位数据结构。大多数情况下,Node 代表了各种实体的调用方式,如输入(Input)、输出(Output)、算子(Operator)、已执行的成员函数(Method)和子模型(Module)。每个 Node 都有一个 op 属性,具体分类如下:
    • placeholder:表示整个模型的输入。
    • get_attr:表示从模型层次结构中检索参数。
    • call_function:表示将自由函数应用于某些值。
    • call_module:表示将模型层次结构的 forward() 成员函数中的子模块应用于给定参数。
    • call_method:表示对某值调用成员函数。
    • output:这与打印 graph 输出中的 return 语句内容相对应。
  • Proxy:在符号追踪期间会用到。其本质上是一个 Node Wrapper,用于流经程序的执行过程并记录下所有的操作(被调用的 torch function、method 和 operator)。若没有主动设置的话,Pytorch 会生成默认的 Proxy 用于符号追踪 。

Example for Transformation

  对模型的图进行额外改动的方法有很多,如直接获取图并修改图(Direct Graph Manipulation),或通过在 GraphModule 模型上间接获取图来修改图(GraphModule Modification)。

Direct Graph Manipulation

简单替换 Node (利用 Pattern)

  1. 遍历 GraphModuleGraph 中的所有 Node
  2. 判断当前 Node 是否满足替换要求(可以用 target 属性作为判断条件)。
  3. 创建一个新的 Node 并插入到 Graph 中。
  4. 使用 FX 内置的 replace_all_uses_with 函数来将要被替换 Node 的输入输出流(flow)重新定向到新 Node 身上。
  5. Graph 中删除旧 Node。
  6. 调用 recompile() 函数来更新 GraphModule

  下面一个例子展示 FX 如何将任何加法操作替换成二进制与(AND)运算:

import torch
from torch.fx import symbolic_trace
import operator

# 定义一个简单的模型
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y, torch.add(x, y), x.add(y)

# 进行符号追踪
traced = symbolic_trace(M())

# 加法操作有三种:
#     1. x + y,其成为 Node 时的 target 为 operator.add。
#     2. torch.add(x, y),其成为 Node 时的 target 为 torch.add.
#     3. x.add(y),其成为 Node 时的 target 为字符串 "add".
patterns = set([operator.add, torch.add, "add"])

# 遍历 Graph 中所有 node
for n in traced.graph.nodes:
    # 如果满足 pattren 之一
    if any(n.target == pattern for pattern in patterns):
        # 在指定位置插入新 node (还没建立连接关系)
        with traced.graph.inserting_after(n):
            new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
            # 建立连接,将旧 node 的连接关系重定向到新 node 上。
            n.replace_all_uses_with(new_node)
        # 从 Graph 中删除旧 node
        traced.graph.erase_node(n)

# 必须 recompile!
traced.recompile()

复杂替换 Node(利用 Proxy)

  另一个修改 Graph 的方式是利用 Proxy,再在一次主动 Trace 的过程中复制 Node、构建新 Node 来组成新的 Graph

import torch
import torch.fx as fx
import torch.nn.functional as F
# 定义一个简单的模型
class M(torch.nn.Module):
    def forward(self, x, y):
        o = F.relu(x) + F.relu(y)
        return o

# 数学定义
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose_relu(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:

    graph : fx.Graph = tracer_class().trace(model)

    new_graph = fx.Graph()
    # 这相当于一个探针,将旧graph里需要用到的node的名字映射到新graph里对应的node。
    mapping_table = {}  # {old node name : new node object}
    # 遍历 node
    for node in graph.nodes:
        # 判断是否是 relu函数。
        if node.op == 'call_function' and node.target in decomposition_rules:
            # 用于记录 proxy当前绑定
            proxy_args = []
            # node的arg即输入/上一个node。这一步其实就为该 node 生成对应输入的 proxy。
            for x in node.args:
                if isinstance(x, fx.Node):
                    proxy_args.append(fx.Proxy(mapping_table[x.name])) 
                else:
                    proxy_args.append(x)
            # 这一步就是在“穿线”。穿线完毕后,与Proxy绑定的Graph也自动完成了:
            # 在原末尾插新加入的node并建立连接。proxy会自动绑定到下一个(输出)node上,
            # 依次类推,最后就变成 output_proxy
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # 获取 当前proxy绑定的node。
            new_node = output_proxy.node
            mapping_table[node.name] = new_node
        else: # 当 node 不需要被拆解时,只需要复制到新graph里就好。
            # 该函数就是实现旧node与新node的映射关系。
            def node_mapping(x):
                return mapping_table[x.name]
            # node_copy 确实吧 node 拷贝过来了,同时还建立了连接。
            # 其会访问 node 的 原来所有输入的node,然后再利用opera
            # 重定向来给新生成的node建立在目标Graph上的连接。
            new_node = new_graph.node_copy(node, node_mapping)
            mapping_table[node.name] = new_node
        # 最后返回的模型绑定的是新 graph
    return fx.GraphModule(model, new_graph)

decompose_relu(M())

  Proxy 可以想象为一个“穿线器”:绑定 Node 后,在经过新的 Node 时能自动“串”好连接关系并加入到原 Graph 中。能记录此时的“线头”,即记录访问到的 Node

GraphModule Modification

下面一个例子展示 FX 是如何通过 GraphModule 间接替换 torch.add() 为 torch.mul() 的:

import torch
import torch.fx as fx

# 定义一个简单模型
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)
# 下面尝试用替换 target 的方式来改动 graph (不提倡,因为对应的 node 的 name 没有改动!)
def transform(m: torch.nn.Module) -> torch.nn.Module:
    gm : fx.GraphModule = fx.symbolic_trace(m)
    # FX 的 IR 图是顺序储存节点,所以可以遍历
    for node in gm.graph.nodes:
        # 检查该节点是否是函数操作 (i.e: torch.add)
        if node.op == 'call_function':
            # 确认是该节点是函数操作时
            if node.target == torch.add:
                node.target = torch.mul

    gm.recompile()  # 重新编译 GraphModule,更新 code 属性
    gm.graph.lint() # 最后需要检查修改过的 IR 图是否符合FX语法

    return gm

transform(M())

符号追踪的局限性(注意事项)

控制流(Control Flow)

  PyTorch 官方将 if 语句循环语句等具有选择/判断性质的语句称为控制流。在 FX语境中,控制流又可以分为动态控制流(Dynamic Control Flow)和静态控制流(Static Control Flow)。

  FX 无法 trace 动态控制流,但可以 trace 判断条件明确的静态控制流。

动态控制流

  若控制流的判断条件含有运算变量(Input Tensor)参与,那么该控制流就称为动态控制流,如:

def func_to_trace(x):
    if x.sum() > 0:
        # 可以看到x变量既参与计算,又参与判断
        return torch.relu(x)
    else:
        return torch.neg(x)

  此时对该函数使用 trace 功能就会报错:

"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

静态控制流

  类推可知,若控制流的判断条件无运算变量参与,也即判断条件的变量不参与流(Flow)计算,那么该控制流就称为静态控制流,如:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # 该if语句就是静态控制流
        if self.do_activation:
            x = torch.relu(x)
        return x

  若想 trace 静态控制流,就需要明确判断条件,即给判断变量显式赋值:

without_activation = MyModule(do_activation=False)
# 然后就可以 trace
traced_without_activation = torch.fx.symbolic_trace(without_activation)

非 torch 函数

  有些函数没有__torch_function__属性,例如 Python 自带的函数或 math 库中的函数,无法被 trace 追踪。例如,当你的模型里调用了 len() 函数,那么进行 trace 时会报错:

"""
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want ")
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

  那么需要使用 wrap() API 来将普通函数包装成 torch 性质的函数:

torch.fx.wrap('len')
# 然后就可以正常 trace 了
traced = torch.fx.symbolic_trace(normalize)

查看 Graph 内容

通过 print() 函数

如:

# 模型定义过程就不展示了
print(traced_model.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

通过 print_tabular() 函数

  通过调用 print_tabular() 函数就可以以 tabular 的格式输出 IR 图:

# 模型定义过程就不展示了
traced_model.graph.print_tabular()
"""
opcode         name      target                   args         kwargs
-------------  --------  -----------------------  -----------  ------------------------
placeholder    x         x                        ()           {}
get_attr       param     param                    ()           {}
call_function  add_1     <built-in function add>  (x, param)   {}
call_module    linear_1  linear                   (add_1,)     {}
call_method    clamp_1   clamp                    (linear_1,)  {'min': 0.0, 'max': 1.0}
output         output    output                   (clamp_1,)   {}
"""

总结

  目前,FX 没有提供任何方式来保证/验证运算符在语法上是有效的。也就是说,任何新(定义)加入的运算符都必须由用户自己来保证其正确性。

  最后官网建议的一点是,你在对 Graph 做变换时,应该让整个程序的输入 torch.nn.Module,然后获取对应的 Graph,做出修改,最后再返回一个 torch.nn.Module。这样更方便后续工作,比如又传入下一段 FX 代码中。

  以上总结如有谬误,还请包涵、指正。

 类似资料:

相关阅读

相关文章

相关问答