编写自己的层

优质
小牛编辑
135浏览
2023-12-01

对于简单的定制操作,我们或许可以通过使用layers.core.Lambda层来完成。但对于任何具有可训练权重的定制层,你应该自己来实现。

这里是一个Keras层应该具有的框架结构(1.1.3以后的版本,如果你的版本更旧请升级),要定制自己的层,你需要实现下面三个方法

  • build(input_shape):这是定义权重的方法,可训练的权应该在这里被加入列表`self.trainable_weights中。其他的属性还包括self.non_trainabe_weights(列表)和self.updates(需要更新的形如(tensor, new_tensor)的tuple的列表)。你可以参考BatchNormalization层的实现来学习如何使用上面两个属性。这个方法必须设置self.built = True,可通过调用super([layer],self).build()实现

  • call(x):这是定义层功能的方法,除非你希望你写的层支持masking,否则你只需要关心call的第一个参数:输入张量

  • get_output_shape_for(input_shape):如果你的层修改了输入数据的shape,你应该在这里指定shape变化的方法,这个函数使得Keras可以做自动shape推断

from keras import backend as K
from keras.engine.topology import Layer

class MyLayer(Layer):
    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(shape=(input_shape[1], self.output_dim),
                                initializer='random_uniform',
                                trainable=True)
        super(MyLayer, self).build()  # be sure you call this somewhere! 

    def call(self, x, mask=None):
        return K.dot(x, self.W)

    def get_output_shape_for(self, input_shape):
        return (input_shape[0] + self.output_dim)

调整旧版Keras编写的层以适应Keras1.0

以下内容是你在将旧版Keras实现的层调整为新版Keras应注意的内容,这些内容对你在Keras1.0中编写自己的层也有所帮助。

  • 你的Layer应该继承自keras.engine.topology.Layer,而不是之前的keras.layers.core.Layer。另外,MaskedLayer已经被移除。

  • build方法现在接受input_shape参数,而不是像以前一样通过self.input_shape来获得该值,所以请把build(self)转为build(self, input_shape)

  • 请正确将output_shape属性转换为方法get_output_shape_for(self, train=False),并删去原来的output_shape

  • 新层的计算逻辑现在应实现在call方法中,而不是之前的get_output。注意不要改动__call__方法。将get_output(self,train=False)转换为call(self,x,mask=None)后请删除原来的get_output方法。

  • Keras1.0不再使用布尔值train来控制训练状态和测试状态,如果你的层在测试和训练两种情形下表现不同,请在call中使用指定状态的函数。如,x=K.in_train_phase(train_x, test_y)。例如,在Dropout的call方法中你可以看到:

return K.in_train_phase(K.dropout(x, level=self.p), x)
  • get_config返回的配置信息可能会包括类名,请从该函数中将其去掉。如果你的层在实例化时需要更多信息(即使将config作为kwargs传入也不能提供足够信息),请重新实现from_config。请参考LambdaMerge层看看复杂的from_config是如何实现的。

  • 如果你在使用Masking,请实现compute_mas(input_tensor, input_mask),该函数将返回output_mask。请确保在__init__()中设置self.supports_masking = True

  • 如果你希望Keras在你编写的层与Keras内置层相连时进行输入兼容性检查,请在__init__设置self.input_specs或实现input_specs()并包装为属性(@property)。该属性应为engine.InputSpec的对象列表。在你希望在call中获取输入shape时,该属性也比较有用。

  • 下面的方法和属性是内置的,请不要覆盖它们

    • __call__

    • add_input

    • assert_input_compatibility

    • set_input

    • input

    • output

    • input_shape

    • output_shape

    • input_mask

    • output_mask

    • get_input_at

    • get_output_at

    • get_input_shape_at

    • get_output_shape_at

    • get_input_mask_at

    • get_output_mask_at

现存的Keras层代码可以为你的实现提供良好参考,阅读源代码吧!