当前位置: 首页 > 工具软件 > C Script > 使用案例 >

Pytorch框架TorchScript模型转换方法

从渊
2023-12-01

为什么要使用TorchScript对模型进行转换?

a)、TorchScript代码可以在它自己的解释器中调用,它本质上是一个受限的Python解释器。这个解释器不获取全局解释器锁,因此可以在同一个实例上同时处理多个请求。

b)、这种格式允许我们将整个模型保存到磁盘上,并将其加载到另一个环境中,比如用Python以外的语言编写的服务器中

c)、TorchScript提供了一种表示方式,我们可以在其中对代码进行编译器优化,以提供更有效的执行

d)、TorchScript允许我们与许多后端/设备运行时进行交互

 

1、trace方法转换模型

trace方法首先使用输入数据执行一遍模型,并记录下模型执行过程中的参数,并创建一个torch.jit.ScriptModule实例。trace方法转换模型示例:

import torch
import numpy as np

class MyCell_v1(torch.nn.Module):
    def __init__(self):
        super(MyCell_v1, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell_v1 = MyCell_v1()
x = torch.ones(3, 4)
h = torch.ones(3, 4)
# 对模型对象进使用trace方法进行转换
traced_cell_1 = torch.jit.trace(my_cell_v1, (x, h))
# 查看转换后的代码
print(traced_cell_1.code)
# 转换后的模型进行推理
print(traced_cell_1(x, h))
# 原始模型进行推理
print(my_cell_v1(x, h))

输出结果如下:

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = torch.add((self.linear).forward(input, ), h, alpha=1)
  _1 = torch.tanh(_0)
  return (_1, _1)

(tensor([[ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349]], grad_fn=<TanhBackward>), tensor([[ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349]], grad_fn=<TanhBackward>))
(tensor([[ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349]], grad_fn=<TanhBackward>), tensor([[ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349],
        [ 0.3909,  0.9382, -0.1499,  0.8349]], grad_fn=<TanhBackward>))

trace弊端:由于要执行一遍模型,当模型中存在循环或者if语句时,不能覆盖所有的模型代码分支。

2、script方法转换模型

a)、trace方法的不足之处分析(Module中包含分支语句)

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x
        
class MyCell_v2(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell_v2, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)
        
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(self.dg(x)) + h)
        return new_h, new_h

dg = MyDecisionGate()
my_cell_v2 = MyCell_v2(dg)

i. 使用Trace方法转换模型,转换后的模型覆盖if分支:

# x 是一个3 x 4的全1单位矩阵,所以 x .sum() > 0 必然成立
x = torch.ones(3, 4)
h = torch.ones(3, 4)
# 对模型进行trace转换,trace 方法把my_cell_v2模型在当前的x和h上执行一遍
# 因为 x .sum() > 0 成立,所以现在MyDecisionGate的forward执行if分支,返回x本身,else分支没有执行
traced_cell_2 = torch.jit.trace(my_cell_v2, (x, h))
# 查看转换后的MyDecisionGate模型对象,转换后的forward里面的if-else不见了,相当于少了一个分支,另一个分支判断肯定会出问题
# 这就是因为trace方法:Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule that does exactly that
# trace方法只记录执行过程中的操作,另一个分支没有执行到,所以记录不到
print(traced_cell_2.dg.code)
print(traced_cell_2.code)
# 查看模型推理结果,traced_cell_2与my_cell_v2的计算结果相同
print(traced_cell_2(x, h))
print(my_cell_v2(x, h))

输出结果如下:

def forward(self,
    x: Tensor) -> None:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.linear
  _1 = (self.dg).forward(x, )
  _2 = torch.add((_0).forward(x, ), h, alpha=1)
  _3 = torch.tanh(_2)
  return (_3, _3)

(tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>), tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>))
(tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>), tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>))

结论:对于输入的x和h,代码应该执行到MyDecisionGate中forward函数的if分支,trace记录到if分支的操作,trace后的模型输出与原模型一致。

ii. 使用Trace方法转后覆盖if分支的模型测试else分支:

# 现在将x乘以-1
# 查看模型推理结果, traced_cell_2与my_cell_v2的计算结果不同
# 由于trace方法生成的trace_cell_2对象只记录了MyDecisionGate中forward方法的if分支,丢弃了else分支
# 所以traced_cell_2把所有输入都当做满足if分支来处理,所以当应该是else处理时(x.sum < 0 时),计算结果就出错了
print(traced_cell_2(-x, h))
print(my_cell_v2(-x, h))

输出结果如下:

(tensor([[-0.2671,  0.1883,  0.2345,  0.6458],
        [-0.2671,  0.1883,  0.2345,  0.6458],
        [-0.2671,  0.1883,  0.2345,  0.6458]], grad_fn=<TanhBackward>), tensor([[-0.2671,  0.1883,  0.2345,  0.6458],
        [-0.2671,  0.1883,  0.2345,  0.6458],
        [-0.2671,  0.1883,  0.2345,  0.6458]], grad_fn=<TanhBackward>))
(tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>), tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>))

