使用的fairseq版本为 0.6.2
fairseq 进行beam search的逻辑位于 fairseq.sequence_generator.SequenceGanerator:generate
SequenceGenerator负责处理整个搜索的过程, 大致逻辑为
fairseq.search
在阅读的过程中, 注释不够, 所以直接将代码的逻辑添加到注释中方便看.
class SequenceGenerator(object):
def __init__(
self,
tgt_dict,
beam_size=1,
max_len_a=0,
max_len_b=200,
min_len=1,
stop_early=True,
normalize_scores=True,
len_penalty=1.,
unk_penalty=0.,
retain_dropout=False,
sampling=False,
sampling_topk=-1,
sampling_temperature=1.,
diverse_beam_groups=-1,
diverse_beam_strength=0.5,
match_source_len=False,
no_repeat_ngram_size=0,
):
"""Generates translations of a given source sentence.
Args:
tgt_dict (~fairseq.data.Dictionary): target dictionary
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
stop_early (bool, optional): stop generation immediately after we
finalize beam_size hypotheses, even though longer hypotheses
might have better normalized scores (default: True)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
sampling (bool, optional): sample outputs instead of beam search
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
sampling_temperature (float, optional): temperature for sampling,
where values >1.0 produces more uniform sampling and values
<1.0 produces sharper sampling (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
if sampling:
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
self.search = search.LengthConstrainedBeamSearch(
tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
)
else:
self.search = search.BeamSearch(tgt_dict)
@torch.no_grad()
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
"""
model = EnsembleModel(models)
if not self.retain_dropout:
model.eval()
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in sample['net_input'].items()
if k != 'prev_output_tokens'
}
src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
input_size = src_tokens.size()
# batch dimension goes first followed by source lengths
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
if self.match_source_len:
max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
model.max_decoder_positions() - 1,
)
# compute the encoder output for each beam
encoder_outs = model.forward_encoder(encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long() # (bsz*beam,) 每个位置表示对应batch样本的idx
# reorder_encoder_output根据new_order, 将encoder_outs每个位置的结果
# 和new_order指定的batch中的sent对应, 现在的encoder_outs包含了 bsz*beam大小的结果
encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
# initialize buffers
# scores记录的是累加的log prob, bxz*beam x max_len+1
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone()
# max_len不包括EOS, 开头和结尾各有一个, 所以+2
# bsz*beam x max_len+2
tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None
nonpad_idxs = None
# list of completed sentences
# 每个sent, 搜索到的EOS的hypo
finalized = [[] for i in range(bsz)]
# 每个sent是否完成搜索
finished = [False for i in range(bsz)]
# 每个sent的最差hypo的结果, 用于找到好的了进行替换
worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
# 还有几个sent没有搜索完成
num_remaining_sent = bsz
# number of candidate hypos per step
# 每个时刻在扩展的时候, 得到cand_size个候选, 把搜索到eos到放到finalized中去, 保证至少还剩beam个cand用于下次搜索
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
# 后面需要把 bsz x cand_size 的矩阵转化成一维, 用于记录原先是哪个sent的哪个cand
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfinalized_scores=None):
"""
如果一个句子的未完成hypo的得分已经比当前完成hypo的得分差, 就不用再搜索了.
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
possible score among unfinalized hypotheses.
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early or step == max_len or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
best_unfinalized_score /= max_len ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
"""
对每个step, 到达eos状态的hypo, 放到相应的finalize中去, 并更新finished状态.
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
# 累积每个sent前面有多少个完成的sent, 这样得到该sent对应到最初batch中是哪个idx
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
unfin_idx = idx // beam_size
# 每个没有完成的句子, 前面有多少完成的句子, 用于确定当前句子位于最初batch中哪个位置
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
if self.match_source_len and step > src_lengths[unfin_idx]:
score = -math.inf
def get_hypo():
"""得到hypo, 存起全部的信息."""
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
'score': score,
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif not self.stop_early and score > worst_finalized[sent]['score']:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]['idx']
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
worst_finalized[sent] = {
'score': s['score'],
'idx': idx,
}
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
# 记录完成的sent对应的batch idx
newly_finished.append(unfin_idx)
return newly_finished
reorder_state = None
batch_idxs = None
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
# reorder_state是根据new_bsz得到的bbsz_idx, 因此要减去batch_idxs, 加上batch_idxs换算到原来的batch idx,
# 因为reorder使用idx是上次的batch idx, 不是new_bsz对应的batch idx
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
model.reorder_incremental_state(reorder_state)
model.reorder_encoder_out(encoder_outs, reorder_state)
# decoder一步, 得到预测的log prob和attn
# lprobs: bsz*beam x vocab_size
lprobs, avg_attn_scores = model.forward_decoder(tokens[:, :step + 1], encoder_outs)
lprobs[:, self.pad] = -math.inf # never select pad, bsz*beam x vocab_size
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
# 指定不能有重复的ngram, 得到banned tokens
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
gen_tokens = tokens[bbsz_idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(lprobs)
scores_buf = scores_buf.type_as(lprobs)
# 哪些cand是搜索到了eos
eos_bbsz_idx = buffer('eos_bbsz_idx')
# 对应的scores
eos_scores = buffer('eos_scores', type_of=scores)
if step < max_len:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1):
# 如果指定了prefix tokens, 就选取相应的cand_scores, cand_indices, cand_beams
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1)
).view(-1, 1).repeat(1, cand_size)
if step > 0:
# save cumulative scores for each hypothesis
cand_scores.add_(scores[:, step - 1].view(bsz, beam_size).repeat(1, 2))
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
if partial_prefix_mask.any():
partial_scores, partial_indices, partial_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
else:
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
else:
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, step - 1].unsqueeze(-1))
# finalize all active hypotheses once we hit max_len
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
lprobs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= len(finalize_hypos(step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
# 将每个cand赋予一个batch*beam的index, 为了下面进行index选择和转换成一维使用
cand_bbsz_idx = cand_beams.add(bbsz_offsets) # bsz x cand_size
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos) # bsz x cand_size
finalized_sents = set() # set没有用到, 后面使用的是list的形式, 表示当前batch哪些sent完成了搜索
if step >= self.min_len:
# only consider eos when it's among the top beam_size indices
# 只考虑前beam大小的, 不然考虑整个cand_size的话, beam size就不对了.
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx,
)
if eos_bbsz_idx.numel() > 0:
torch.masked_select(
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < max_len
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = cand_indices.new_ones(bsz) # bxz,
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1) # new_bsz,
eos_mask = eos_mask[batch_idxs] # new_bsz x vocab_size
cand_beams = cand_beams[batch_idxs] # new_bsz x cand_size
bbsz_offsets.resize_(new_bsz, 1) # 取前new_bsz个
cand_bbsz_idx = cand_beams.add(bbsz_offsets) # new_bsz x cand_size
cand_scores = cand_scores[batch_idxs] # new_bsz x cand_size x vocab_size
cand_indices = cand_indices[batch_idxs] # new_bsz x cand_size
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths = src_lengths[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores) # 已经和scores在batch idx对不上了, 但是后面没用到读取的数值, 不影响
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# 将以eos结尾的排在后面, 取前beam_size个, 作为下次的hypo进行扩展
active_mask = buffer('active_mask') # new_bsz x cand_size
torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)],
out=active_mask,
)
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
# active_hypos: new_bsz x beam_size
active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
torch.topk(
active_mask, k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos)
)
active_bbsz_idx = buffer('active_bbsz_idx') # new_bsz x beam_size
torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos,
out=active_bbsz_idx,
)
active_scores = torch.gather(
cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step + 1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx,
out=scores_buf[:, :step],
)
torch.gather(
cand_scores, dim=1, index=active_hypos,
out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
)
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized