前言:该教程利用fairseq增加一个新的FaiseqEncoderDecoderModel,该新模型利用LSTM来encodes一个sentence输入句;接着把最后的hidden state传给第二个LSTM,用于decodes出target sentence目标句。
教程包括:
a.编写Encoder和Decoder,分别用于encode、decode输入句与目标句;
b.注册一个适用于Command-line Tools的新模型;
c.使用现有的command-line tools来训练这个模型;
d.使用Incremental decoding,让修改Decoder更加迅速敏捷making generation faster;
1.搭建Encoder和Decoder
利用FairseqEncoder和FairDecoder接口来搭建自己的Encoder和Decoder,这两个接口继承torch.nn.Module,因此FairseqEncoder和FairseqDecoder能够使用传统的PyTorch Modules来进行编写。
Encoder
该encoder会在输入句中嵌入tokens,然后传入torch.nn.LSTM并且返回最终的隐藏状态hidden state。创建的encoder保存为fairseq/models/simple_lstm.py。
import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoder
class SimpleLSTMEncoder(FairseqEncoder):
def __init__(
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,
):
super().__init__(dictionary)
self.args = args
# Our encoder will embed the inputs before feeding them to the LSTM.
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
# We'll use a single-layer, unidirectional LSTM for simplicity.
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
batch_first=True,
)
def forward(self, src_tokens, src_lengths):
# The inputs to the ``forward()`` function are determined by the
# Task, and in particular the ``'net_input'`` key in each
# mini-batch. We discuss Tasks in the next tutorial, but for now just
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*
# has shape `(batch)`.
# Note that the source is typically padded on the left. This can be
# configured by adding the `--left-pad-source "False"` command-line
# argument, but here we'll make the Encoder handle either kind of
# padding by converting everything to be right-padded.
if self.args.left_pad_source:
# Convert left-padding to right-padding.
src_tokens = utils.convert_padding_direction(
src_tokens,
padding_idx=self.dictionary.pad(),
left_to_right=True
)
# Embed the source.
x = self.embed_tokens(src_tokens)
# Apply dropout.
x = self.dropout(x)
# Pack the sequence into a PackedSequence object to feed to the LSTM.
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)
# Get the output from the LSTM.
_outputs, (final_hidden, _final_cell) = self.lstm(x)
# Return the Encoder's output. This can be any object and will be
# passed directly to the Decoder.
return {
# this will have shape `(bsz, hidden_dim)`
'final_hidden': final_hidden.squeeze(0),
}
# Encoders are required to implement this method so that we can rearrange
# the order of the batch elements during inference (e.g., beam search).
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
final_hidden = encoder_out['final_hidden']
return {
'final_hidden': final_hidden.index_select(0, new_order),
}
Decoder
基于Encoder的最终hidden state和嵌入的目标单词target word,利用Decoder来预测下一个word,有时候也叫teacher forcing。具体来说,使用torch.nn.LSTM来生成一个隐藏状态hidden state序列,之后通过函数映射到输出的词汇中用于预测每一个目标词target word。
import torch
from fairseq.models import FairseqDecoder
class SimpleLSTMDecoder(FairseqDecoder):
def __init__(
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
dropout=0.1,
):
super().__init__(dictionary)
# Our decoder will embed the inputs before feeding them to the LSTM.
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
# We'll use a single-layer, unidirectional LSTM for simplicity.
self.lstm = nn.LSTM(
# For the first layer we'll concatenate the Encoder's final hidden
# state with the embedded target tokens.
input_size=encoder_hidden_dim + embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
)
# Define the output projection.
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
# During training Decoders are expected to take the entire target sequence
# (shifted right by one position) and produce logits over the vocabulary.
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,
# ``dictionary.eos()``, followed by the target sequence.
def forward(self, prev_output_tokens, encoder_out):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape
`(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape
`(batch, tgt_len, src_len)`
"""
bsz, tgt_len = prev_output_tokens.size()
# Extract the final hidden state from the Encoder.
final_encoder_hidden = encoder_out['final_hidden']
# Embed the target sequence, which has been shifted right by one
# position and now starts with the end-of-sentence symbol.
x = self.embed_tokens(prev_output_tokens)
# Apply dropout.
x = self.dropout(x)
# Concatenate the Encoder's final hidden state to *every* embedded
# target token.
x = torch.cat(
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
dim=2,
)
# Using PackedSequence objects in the Decoder is harder than in the
# Encoder, since the targets are not sorted in descending length order,
# which is a requirement of ``pack_padded_sequence()``. Instead we'll
# feed nn.LSTM directly.
initial_state = (
final_encoder_hidden.unsqueeze(0), # hidden
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
)
output, _ = self.lstm(
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)`
initial_state,
)
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)`
# Project the outputs to the size of the vocabulary.
x = self.output_projection(x)
# Return the logits and ``None`` for the attention weights
return x, None
2.注册模型(Registering the Model)
利用fairseq中的register_model()函数修饰器来注册模型。一旦模型注册成功,就可以使用Command-line Tools。
所有的注册模型必须实现BaseFairseqModel接口,对于sequence-to-sequence模型(如,任何模包含Encoder和Decoder的模型)需要使用FairseqEncoderDecoderModel接口。
创建SimpleLSTMModel类,在类中利用函数包装类wrapper class,命名该函数为simple_lstm。
from fairseq.models import FairseqEncoderDecoderModel, register_model
# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.
@register_model('simple_lstm')
class SimpleLSTMModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
# Models can override this method to add new command-line arguments.
# Here we'll add some new command-line arguments to configure dropout
# and the dimensionality of the embeddings and hidden states.
parser.add_argument(
'--encoder-embed-dim', type=int, metavar='N',
help='dimensionality of the encoder embeddings',
)
parser.add_argument(
'--encoder-hidden-dim', type=int, metavar='N',
help='dimensionality of the encoder hidden state',
)
parser.add_argument(
'--encoder-dropout', type=float, default=0.1,
help='encoder dropout probability',
)
parser.add_argument(
'--decoder-embed-dim', type=int, metavar='N',
help='dimensionality of the decoder embeddings',
)
parser.add_argument(
'--decoder-hidden-dim', type=int, metavar='N',
help='dimensionality of the decoder hidden state',
)
parser.add_argument(
'--decoder-dropout', type=float, default=0.1,
help='decoder dropout probability',
)
@classmethod
def build_model(cls, args, task):
# Fairseq initializes models by calling the ``build_model()``
# function. This provides more flexibility, since the returned model
# instance can be of a different type than the one that was called.
# In this case we'll just return a SimpleLSTMModel instance.
# Initialize our Encoder and Decoder.
encoder = SimpleLSTMEncoder(
args=args,
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
hidden_dim=args.encoder_hidden_dim,
dropout=args.encoder_dropout,
)
decoder = SimpleLSTMDecoder(
dictionary=task.target_dictionary,
encoder_hidden_dim=args.encoder_hidden_dim,
embed_dim=args.decoder_embed_dim,
hidden_dim=args.decoder_hidden_dim,
dropout=args.decoder_dropout,
)
model = SimpleLSTMModel(encoder, decoder)
# Print the model architecture.
print(model)
return model
# We could override the ``forward()`` if we wanted more control over how
# the encoder and decoder interact, but it's not necessary for this
# tutorial since we can inherit the default implementation provided by
# the FairseqEncoderDecoderModel base class, which looks like:
#
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
# encoder_out = self.encoder(src_tokens, src_lengths)
# decoder_out = self.decoder(prev_output_tokens, encoder_out)
# return decoder_out
最后利用这个configuration配置方式为模型定义一个architecture,这个定义结构的方式是通过register_model_architecture函数修饰器完成的。之后就可以使用--arch命令参数,如--arch tutorial_simple_lstm。
from fairseq.models import register_model_architecture
# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'simple_lstm'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
def tutorial_simple_lstm(args):
# We use ``getattr()`` to prioritize arguments that are explicitly given
# on the command-line, so that the defaults defined below are only used
# when no other value has been specified.
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)
3.模型训练
可以利用fairseq-train命令行工具Command-line tool来训练模型,并且确保你的新模型结构为(--arch tutorial_simple_lstm)
> fairseq-train data-bin/iwslt14.tokenized.de-en \
--arch tutorial_simple_lstm \
--encoder-dropout 0.2 --decoder-dropout 0.2 \
--optimizer adam --lr 0.005 --lr-shrink 0.5 \
--max-tokens 12000
(...)
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
The model files should appear in the checkpoints/ directory. While this model architecture is not
模型文件将存储到./checkpoints文件夹中。可以使用fairseq-generate命令在测试集上来计算BLEU score。
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
4.让代码飞起来Making generation faster
虽然sequence-to-sequence 模型天生就很慢,但我们上面的实现也特别慢,因为它为每个输出标记重新计算整个Decoder隐藏状态序列(即,它是O(n^2))。我们可以通过缓存以前的隐藏状态来大大加快这个过程。
在fairseq中,这被称为增量译码incremental decoding。增量解码是推理时的一种特殊模式,其中Model只接收与前一个输出令牌(用于教师强制)对应的单个时间步输入,并且必须增量地产生下一个输出。因此,模型必须缓存序列所需的任何长期状态,如隐藏状态、卷积状态等。
为了实现增量解码,将修改模型来实现FairseqIncrementalDecoder接口。与标准的FairseqDecoder接口相比,增量解码器接口允许forward()方法接受一个额外的关键字参数(incremental_state),该参数可用于跨时间步缓存状态。
用一个增量的替换的SimpleLSTMDecoder:
import torch
from fairseq.models import FairseqIncrementalDecoder
class SimpleLSTMDecoder(FairseqIncrementalDecoder):
def __init__(
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,
dropout=0.1,
):
# This remains the same as before.
super().__init__(dictionary)
self.embed_tokens = nn.Embedding(
num_embeddings=len(dictionary),
embedding_dim=embed_dim,
padding_idx=dictionary.pad(),
)
self.dropout = nn.Dropout(p=dropout)
self.lstm = nn.LSTM(
input_size=encoder_hidden_dim + embed_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=False,
)
self.output_projection = nn.Linear(hidden_dim, len(dictionary))
# We now take an additional kwarg (*incremental_state*) for caching the
# previous hidden and cell states.
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if incremental_state is not None:
# If the *incremental_state* argument is not ``None`` then we are
# in incremental inference mode. While *prev_output_tokens* will
# still contain the entire decoded prefix, we will only use the
# last step and assume that the rest of the state is cached.
prev_output_tokens = prev_output_tokens[:, -1:]
# This remains the same as before.
bsz, tgt_len = prev_output_tokens.size()
final_encoder_hidden = encoder_out['final_hidden']
x = self.embed_tokens(prev_output_tokens)
x = self.dropout(x)
x = torch.cat(
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
dim=2,
)
# We will now check the cache and load the cached previous hidden and
# cell states, if they exist, otherwise we will initialize them to
# zeros (as before). We will use the ``utils.get_incremental_state()``
# and ``utils.set_incremental_state()`` helpers.
initial_state = utils.get_incremental_state(
self, incremental_state, 'prev_state',
)
if initial_state is None:
# first time initialization, same as the original version
initial_state = (
final_encoder_hidden.unsqueeze(0), # hidden
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell
)
# Run one step of our LSTM.
output, latest_state = self.lstm(x.transpose(0, 1), initial_state)
# Update the cache with the latest hidden and cell states.
utils.set_incremental_state(
self, incremental_state, 'prev_state', latest_state,
)
# This remains the same as before
x = output.transpose(0, 1)
x = self.output_projection(x)
return x, None
# The ``FairseqIncrementalDecoder`` interface also requires implementing a
# ``reorder_incremental_state()`` method, which is used during beam search
# to select and reorder the incremental state.
def reorder_incremental_state(self, incremental_state, new_order):
# Load the cached state.
prev_state = utils.get_incremental_state(
self, incremental_state, 'prev_state',
)
# Reorder batches according to *new_order*.
reordered_state = (
prev_state[0].index_select(1, new_order), # hidden
prev_state[1].index_select(1, new_order), # cell
)
# Update the cached state.
utils.set_incremental_state(
self, incremental_state, 'prev_state', reordered_state,
)
最后,可以重新运行生成并观察加速情况:
# Before
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)
# After
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
--path checkpoints/checkpoint_best.pt \
--beam 5 \
--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)