class Embeddings(nn.Module):
def __init__(self, vocab_size, embedding_dim):
"""
vocab_size: 词表大小
embedding_dim: 词嵌入维度
"""
super().__init__()
# lut : look up table, embed 过程其实类似于查找表的映射
self.lut = nn.Embedding(vocab_size, embedding_dim)
self.embedding_dim = embedding_dim
def forward(self, x):
"""
对词汇映射后的数字张量进行缩放,控制数值大小
"""
return self.lut(x) * math.sqrt(self.embedding_dim)
class PositionalEncoding(nn.Module):
def __init__(self, embedding_dim, dropout, max_len = 5000):
super().__init__()
self.dropout = nn.Dropout(p = dropout)
pos_encode = torch.zeros(max_len, embedding_dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2) *
- (math.log(1000.0) / embedding_dim))
pos_encode[:, 0::2] = torch.sin(position * div_term)
pos_encode[:, 1::2] = torch.cos(position * div_term)
pos_encode = pos_encode.unsqueeze(0)
self.register_buffer('pos_encode', pos_encode)
def forward(self, x):
x = x + Variable(self.pos_encode[:, :x.size(1)], requires_grad = False)
return self.dropout(x)
class MultiHeadedAttention(nn.Module):
def __init__(self, head, embedding_dim, dropout = 0.1):
super().__init__()
assert embedding_dim % head == 0
self.head_dim = embedding_dim // head
self.head = head
self.embedding_dim = embedding_dim
self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)
self.attention = None
self.dropout = nn.Dropout(p = dropout)
def forward(self, query, key, value, mask = None):
if mask is not None:
mask = mask.unsqueeze(1)
batch_size = query.size(0)
query, key, value = \
[model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)
for model, x in zip(self.linears, (query, key, value))]
x, self.attention = attention(query, key, value, mask = mask, dropout = self.dropout)
x = x.transpose(1, 2).view(batch_size, -1, self.head * self.head_dim)
return self.linears[-1](x)
class PositionwiseFeedForward(nn.Module):
def __init__(self, embedding_dim, ff_dim, dropout):
super().__init__()
self.l1 = nn.Linear(embedding_dim, ff_dim)
self.l2 = nn.Linear(ff_dim, embedding_dim)
self.dropout = nn.Dropout(p = dropout)
def forward(self, x):
return self.l2(self.dropout(F.relu(self.l1(x))))
class LayerNorm(nn.Module):
def __init__(self, embedding_dim, eps = 1e-6):
super().__init__()
self.w1 = nn.Parameter(torch.ones(embedding_dim))
self.w2 = nn.Parameter(torch.zeros(embedding_dim))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keeepdim = True)
std = x.std(-1, keepdim = True)
return self.w1 * (x - mean) / (std + self.eps) + self.w2
class SublayerConnection(nn.Module):
def __init__(self, embedding_dim, dropout = 0.1):
super().__init__()
self.norm = LayerNorm(embedding_dim)
self.dropout = nn.Dropout(p = dropout)
def forward(self, x, func):
return x + self.dropout(func(self.norm(x)))
方便重复实现相同结构的层
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def attention(query, key, value, mask = None, dropout = None):
embedding_dim = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(embedding_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim = -1)
if dropout is not None:
scores = dropout(scores)
return torch.matmul(scores, value), scores
class EncoderLayer(nn.Module):
def __init__(self, embedding_dim, self_attention, feed_forward, dropout):
super().__init__()
self.self_attention = self_attention
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(embedding_dim, dropout), 2)
self.embedding_dim = embedding_dim
def forward(self, x, mask):
x = self.sublayer[0](x, lambda x : self.self_attention(x, x, x, mask))
return self.sublayer[1](x, self.feed_forward)
class Encoder(nn.Module):
def __init__(self, encoder_layer, N):
super().__init__()
self.encoder_layers = clones(encoder_layer, N)
self.norm = LayerNorm(encoder_layer.embedding_dim)
def forward(self, x, mask):
for encoder_layer in self.encoder_layers:
x = encoder_layer(x, mask)
return self.norm(x)
class DecoderLayer(nn.Module):
def __init__(self, embedding_dim, self_attention, attention, feed_forward, dropout):
super().__init__()
self.embedding_dim = embedding_dim
self.self_attention = self_attention
self.attention = attention
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(embedding_dim, dropout), 3)
def forward(self, x, encode_kv, src_mask, trg_mask):
x = self.sublayer[0](x, lambda x : self.self_attention(x, x, x, trg_mask))
x = self.sublayer[1](x, lambda x : self.attention(x, encode_kv, encode_kv, src_mask))
return self.sublayer[2](x, self.feed_forward)
class Decoder(nn.Module):
def __init__(self, decoder_layer, N):
super().__init__()
self.decode_layers = clones(decoder_layer, N)
self.norm = LayerNorm(decoder_layer.embedding_dim)
def forward(self, x, encode_kv, src_mask, trg_msk):
for decoder_layer in self.decoder_layers:
x = decoder_layer(x, encode_kv, src_mask, trg_msk)
return self.norm(x)
class OutputLayer(nn.Module):
def __init__(self, embedding_dim, vocab_size):
super().__init__()
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, x):
return F.log_softmax(self.linear(x), dim = -1)
class Transofrmer(nn.Module):
def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.trg_embed = trg_embed
self.generator = generator
def forward(self, source, target, src_mask, trg_mask):
return self.decode(self.encode(source, src_mask), src_mask, target, trg_mask)
def encode(self, source, src_mask):
return self.encoder(self.src_embed(source), src_mask)
def decode(self, encode_kv, src_mask, target, trg_mask):
return self.decoder(self.trg_embed(target), encode_kv, src_mask, trg_mask)
def build_model(src_vocab_size, trg_vocab_size, N = 6, embedding_dim = 512, ff_dim = 2048, head = 8, dropout = 0.1):
c = copy.deepcopy
attention = MultiHeadedAttention(head, embedding_dim, dropout)
ff = PositionwiseFeedForward(embedding_dim, ff_dim, dropout)
position = PositionalEncoding(embedding_dim, dropout)
model = Transofrmer(
Encoder(EncoderLayer(embedding_dim, c(attention), c(ff), dropout), N),
Decoder(DecoderLayer(embedding_dim, c(attention), c(attention), c(ff), dropout), N),
nn.Sequential(Embeddings(src_vocab_size, embedding_dim), c(position)),
nn.Sequential(Embeddings(trg_vocab_size, embedding_dim), c(position)),
OutputLayer(embedding_dim, trg_vocab_size)
)
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
embedding_dim = 512
ff_dim = 64
head = 8
N = 6
dropout = 0.1
src_vocab_size = 10
trg_vocab_size = 10
model = build_model(src_vocab_size, trg_vocab_size, N, embedding_dim, ff_dim, head, dropout)
print(model)
Transofrmer(
(encoder): Encoder(
(encoder_layers): ModuleList(
(0): EncoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(1): EncoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(2): EncoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(3): EncoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(4): EncoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(5): EncoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(norm): LayerNorm()
)
(decoder): Decoder(
(decode_layers): ModuleList(
(0): DecoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(1): DecoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(2): DecoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(3): DecoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(4): DecoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(5): DecoderLayer(
(self_attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(attention): MultiHeadedAttention(
(linears): ModuleList(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Linear(in_features=512, out_features=512, bias=True)
(2): Linear(in_features=512, out_features=512, bias=True)
(3): Linear(in_features=512, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(feed_forward): PositionwiseFeedForward(
(l1): Linear(in_features=512, out_features=64, bias=True)
(l2): Linear(in_features=64, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(sublayer): ModuleList(
(0): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): SublayerConnection(
(norm): LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(norm): LayerNorm()
)
(src_embed): Sequential(
(0): Embeddings(
(lut): Embedding(10, 512)
)
(1): PositionalEncoding(
(dropout): Dropout(p=0.1, inplace=False)
)
)
(trg_embed): Sequential(
(0): Embeddings(
(lut): Embedding(10, 512)
)
(1): PositionalEncoding(
(dropout): Dropout(p=0.1, inplace=False)
)
)
(generator): OutputLayer(
(linear): Linear(in_features=512, out_features=10, bias=True)
)
)