结论:对于输入的-x和h,代码应该执行到MyDecisionGate中forward函数的else分支,由于trace中记录到的是if分支的操作,所以trace后的模型对于-x的输出与原模型不一致。

iii. 使用Trace方法转换模型,转换后的模型覆盖else分支:

# 现在将x乘以-1
# 并同时在新的x和h的基础上重新进行trace
traced_cell_3 = torch.jit.trace(my_cell_v2, (-x, h))
print(traced_cell_3.dg.code)
# 查看模型推理结果, traced_cell_3与my_cell_v2的计算结果相同
print(traced_cell_3(-x, h))
print(my_cell_v2(-x, h))

输出结果如下:

def forward(self,
    x: Tensor) -> Tensor:
  return torch.neg(x)

(tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>), tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>))
(tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>), tensor([[0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581],
        [0.9647, 0.7493, 0.9096, 0.4581]], grad_fn=<TanhBackward>))

结论:对于输入的-x和h,代码应该执行到MyDecisionGate中forward函数的else分支,重新执行trace方法进行转换,转换后的模型记录到else分支的操作,trace后的模型输出与原模型一致。

总结:”Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule that does exactly that”trace方法只记录执行过程中遇到的操作,另一个分支没有执行到,所以记录不到。所以trace方法不适用于Module中具有分支和循环结构的模型。

 

b)、Script方法转换模型

x = torch.ones(3, 4)
h = torch.ones(3, 4)
scripted_gate = torch.jit.script(MyDecisionGate())
my_cell_script = MyCell_v2(scripted_gate)
scripted_cell = torch.jit.script(my_cell_script)

print(scripted_gate.code)
print(scripted_cell.code)

输出结果如下:

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.linear).forward((self.dg).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)
print(scripted_cell(x, h))
print(my_cell_script(x, h))

"""
(tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>), tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>))
(tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>), tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>))
"""

print(scripted_cell(-x, h))
print(my_cell_script(-x, h))

"""
(tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>), tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>))
(tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>), tensor([[0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964],
        [0.8947, 0.5101, 0.5769, 0.8964]], grad_fn=<TanhBackward>))
"""

总结:script能够处理Module模型中的控制流。

 

3、trace、script混合方法转换模型

a)、tracescript针对的都是torch.nn.Module类对象及其子类对象

b)、在创建Module对象时就可以使用trace或者script,并根据具体的Module代码逻辑选择使用trace还是script

c)、trace里面可以包含scriptscript里面也可以包含trace

d)、如果Module中包含if分支或者loop循环处理,则使用script进行转换,否则使用trace进行转换

"""
    个人理解,由内而外:
    1、MyDecisionGate是Module的子类,并且其中包含if分支语句,所以要使用script方法进行转换
    2、MyCell_v2是Module的子类,虽然其中包含的MyDecisionGate中包含if分支,但是已经经过script方法转换了,所以对MyCell_v2使用trace方法进行转换
    3、MyRNNLoop_TraceScript是Module的子类,并且其中包含for循环,所以要使用script方法进行转换
"""

class MyRNNLoop_TraceScript(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop_TraceScript, self).__init__()
        self.cell = torch.jit.trace(MyCell_v2(torch.jit.script(MyDecisionGate())), (x, h))
        
    def forward(self, xs):
        h = torch.zeros(3, 4)
        y = torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop_TraceScript())
print(rnn_loop.code)

输出结果如下:

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = (self.cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

4、trace、script转换模型性能测试对比

a)、使用trace/script转换的模型

class MyRNNLoop_TraceScript(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop_TraceScript, self).__init__()
        self.cell = torch.jit.trace(MyCell_v2(torch.jit.script(MyDecisionGate())), (x, h))
        
    def forward(self, xs):
        h = torch.zeros(3, 4)
        y = torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop_TraceScript())
xs = torch.randn(100, 3, 4)
# 统计模型运行时间
print(rnn_loop(xs))
%timeit rnn_loop(xs)

输出结果如下:

(tensor([[-0.0442, -0.9005,  0.6578,  0.0710],
        [ 0.7052,  0.8484,  0.6964, -0.8223],
        [ 0.3576, -0.7521,  0.4776,  0.0805]], grad_fn=<TanhBackward>), tensor([[-0.0442, -0.9005,  0.6578,  0.0710],
        [ 0.7052,  0.8484,  0.6964, -0.8223],
        [ 0.3576, -0.7521,  0.4776,  0.0805]], grad_fn=<TanhBackward>))
13.6 ms ± 815 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

b)、不使用trace/script转换的模型

class MyRNNLoop_Normal(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop_Normal, self).__init__()
        self.cell = MyCell_v2(MyDecisionGate())
        
    def forward(self, xs):
        h = torch.zeros(3, 4)
        y = torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop_normal = MyRNNLoop_Normal()
# 统计模型运行时间
print(rnn_loop_normal(xs))
%timeit rnn_loop_normal(xs)

输出结果如下:

