参考:https://pytorch.org/docs/stable/fx.html
FX 是针对 torch.nn.module
而开发的工具,其能动态地获取 model 前向传播的执行过程,以便动态地增加、删除、改动、检查运算操作。其由三个主要组件组成:符号追踪器(Symbolic Tracer)、中间表示(Intermediate Representation, IR)和 Python 代码生成。这三个组件常常同时出现,如下面的例子:
import torch
# 一个简单的模型
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# 符号追踪。捕获模型的forward的内容
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# 查看该模型的 IR图
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# 查看由 IR 图生成的 Python 代码
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
forward
)定义。总的来说,FX 的使用流程为:符号跟踪->中间表示->转换->Python代码生成。这是一种 Python-to-Python 的方法。FX 的精髓在于“Dynamic Transformation”,即当你需要对模型进行额外改动设计(如插入量化节点、算子 Fusion)时,不需要繁琐地针对模型的每一个部分来修改代码,只需要按照 FX 的流程来高效自动化地实现。
fx.Graph
生成而来的 nn.Module
,其有对应的 graph
、code
成员变量。当 graph
成员变量被重新赋值过,code
变量和 forward()
函数回自动重新生成。如果你编辑过 graph
的内容却没有重新赋值过,那你必须调用 recompile()
函数来更新信息。torch.fx.symbolic_trace()
函数作用完后 return
的就是 GraphModule
。Node
组成。这一一系列的 Node
就构成了执行逻辑。torch.fx.Tracer.trace()
函数作用完后 return 的就是 Graph
。graph
中操作的单位数据结构。大多数情况下,Node
代表了各种实体的调用方式,如输入(Input)、输出(Output)、算子(Operator)、已执行的成员函数(Method)和子模型(Module)。每个 Node
都有一个 op
属性,具体分类如下:
placeholder
:表示整个模型的输入。get_attr
:表示从模型层次结构中检索参数。call_function
:表示将自由函数应用于某些值。call_module
:表示将模型层次结构的 forward()
成员函数中的子模块应用于给定参数。call_method
:表示对某值调用成员函数。output
:这与打印 graph
输出中的 return
语句内容相对应。Node Wrapper
,用于流经程序的执行过程并记录下所有的操作(被调用的 torch function、method 和 operator)。若没有主动设置的话,Pytorch 会生成默认的 Proxy
用于符号追踪 。 对模型的图进行额外改动的方法有很多,如直接获取图并修改图(Direct Graph Manipulation),或通过在 GraphModule
模型上间接获取图来修改图(GraphModule Modification)。
GraphModule
的 Graph
中的所有 Node
。Node
是否满足替换要求(可以用 target
属性作为判断条件)。Node
并插入到 Graph
中。Node
的输入输出流(flow)重新定向到新 Node
身上。Graph
中删除旧 Node。recompile()
函数来更新 GraphModule
。下面一个例子展示 FX 如何将任何加法操作替换成二进制与(AND)运算:
import torch
from torch.fx import symbolic_trace
import operator
# 定义一个简单的模型
class M(torch.nn.Module):
def forward(self, x, y):
return x + y, torch.add(x, y), x.add(y)
# 进行符号追踪
traced = symbolic_trace(M())
# 加法操作有三种:
# 1. x + y,其成为 Node 时的 target 为 operator.add。
# 2. torch.add(x, y),其成为 Node 时的 target 为 torch.add.
# 3. x.add(y),其成为 Node 时的 target 为字符串 "add".
patterns = set([operator.add, torch.add, "add"])
# 遍历 Graph 中所有 node
for n in traced.graph.nodes:
# 如果满足 pattren 之一
if any(n.target == pattern for pattern in patterns):
# 在指定位置插入新 node (还没建立连接关系)
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
# 建立连接,将旧 node 的连接关系重定向到新 node 上。
n.replace_all_uses_with(new_node)
# 从 Graph 中删除旧 node
traced.graph.erase_node(n)
# 必须 recompile!
traced.recompile()
另一个修改 Graph
的方式是利用 Proxy
,再在一次主动 Trace
的过程中复制 Node
、构建新 Node
来组成新的 Graph
。
import torch
import torch.fx as fx
import torch.nn.functional as F
# 定义一个简单的模型
class M(torch.nn.Module):
def forward(self, x, y):
o = F.relu(x) + F.relu(y)
return o
# 数学定义
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose_relu(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
# 这相当于一个探针,将旧graph里需要用到的node的名字映射到新graph里对应的node。
mapping_table = {} # {old node name : new node object}
# 遍历 node
for node in graph.nodes:
# 判断是否是 relu函数。
if node.op == 'call_function' and node.target in decomposition_rules:
# 用于记录 proxy当前绑定
proxy_args = []
# node的arg即输入/上一个node。这一步其实就为该 node 生成对应输入的 proxy。
for x in node.args:
if isinstance(x, fx.Node):
proxy_args.append(fx.Proxy(mapping_table[x.name]))
else:
proxy_args.append(x)
# 这一步就是在“穿线”。穿线完毕后,与Proxy绑定的Graph也自动完成了:
# 在原末尾插新加入的node并建立连接。proxy会自动绑定到下一个(输出)node上,
# 依次类推,最后就变成 output_proxy
output_proxy = decomposition_rules[node.target](*proxy_args)
# 获取 当前proxy绑定的node。
new_node = output_proxy.node
mapping_table[node.name] = new_node
else: # 当 node 不需要被拆解时,只需要复制到新graph里就好。
# 该函数就是实现旧node与新node的映射关系。
def node_mapping(x):
return mapping_table[x.name]
# node_copy 确实吧 node 拷贝过来了,同时还建立了连接。
# 其会访问 node 的 原来所有输入的node,然后再利用opera
# 重定向来给新生成的node建立在目标Graph上的连接。
new_node = new_graph.node_copy(node, node_mapping)
mapping_table[node.name] = new_node
# 最后返回的模型绑定的是新 graph
return fx.GraphModule(model, new_graph)
decompose_relu(M())
Proxy
可以想象为一个“穿线器”:绑定 Node
后,在经过新的 Node
时能自动“串”好连接关系并加入到原 Graph
中。能记录此时的“线头”,即记录访问到的 Node
。
下面一个例子展示 FX 是如何通过 GraphModule 间接替换 torch.add() 为 torch.mul() 的:
import torch
import torch.fx as fx
# 定义一个简单模型
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
# 下面尝试用替换 target 的方式来改动 graph (不提倡,因为对应的 node 的 name 没有改动!)
def transform(m: torch.nn.Module) -> torch.nn.Module:
gm : fx.GraphModule = fx.symbolic_trace(m)
# FX 的 IR 图是顺序储存节点,所以可以遍历
for node in gm.graph.nodes:
# 检查该节点是否是函数操作 (i.e: torch.add)
if node.op == 'call_function':
# 确认是该节点是函数操作时
if node.target == torch.add:
node.target = torch.mul
gm.recompile() # 重新编译 GraphModule,更新 code 属性
gm.graph.lint() # 最后需要检查修改过的 IR 图是否符合FX语法
return gm
transform(M())
PyTorch 官方将 if 语句、循环语句等具有选择/判断性质的语句称为控制流。在 FX语境中,控制流又可以分为动态控制流(Dynamic Control Flow)和静态控制流(Static Control Flow)。
FX 无法 trace
动态控制流,但可以 trace
判断条件明确的静态控制流。
若控制流的判断条件含有运算变量(Input Tensor)参与,那么该控制流就称为动态控制流,如:
def func_to_trace(x):
if x.sum() > 0:
# 可以看到x变量既参与计算,又参与判断
return torch.relu(x)
else:
return torch.neg(x)
此时对该函数使用 trace
功能就会报错:
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
类推可知,若控制流的判断条件无运算变量参与,也即判断条件的变量不参与流(Flow)计算,那么该控制流就称为静态控制流,如:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# 该if语句就是静态控制流
if self.do_activation:
x = torch.relu(x)
return x
若想 trace
静态控制流,就需要明确判断条件,即给判断变量显式赋值:
without_activation = MyModule(do_activation=False)
# 然后就可以 trace
traced_without_activation = torch.fx.symbolic_trace(without_activation)
有些函数没有__torch_function__
属性,例如 Python 自带的函数或 math
库中的函数,无法被 trace
追踪。例如,当你的模型里调用了 len() 函数,那么进行 trace
时会报错:
"""
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want ")
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
那么需要使用 wrap()
API 来将普通函数包装成 torch 性质的函数:
torch.fx.wrap('len')
# 然后就可以正常 trace 了
traced = torch.fx.symbolic_trace(normalize)
如:
# 模型定义过程就不展示了
print(traced_model.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
通过调用 print_tabular()
函数就可以以 tabular 的格式输出 IR 图:
# 模型定义过程就不展示了
traced_model.graph.print_tabular()
"""
opcode name target args kwargs
------------- -------- ----------------------- ----------- ------------------------
placeholder x x () {}
get_attr param param () {}
call_function add_1 <built-in function add> (x, param) {}
call_module linear_1 linear (add_1,) {}
call_method clamp_1 clamp (linear_1,) {'min': 0.0, 'max': 1.0}
output output output (clamp_1,) {}
"""
目前,FX 没有提供任何方式来保证/验证运算符在语法上是有效的。也就是说,任何新(定义)加入的运算符都必须由用户自己来保证其正确性。
最后官网建议的一点是,你在对 Graph
做变换时,应该让整个程序的输入 torch.nn.Module
,然后获取对应的 Graph
,做出修改,最后再返回一个 torch.nn.Module
。这样更方便后续工作,比如又传入下一段 FX 代码中。
以上总结如有谬误,还请包涵、指正。