layerNormaliztion , attention,self_attention,multi_head_attention keras代码

齐文林
2023-12-01
from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Add, Multiply, Dot
from keras.layers import Embedding, Permute, Reshape, GaussianNoise
from keras.layers.core import Dropout, Lambda, Dense, Flatten
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D, Conv1D
from keras.layers.pooling import GlobalMaxPooling1D, GlobalAveragePooling2D, AveragePooling2D
from keras.layers.merge import concatenate, Concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, LearningRateScheduler
from keras.optimizers import Adam, SGD, Nadam
from keras import backend as K

from keras.engine.topology import Layer
import tensorflow as tf


class LayerNormalization(keras.layers.Layer):
    
    def __init__(self,
                 center = True,
                 scale = True,
                 epsilon = True,
                 gamma_initializer = "ones",
                 beta_initializer = "zeros",
                 gamma_regularizer = None,
                 beta_regularizer = None,
                 gamma_constraint = None,
                 beta_constraint = None,
                 **kwargs):
        """
        Layer normalization layer
            refference: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
                :param center: Add an offset parameter if it is True.
                :param scale: Add a scale parameter if it is True.
                :param epsilon: Epsilon for calculating variance.
                :param gamma_initializer: Initializer for the gamma weight.
                :param beta_initializer: Initializer for the beta weight.
                :param gamma_regularizer: Optional regularizer for the gamma weight.
                :param beta_regularizer: Optional regularizer for the beta weight.
                :param gamma_constraint: Optional constraint for the gamma weight.
                :param beta_constraint: Optional constraint for the beta weight.
                :param kwargs:
        """        
        super(LayerNormalization,self).__init__(**kwargs)
        self.supports_masking = True
        self.center = center
        self.scale = scale
        if epsilon is None:
            epsilon = K.epsilon() * K.epsilon()
        self.epsilon = epsilon
        self.gamma_initializer = keras.initializers.get(gamma_initializer)
        self.beta_initializer = keras.initializers.get(beta_initializer)
        self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
        self.beta_regularizer = keras.regularizers.get(beta_regularizer)
        self.gamma_constraint = keras.constraints.get(gamma_constraint)
        self.beta_constraint = keras.constraints.get(beta_constraint)
        self.gamma, self.beta = None, None

    def get_config(self):
        config = {
            'center': self.center,
            'scale': self.scale,
            'epsilon': self.epsilon,
            'gamma_initializer': keras.initializers.serialize(self.gamma_initializer),
            'beta_initializer': keras.initializers.serialize(self.beta_initializer),
            'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer),
            'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer),
            'gamma_constraint': keras.constraints.serialize(self.gamma_constraint),
            'beta_constraint': keras.constraints.serialize(self.beta_constraint),
        }
        base_config = super(LayerNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

    def compute_mask(self, inputs, input_mask=None):
        return input_mask

    def build(self, input_shape):
        shape = input_shape[-1:]
        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
                name='gamma',
            )
        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
                name='beta',
            )
        super(LayerNormalization, self).build(input_shape)

    def call(self, inputs, training=None):
        mean = K.mean(inputs, axis=-1, keepdims=True)
        variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
        std = K.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        if self.scale:
            outputs *= self.gamma
        if self.center:
            outputs += self.beta
        return outputs        



def attention(x, n_factor, dropout):
    """
    attention layer @ Morxrc
        attention - 类似于RNN的效果,收集全局信息,方便并行化处理
        Conv1D - 1D卷积,Q,K,V,分别为Query,Key,Value, Q,K为比较相似度所使用,Value为节点本身信息的向量
                 目前主流的NLP研究中,key和value常常都是同一个,即key=value。n_factor卷积核数量也就是输出维度
                 1为kernel_size 表示卷积窗口的长度。
        Permute - Permute层是置换模式,即(2,1)就是置换输入的第一和第二个维度,即转置所用
        axis = -1 - 在最后一个维度进行操作
    """
    x_Q = Conv1D(n_factor,1,activation="linear",
                 kernel_initializer='glorot_uniform',
                 bias_initializer='glorot_uniform',
                )(x)
    x_K = Conv1D(n_factor, 1, activation='linear', 
                  kernel_initializer='glorot_uniform',
                  bias_initializer='glorot_uniform',
                 )(x)
    x_V =  Conv1D(n_factor, 1, activation='linear', 
                  kernel_initializer='glorot_uniform',
                  bias_initializer='glorot_uniform',
                 )(x)

    x_KT = Permute((2,1))(x_K)
    res = Lambda(lambda c:K.batch_dot(c[0],c[1])/np.sqrt(n_factor))([x_Q,x_KT])
    att = Lambda(lambda c:K.softmax(c,axis=-1))(res)
    att = Lambda(lambda c:K.batch_dot(c[0],c[1]))([att,x_V])
    return att

def self_attention(x, n_factor, dropout):
    att = attention(x,n_factor,dropout)
    att = LayerNormalization()(att)
    if dropout > 0:
        att = Dropout(dropout)(att)
    x = Add()([x,att])
    return x

def multi_head_attention(x,n_factor,n_head,dropout):
    n_factor_head = n_factor // n_head
    # n_factor_head 
    heads = [attention(x,n_factor_head,dropout) for i in range(n_head)]
    att = Concatenate()(heads)
    att = Dense(n_factor,
                kernel_initializer = "glorot_uniform",
                bias_initializer = "glorot_uniform",
               )(att)
    # add & Norm
    x = Add()([x,att])
    x = LayerNormalization()(x)
    if dropout > 0:
        x = Dropout(dropout)(x)
    return x

 

 类似资料: