当前位置: 首页 > 知识库问答 >
问题:

如何将PReLU合并到量化模型中?

云开诚
2023-03-14

我试图量化一个使用PReLU的模型。用ReLU替换PReLU是不可能的,因为它会严重影响网络性能,甚至毫无用处。

据我所知,PReLU在Pytorch中不支持量化。因此,我尝试手动重写这个模块,并使用torch实现乘法和加法。FloatFunctional()以绕过此限制。

这就是我到目前为止提出的问题:

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.weight = prelu_object.weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)    
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
        inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
        self.weight = self.dequant(self.weight)
        return inputs

对于替换:

class model(nn.Module):
     def __init__(self)
         super().__init__()
         .... 
        self.prelu = PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
         ....

基本上,我读取了现有prelu模块的学习参数,并在一个新模块中自己运行计算。该模块似乎在某种意义上工作,它没有失败整个应用程序。

然而,为了评估我的实现是否真的正确并产生与原始模块相同的结果,我尝试对其进行测试
这里是正常模型(即非量化模型)的对应项:
由于某种原因,实际的PReLU与我的实现之间的误差非常大!

以下是不同层中的示例差异:

diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374

差值是这样计算的:

def forward(self, x):
    residual = x
    out = self.bn0(x)
    out = self.conv1(out)
    out = self.bn1(out)

    out = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out - out2).mean().item()}')

    out = self.conv2(out)
...

这是我在普通模型上使用的正常实现(即没有量化!)要评估它是否产生正确的结果,然后转到量化版本:

class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        x = self.weight
        tmin, _ = torch.min(inputs,dim=0)
        tmax, _ = torch.max(inputs,dim=0)
        weight_min_res = torch.mul(x, tmin)
        inputs = torch.add(tmax, weight_min_res)
        inputs = inputs.unsqueeze(0)
        return inputs

我错过了什么?

共有1个答案

隆谦
2023-03-14

我想通了!我一开始就犯了一个巨大的错误。我需要计算

PReLU(x)=max(0,x)+a∗min(0,x)
class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        pos = torch.relu(inputs)
        neg = -self.weight * torch.relu(-inputs)
        inputs = pos + neg
        return inputs

这是量化版本:

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

附带说明:
我在计算差异时也有一个打字错误:

    out = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out - out2).mean().item()}')

    out = self.conv2(out)

需要

    out1 = self.prelu(out)
    out2 = self.prelu2(out)
    print(f'diff : {( out1 - out2).mean().item()}')
    out = self.conv2(out1)

如果您面临量化问题,您可以尝试以下版本:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized as nnq
from torch.quantization import fuse_modules


class QPReLU(nn.Module):
    def __init__(self, num_parameters=1, init: float = 0.25):
        super(QPReLU, self).__init__()
        self.num_parameters = num_parameters
        self.weight = nn.Parameter(torch.Tensor(num_parameters).fill_(init))
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.f_mul_neg_one1 = nnq.FloatFunctional()
        self.f_mul_neg_one2 = nnq.FloatFunctional()
        self.f_mul_alpha = nnq.FloatFunctional()
        self.f_add = nnq.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.quant2 = torch.quantization.QuantStub()
        self.quant3 = torch.quantization.QuantStub()
        # self.dequant2 = torch.quantization.QuantStub()
        self.neg_one = torch.Tensor([-1.0])
        
    
    def forward(self, x):
        x = self.quant(x)
        
        # PReLU, with modules only
        x1 = self.relu1(x)
        
        neg_one_q = self.quant2(self.neg_one)
        weight_q = self.quant3(self.weight)
        x2 = self.f_mul_alpha.mul(
            weight_q, self.f_mul_neg_one2.mul(
                self.relu2(
                    self.f_mul_neg_one1.mul(x, neg_one_q),
                ),
            neg_one_q)
        )
        
        x = self.f_add.add(x1, x2)
        x = self.dequant(x)
        return x
    
m1 = nn.PReLU()
m2 = QPReLU()

# check correctness in fp
for i in range(10):
    data = torch.randn(2, 2) * 1000
    assert torch.allclose(m1(data), m2(data))

# toy model
class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.prelu = QPReLU()
        
    def forward(self, x):
        x = self.prelu(x)
        return x
    
# quantize it
m = M()
m.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(m, inplace=True)
# calibrate
m(torch.randn(4, 4))
# convert
torch.quantization.convert(m, inplace=True)
# run some data through
res = m(torch.randn(4, 4))
print(res)

一定要看这里的注释

 类似资料:
  • 我有一个包含Conv2D层的网络,然后是ReLU激活,声明如下: 它被移植到TFLite,具有以下代表性: 无Q感知训练的基本TFLite网络 然而,在网络上执行量化感知训练并再次移植后,ReLU层现在在图中是明确的: TFLite网络在Q感知训练后 这导致它们在目标上被单独处理,而不是在Conv2D内核的评估期间,在我的整个网络中导致10%的性能损失。 使用以下隐式语法声明激活不会产生问题: 具

  • 问题内容: 我有两个对象:a 和a 。 将它们合并为单个的最佳方法是什么? 这些列在数据库中分别存储。我通过JDBC 和获得它们。 问题答案: 您可以创建两个Calendar实例。在第一个中,您初始化日期,在第二个中,时间。您可以从“时间”实例中提取时间值,并将其设置为“日期”。

  • 此代码不断要求用户提供双精度值,直到用户输入空行。当用户输入非双精度值(如字符串)时,将显示“无效输入”消息。目前,即使用户输入空行,也会显示无效的输入消息,我理解原因。当我输入空行时,获得无效输入不显示的最佳方法是什么。我不确定是否有办法使用try-catch,或者我只需要使用其他东西。

  • 我从Rabbitmq中的队列收到了这条json消息: 然后我需要映射到这个模型类: 为此,我在@RabbitListener类中执行了此操作: 另一方面,我有一个服务类,它在customer类中为我提供customer对象,该对象需要添加到我的模型类中,并提供一个特定的id,如下所示: 最后,我的问题是如何将这个对象添加到模型类中?所以我可以在《邮递员》中看到这样的内容:

  • 我使用以下代码生成量化的tflite模型 但是根据训练后量化: 生成的模型将完全量化,但为了方便起见,仍然采用浮点输入和输出。 要为Google Coral Edge TPU编译tflite模型,我还需要量化输入和输出。 在模型中,我看到第一个网络层将浮点输入转换为,最后一个网络层将转换为浮点输出。如何编辑tflite模型以除去第一个和最后一个浮动层? 我知道我可以在转换期间将输入和输出类型设置为

  • 我们已经知道,数据模型+模板=输出,我们有了一个数据模型 (root) 和一个模板 (temp), 为了得到输出就需要合并它们。这是由模板的 process 方法完成的。它用数据模型root和 Writer 对象作为参数,然后向 Writer 对象写入产生的内容。 为简单起见,这里我们只做标准的输出: Writer out = new OutputStreamWriter(System.out);