(tensor([[ 0.0571, -0.8888, -0.8796, -0.8841],
        [ 0.8694,  0.3169,  0.6436,  0.4143],
        [ 0.6243, -0.7954, -0.8708, -0.6556]], grad_fn=<TanhBackward>), tensor([[ 0.0571, -0.8888, -0.8796, -0.8841],
        [ 0.8694,  0.3169,  0.6436,  0.4143],
        [ 0.6243, -0.7954, -0.8708, -0.6556]], grad_fn=<TanhBackward>))
30.6 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

5、trace、script模型save和load

a)、保存的模型文件包含代码、参数、属性特征以及debug信息,也就是说保存的模型文件包含所有模型需要的信息

b)、所以这个模型文件可以在任何独立的进程中独立运行,甚至是和python运行环境完全不相关的环境中运行

 

i. 保存模型

# 保存模型
rnn_loop.save("./models/rnn_loop.pth")

ii. 加载模型

# 加载模型
loaded = torch.jit.load("./models/rnn_loop.pth")

iii. 测试模型

# 测试模型
print(loaded(xs))
%timeit loaded(xs)

输出结果如下:

(tensor([[-0.0442, -0.9005,  0.6578,  0.0710],
        [ 0.7052,  0.8484,  0.6964, -0.8223],
        [ 0.3576, -0.7521,  0.4776,  0.0805]], grad_fn=<TanhBackward>), tensor([[-0.0442, -0.9005,  0.6578,  0.0710],
        [ 0.7052,  0.8484,  0.6964, -0.8223],
        [ 0.3576, -0.7521,  0.4776,  0.0805]], grad_fn=<TanhBackward>))
13.5 ms ± 3.78 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

6、总结

script(obj[,optimize,_frames_up,_rcb])

Scripting a function or nn.Module will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a ScriptModule or ScriptFunction.

trace(func, example_inputs[, optimize, …])

Trace a function and return an executable or ScriptFunction that will be optimized using just-in-time compilation.

script_if_tracing(fn)

Compiles fn when it is first called during tracing.

trace_module(mod, inputs[, optimize, …])

Trace a module and return an executable ScriptModule that will be optimized using just-in-time compilation.

fork(func, *args, **kwargs)

Creates an asynchronous task executing func and a reference to the value of the result of this execution.

wait(future)

Forces completion of a torch.jit.Future[T] asynchronous task, returning the result of the task.

ScriptModule()

A wrapper around C++ torch::jit::Module.

ScriptFunction

Functionally equivalent to a ScriptModule, but represents a single function and does not have any attributes or Parameters.

freeze(mod[, preserved_attrs, optimize_numerics])

Freezing a ScriptModule will clone it and attempt to inline the cloned module’s submodules, parameters, and attributes as constants in the TorchScript IR Graph.

save(m, f[, _extra_files])

Save an offline version of this module for use in a separate process.

load(f[, map_location, _extra_files])

Load a ScriptModule or ScriptFunction previously saved with torch.jit.save

ignore([drop])

This decorator indicates to the compiler that a function or method should be ignored and left as a Python function.

unused(fn)

This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception.

isinstance(obj, target_type)

This function provides for conatiner type refinement in TorchScript.

 

 

 

 

 

 

 

 

 

 

 

 

 

1)、torch.jit.trace

trace方法非常适合那些只操作单张量或张量的列表、字典和元组的代码。使用torch.jit.trace和torch.jit.trace_module ,你能将一个模型或python函数转为TorchScript中的ScriptModule或ScriptFunction。根据提供的输入样例,trace将会运行该函数并记录所有张量上执行的操作。Trace方法只会记录当前输入样例执行到的操作,以后不管遇到什么输入都会执行相同的操作,所以不适用于具有分支控制流的场景。

2)、torch.jit.script

可以编译一个带有控制流的function(使用装饰器方式或者函数调用方式)或者Module

3)、torch.jit.trace_module

当传入trace函数的是一个Module时,默认只执行并跟踪Module的forward方法,在trace_module方法中,通过传入一个字典参数(字典的key是Module中方法名称,字典的value是对应方法的输入参数),可以执行并跟踪多个方法。

4)、torch.jit.ignore

在编译时忽视某些方法,被torch.jit.ignore装饰器装饰的函数不会被TorchScript编译,被忽略的函数将直接使用Python解释器执行。如果一个Module中有被ignore忽略的方法,则这个经过TorchScript编译的模型不能执行sav方法导出模型。

5)、torch.jit.unused

在编译时忽视某些方法,被torch.jit.unused装饰器装饰的函数不会被编译到TorchScript中,并使用raise Exception来替代这个函数。

6)、torch.jit.export

编译一个不在forward中调用的方法以及递归地编译其内的所有方法,可在此方法上使用装饰器torch.jit.export

 

TORCHSCRIPT

INTRODUCTION TO TORCHSCRIPT

LOADING A TORCHSCRIPT MODEL IN C++

Pytorch C++部署 之 TorchScript 踩坑记录

在C++平台上部署PyTorch模型流程+踩坑实录

pytorch怎么使用c++调用部署模型?

TORCH.JIT.SCRIPT

TORCHSCRIPT LANGUAGE REFERENCE

TORCH.JIT.TRACE

 类似资料: