【Transformer from “Attention is all you need“ implementation】—— pytorch

仇承志
2023-12-01

Components

Embeddings

class Embeddings(nn.Module):
  def __init__(self, vocab_size, embedding_dim):
    """
    vocab_size: 词表大小
    embedding_dim: 词嵌入维度
    """
    super().__init__()
    # lut : look up table, embed 过程其实类似于查找表的映射
    self.lut = nn.Embedding(vocab_size, embedding_dim) 
    self.embedding_dim = embedding_dim

  def forward(self, x):
    """
    对词汇映射后的数字张量进行缩放,控制数值大小
    """
    return self.lut(x) * math.sqrt(self.embedding_dim)

PositionalEncoding

class PositionalEncoding(nn.Module):
  def __init__(self, embedding_dim, dropout, max_len = 5000):
  
    super().__init__()
    
    self.dropout = nn.Dropout(p = dropout)

    pos_encode = torch.zeros(max_len, embedding_dim)
    position = torch.arange(0, max_len).unsqueeze(1)

    div_term = torch.exp(torch.arange(0, embedding_dim, 2) * 
                        - (math.log(1000.0) / embedding_dim))
    
    pos_encode[:, 0::2] = torch.sin(position * div_term)
    pos_encode[:, 1::2] = torch.cos(position * div_term)

    pos_encode = pos_encode.unsqueeze(0)

    self.register_buffer('pos_encode', pos_encode)

  def forward(self, x):
    x = x + Variable(self.pos_encode[:, :x.size(1)], requires_grad = False)
    return self.dropout(x)

MultiHeadedAttention

class MultiHeadedAttention(nn.Module):
  def __init__(self, head, embedding_dim, dropout = 0.1):
    super().__init__()

    assert embedding_dim % head == 0 

    self.head_dim = embedding_dim // head

    self.head = head
    self.embedding_dim = embedding_dim

    self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)
    self.attention = None
    self.dropout = nn.Dropout(p = dropout)

  def forward(self, query, key, value, mask = None):
    if mask is not None:
      mask = mask.unsqueeze(1)
    
    batch_size = query.size(0)

    query, key, value = \
      [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)
      for model, x in zip(self.linears, (query, key, value))]

    x, self.attention = attention(query, key, value, mask = mask, dropout = self.dropout)
    x = x.transpose(1, 2).view(batch_size, -1, self.head * self.head_dim)

    return self.linears[-1](x)

PositionwiseFeedForward

class PositionwiseFeedForward(nn.Module):
  def __init__(self, embedding_dim, ff_dim, dropout):
    super().__init__()

    self.l1 = nn.Linear(embedding_dim, ff_dim)
    self.l2 = nn.Linear(ff_dim, embedding_dim)
    self.dropout = nn.Dropout(p = dropout)

  def forward(self, x):
    return self.l2(self.dropout(F.relu(self.l1(x))))

LayerNorm

class LayerNorm(nn.Module):
  def __init__(self, embedding_dim, eps = 1e-6):
    super().__init__()

    self.w1 = nn.Parameter(torch.ones(embedding_dim))
    self.w2 = nn.Parameter(torch.zeros(embedding_dim))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1, keeepdim = True)
    std = x.std(-1, keepdim = True)
    return self.w1 * (x - mean) / (std + self.eps) + self.w2

SublayerConnection

class SublayerConnection(nn.Module):
  def __init__(self, embedding_dim, dropout = 0.1):
    super().__init__()

    self.norm = LayerNorm(embedding_dim)
    self.dropout = nn.Dropout(p = dropout)

  def forward(self, x, func):
    return x + self.dropout(func(self.norm(x)))

Clones

方便重复实现相同结构的层

def clones(module, N):
  return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

Attention

def attention(query, key, value, mask = None, dropout = None):
  embedding_dim = query.size(-1)

  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(embedding_dim)

  if mask is not None:
    scores = scores.masked_fill(mask == 0, -1e9)

  scores = F.softmax(scores, dim = -1)

  if dropout is not None:
    scores = dropout(scores)

  return torch.matmul(scores, value), scores

Encoder

EncoderLayer

class EncoderLayer(nn.Module):
  def __init__(self, embedding_dim, self_attention, feed_forward, dropout):
    super().__init__()

    self.self_attention = self_attention
    self.feed_forward = feed_forward
    self.sublayer = clones(SublayerConnection(embedding_dim, dropout), 2)
    self.embedding_dim = embedding_dim
  
  def forward(self, x, mask):
    x = self.sublayer[0](x, lambda x : self.self_attention(x, x, x, mask))
    return self.sublayer[1](x, self.feed_forward)

Encoder

class Encoder(nn.Module):
  def __init__(self, encoder_layer, N):
    super().__init__()

    self.encoder_layers = clones(encoder_layer, N)
    self.norm = LayerNorm(encoder_layer.embedding_dim)

  def forward(self, x, mask):
    for encoder_layer in self.encoder_layers:
      x = encoder_layer(x, mask)

    return self.norm(x)

Decoder

DecoderLayer

class DecoderLayer(nn.Module):
  def __init__(self, embedding_dim, self_attention, attention, feed_forward, dropout):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.self_attention = self_attention
    self.attention = attention
    self.feed_forward = feed_forward

    self.sublayer = clones(SublayerConnection(embedding_dim, dropout), 3)

  def forward(self, x, encode_kv, src_mask, trg_mask):
    x = self.sublayer[0](x, lambda x : self.self_attention(x, x, x, trg_mask))
    x = self.sublayer[1](x, lambda x : self.attention(x, encode_kv, encode_kv, src_mask))

    return self.sublayer[2](x, self.feed_forward)

Decoder

class Decoder(nn.Module):
  def __init__(self, decoder_layer, N):
    super().__init__()

    self.decode_layers = clones(decoder_layer, N)
    self.norm = LayerNorm(decoder_layer.embedding_dim)

  def forward(self, x, encode_kv, src_mask, trg_msk):
    for decoder_layer in self.decoder_layers:
      x = decoder_layer(x, encode_kv, src_mask, trg_msk)

    return self.norm(x)

Transformer

OutputLayer

class OutputLayer(nn.Module):
  def __init__(self, embedding_dim, vocab_size):
    super().__init__()

    self.linear = nn.Linear(embedding_dim, vocab_size)

  def forward(self, x):
    return F.log_softmax(self.linear(x), dim = -1)

Transformer

class Transofrmer(nn.Module):
  def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
    super().__init__()

    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.trg_embed = trg_embed
    self.generator = generator

  def forward(self, source, target, src_mask, trg_mask):
    return self.decode(self.encode(source, src_mask), src_mask, target, trg_mask)

  def encode(self, source, src_mask):
    return self.encoder(self.src_embed(source), src_mask)

  def decode(self, encode_kv, src_mask, target, trg_mask):
    return self.decoder(self.trg_embed(target), encode_kv, src_mask, trg_mask)

def build_model(src_vocab_size, trg_vocab_size, N = 6, embedding_dim = 512, ff_dim = 2048, head  = 8, dropout = 0.1):
  c = copy.deepcopy
  attention = MultiHeadedAttention(head, embedding_dim, dropout)
  ff = PositionwiseFeedForward(embedding_dim, ff_dim, dropout)
  position = PositionalEncoding(embedding_dim, dropout)

  model = Transofrmer(
      Encoder(EncoderLayer(embedding_dim, c(attention), c(ff), dropout), N),
      Decoder(DecoderLayer(embedding_dim, c(attention), c(attention), c(ff), dropout), N),
      nn.Sequential(Embeddings(src_vocab_size, embedding_dim), c(position)),
      nn.Sequential(Embeddings(trg_vocab_size, embedding_dim), c(position)),
      OutputLayer(embedding_dim, trg_vocab_size)
  )

  for p in model.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform_(p)

  return model

Example

embedding_dim = 512
ff_dim = 64
head = 8
N = 6
dropout = 0.1
src_vocab_size = 10
trg_vocab_size = 10

model = build_model(src_vocab_size, trg_vocab_size, N, embedding_dim, ff_dim, head, dropout)
print(model)
Transofrmer(
  (encoder): Encoder(
    (encoder_layers): ModuleList(
      (0): EncoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (1): EncoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (2): EncoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (3): EncoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (4): EncoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (5): EncoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (decode_layers): ModuleList(
      (0): DecoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (1): DecoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (2): DecoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (3): DecoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (4): DecoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (5): DecoderLayer(
        (self_attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attention): MultiHeadedAttention(
          (linears): ModuleList(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): Linear(in_features=512, out_features=512, bias=True)
            (3): Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (l1): Linear(in_features=512, out_features=64, bias=True)
          (l2): Linear(in_features=64, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (src_embed): Sequential(
    (0): Embeddings(
      (lut): Embedding(10, 512)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (trg_embed): Sequential(
    (0): Embeddings(
      (lut): Embedding(10, 512)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (generator): OutputLayer(
    (linear): Linear(in_features=512, out_features=10, bias=True)
  )
)
 类似资料: