当前位置: 首页 > 知识库问答 >
问题:

Keras的BatchNormalization和PyTorch的BatchNormal2D之间的区别?

萧嘉禧
2023-03-14

我有一个在Keras和PyTorch中实现的小型CNN示例。当我打印这两个网络的摘要时,可训练参数的总数是相同的,但参数总数和批量规范化的参数数量不匹配。

以下是CNN在Keras的实施情况:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

为上述模型打印的摘要是:

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 131072)            0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               13107300  
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1010      
=================================================================
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64
_________________________________________________________________

下面是PyTorch中相同模型架构的实现:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
        nn.ReLU(True),
        nn.BatchNorm2d(num_features=32),
    )
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

以下是上述模型摘要的输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
================================================================
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02
----------------------------------------------------------------

正如您在上面的结果中看到的,Keras中的批处理归一化比PyTorch有更多的参数(确切地说是2倍)。那么上述CNN架构有什么区别?如果它们是等效的,那么我在这里错过了什么?

共有1个答案

姬和歌
2023-03-14

Keras将许多将在图层中“保存/加载”的东西视为参数(权重)。

虽然这两种实现自然都具有批次的累积“平均值”和“方差”,但这些值不能通过反向传播进行训练。

尽管如此,这些值每批都会更新一次,Keras将它们视为不可训练的权重,而PyTorch只是将它们隐藏起来。这里的术语“不可训练”意味着“不可通过反向传播进行训练”,但并不意味着值被冻结。

它们总共是BatchNormize层的4组“权重”。考虑到选定的轴(默认=-1,您的图层的大小=32)

  • 缩放比例(32)-可培训

在Keras中这样做的好处是,当您保存图层时,您还可以保存均值和方差值,就像您自动保存图层中的所有其他权重一样。当您加载图层时,这些权重会一起加载。

 类似资料:
  • 问题内容: 我错放了太多次了,我想我一直忘记,因为我不知道两者之间的区别,只是一个给了我我期望的价值,而另一个却没有。 为什么是这样? 问题答案: 是的简写形式(尽管请注意,该表达式只会被计算一次。) 是的,即指定一元的到。 例子:

  • 问题内容: 因此,我有一段简单的代码可以打印出整数1-10: 然后,如果仅在第3行上更改一个运算符,它将打印出无限数量的1整数(我知道为什么会这样做)。为什么在运行第二个程序时没有出现语法错误?如果赋值运算符后面跟着一个加法运算符,它不会调用语法错误吗? 问题答案: 与相同, 只是意味着。

  • 问题内容: 有人可以解释一下 和 我不知道“确切”的含义 问题答案: 在这个例子中,什么都没有。当您具有多个具有相似名称的路径时,该参数将起作用: 例如,假设我们有一个显示用户列表的组件。我们还有一个用于创建用户的组件。的网址应嵌套在下。因此,我们的设置可能如下所示: 现在,这里的问题是,当我们转到路由器时,将通过所有定义的路由,并返回它找到的第一个匹配项。因此,在这种情况下,它将首先找到路线,然

  • 我在尝试Python-Selenium的XPath。 我使用这个链接来尝试教程中的一些XPath: 所以我尝试了XPaths的这两个变体。 返回9个结果 “//”如何匹配5个更多的结果?

  • 问题内容: 我很好奇printStackTrace()和toString()之间的区别是什么。乍一看,他们 似乎 做的完全相同。 码: 问题答案: 不,有重要区别!使用toString,您只有异常的类型和错误消息。使用printStackTrace()可以获得异常的整个堆栈跟踪,这对于调试非常有帮助。 System.out.println(toString())的示例: printStackTra