当前位置: 首页 > 工具软件 > x-hook > 使用案例 >

Pytorch获取中间层信息-hook函数

阚小云
2023-12-01

参考链接:https://www.cnblogs.com/hellcat/p/8512090.html
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数包括tensor的hook和nn.Module的hook,用法相似。hook函数在使用后应及时删除,以避免每次都运行钩子增加运行负载。hook函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用hook技术就更合适一些

Tensor对象

参考:https://pytorch.org/docs/stable/tensors.html
有如下的register_hook(hook)方法,为Tensor注册一个backward hook,用来获取变量的梯度。
hook必须遵循如下的格式:hook(grad) -> Tensor or None,其中grad为获取的梯度
具体的实例如下:

import torch

grad_list = []
def print_grad(grad):
    grad = grad * 2
    grad_list.append(grad)

x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
h = x.register_hook(print_grad)    # double the gradient
out = x.pow(2).sum()
out.backward()
print(grad_list)
'''
[tensor([[ 4., -4.],
        [ 4.,  4.]])]
'''
# 删除hook函数
h.remove()

Module对象

register_forward_hook(hook)register_backward_hook(hook)两种方法,分别对应前向传播和反向传播的hook函数。

register_forward_hook(hook)

在网络执行forward()之后,执行hook函数,需要具有如下的形式:

hook(module, input, output) -> None or modified output

hook可以修改input和output,但是不会影响forward的结果。最常用的场景是需要提取模型的某一层(不是最后一层)的输出特征,但又不希望修改其原有的模型定义文件,这时就可以利用forward_hook函数。

import torch
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

features = []
def hook(module, input, output):
    features.append(output.clone().detach())


net = LeNet()
x = torch.randn(2, 3, 32, 32)
handle = net.conv2.register_forward_hook(hook)
y = net(x)

print(features[0].size())
handle.remove()

register_backward_hook(hook)

每一次module的inputs的梯度被计算后调用hook,hook必须具有如下的签名:

hook(module, grad_input, grad_output) -> Tensor or None

grad_inputgrad_output参数分别表示输入的梯度和输出的梯度,是不能修改的,但是可以通过return一个梯度元组tuple来替代grad_input
展示一个实例来解析grad_inputgrad_output参数:

import torch
import torch.nn as nn


def hook(module, grad_input, grad_output):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)


x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = nn.Linear(3, 1)
handle = module.register_backward_hook(hook)
y = module(x)
y.backward()
print('module_weight: ', module.weight.grad)

handle.remove()

输出:

grad_input:  (tensor([1.]), tensor([[ 0.1236, -0.0232, -0.5687]]), tensor([[ 1.],
        [ 2.],
        [10.]]))
grad_output:  (tensor([[1.]]),)
module_weight:  tensor([[ 1.,  2., 10.]])

可以看出,grad_input元组包含(bias的梯度输入x的梯度权重weight的梯度),grad_output元组包含输出y的梯度。
可以在hook函数中通过return来修改grad_input

import torch
import torch.nn as nn


def hook(module, grad_input, grad_output):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)
    return grad_input[0] * 0, grad_input[1] * 0, grad_input[2] * 0,


x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = nn.Linear(3, 1)
handle = module.register_backward_hook(hook)
y = module(x)
y.backward()
print('module_bias: ', module.bias.grad)
print('x: ', x.grad)
print('module_weight: ', module.weight.grad)

handle.remove()

输出:

grad_input:  (tensor([1.]), tensor([[ 0.1518,  0.0798, -0.3170]]), tensor([[ 1.],
        [ 2.],
        [10.]]))
grad_output:  (tensor([[1.]]),)
module_bias:  tensor([0.])
x:  tensor([[0., 0., -0.]])
module_weight:  tensor([[0., 0., 0.]])

对于没有参数的Module,比如nn.ReLU来说,grad_input元组包含(输入x的梯度),grad_output元组包含(输出y的梯度)。

def hook(module, grad_input, grad_output):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)
    return (grad_input[0] / 4, )


x = torch.tensor([-1., 2., 10.], requires_grad=True)
module = nn.ReLU()
handle = module.register_backward_hook(hook)
y = module(x).sum()
z = y * y
z.backward()

print(x.grad)  # tensor([0., 6., 6.])
handle.remove()

输出:

grad_input:  (tensor([ 0., 24., 24.]),)
grad_output:  (tensor([24., 24., 24.]),)
tensor([0., 6., 6.])

y = R e L U ( x 1 ) + R e L U ( x 2 ) + R e L U ( x 3 ) y=ReLU(x_{1})+ReLU(x_{2})+ReLU(x_{3}) y=ReLU(x1)+ReLU(x2)+ReLU(x3)
z = y 2 z=y^{2} z=y2
grad_output是传到ReLU模块的输出值的梯度,即 ∂ z ∂ y = 2 y = 24 \frac{\partial z}{\partial y}=2y=24 yz=2y=24
grad_input是进入ReLU模块的输入值的梯度,由 ∂ y ∂ x 1 = 0 , ∂ y ∂ x 2 = 1 , ∂ y ∂ x 3 = 1 \frac{\partial y}{\partial x_{1}}=0,\frac{\partial y}{\partial x_{2}}=1,\frac{\partial y}{\partial x_{3}}=1 x1y=0,x2y=1,x3y=1,可得:
∂ z ∂ y ∂ y ∂ x 1 = 0 , ∂ z ∂ y ∂ y ∂ x 2 = 24 , ∂ z ∂ y ∂ y ∂ x 3 = 24 \frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{1}}=0,\frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{2}}=24,\frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{3}}=24 yzx1y=0,yzx2y=24,yzx3y=24
在hook函数中可以对输入值 x x x的梯度进行缩放:
[ 0 , 24 , 24 ] / 4 = [ 0 , 6 , 6 ] [0,24,24]/4=[0,6,6] [0,24,24]/4=[0,6,6]

 类似资料: