Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import namedtuple | |
import numpy as np | |
import torch | |
from fairseq import utils | |
DecoderOut = namedtuple( | |
"IterativeRefinementDecoderOut", | |
["output_tokens", "output_scores", "attn", "step", "max_step", "history"], | |
) | |
class IterativeRefinementGenerator(object): | |
def __init__( | |
self, | |
tgt_dict, | |
models=None, | |
eos_penalty=0.0, | |
max_iter=10, | |
max_ratio=2, | |
beam_size=1, | |
decoding_format=None, | |
retain_dropout=False, | |
adaptive=True, | |
retain_history=False, | |
reranking=False, | |
): | |
""" | |
Generates translations based on iterative refinement. | |
Args: | |
tgt_dict: target dictionary | |
eos_penalty: if > 0.0, it penalized early-stopping in decoding | |
max_iter: maximum number of refinement iterations | |
max_ratio: generate sequences of maximum length ax, where x is the source length | |
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} | |
retain_dropout: retaining dropout in the inference | |
adaptive: decoding with early stop | |
""" | |
self.bos = tgt_dict.bos() | |
self.pad = tgt_dict.pad() | |
self.unk = tgt_dict.unk() | |
self.eos = tgt_dict.eos() | |
self.vocab_size = len(tgt_dict) | |
self.eos_penalty = eos_penalty | |
self.max_iter = max_iter | |
self.max_ratio = max_ratio | |
self.beam_size = beam_size | |
self.reranking = reranking | |
self.decoding_format = decoding_format | |
self.retain_dropout = retain_dropout | |
self.retain_history = retain_history | |
self.adaptive = adaptive | |
self.models = models | |
def generate_batched_itr( | |
self, | |
data_itr, | |
maxlen_a=None, | |
maxlen_b=None, | |
cuda=False, | |
timer=None, | |
prefix_size=0, | |
): | |
"""Iterate over a batched dataset and yield individual translations. | |
Args: | |
maxlen_a/b: generate sequences of maximum length ax + b, | |
where x is the source sentence length. | |
cuda: use GPU for generation | |
timer: StopwatchMeter for timing generations. | |
""" | |
for sample in data_itr: | |
if "net_input" not in sample: | |
continue | |
if timer is not None: | |
timer.start() | |
with torch.no_grad(): | |
hypos = self.generate( | |
self.models, | |
sample, | |
prefix_tokens=sample["target"][:, :prefix_size] | |
if prefix_size > 0 | |
else None, | |
) | |
if timer is not None: | |
timer.stop(sample["ntokens"]) | |
for i, id in enumerate(sample["id"]): | |
# remove padding | |
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) | |
ref = utils.strip_pad(sample["target"][i, :], self.pad) | |
yield id, src, ref, hypos[i] | |
def generate(self, models, sample, prefix_tokens=None, constraints=None): | |
if constraints is not None: | |
raise NotImplementedError( | |
"Constrained decoding with the IterativeRefinementGenerator is not supported" | |
) | |
# TODO: iterative refinement generator does not support ensemble for now. | |
if not self.retain_dropout: | |
for model in models: | |
model.eval() | |
model, reranker = models[0], None | |
if self.reranking: | |
assert len(models) > 1, "Assuming the last checkpoint is the reranker" | |
assert ( | |
self.beam_size > 1 | |
), "Reranking requires multiple translation for each example" | |
reranker = models[-1] | |
models = models[:-1] | |
if len(models) > 1 and hasattr(model, "enable_ensemble"): | |
assert model.allow_ensemble, "{} does not support ensembling".format( | |
model.__class__.__name__ | |
) | |
model.enable_ensemble(models) | |
# TODO: better encoder inputs? | |
src_tokens = sample["net_input"]["src_tokens"] | |
src_lengths = sample["net_input"]["src_lengths"] | |
bsz, src_len = src_tokens.size() | |
# initialize | |
encoder_out = model.forward_encoder([src_tokens, src_lengths]) | |
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) | |
if self.beam_size > 1: | |
assert ( | |
model.allow_length_beam | |
), "{} does not support decoding with length beam.".format( | |
model.__class__.__name__ | |
) | |
# regenerate data based on length-beam | |
length_beam_order = ( | |
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) | |
) | |
encoder_out = model.encoder.reorder_encoder_out( | |
encoder_out, length_beam_order | |
) | |
prev_decoder_out = model.regenerate_length_beam( | |
prev_decoder_out, self.beam_size | |
) | |
bsz = bsz * self.beam_size | |
sent_idxs = torch.arange(bsz) | |
prev_output_tokens = prev_decoder_out.output_tokens.clone() | |
if self.retain_history: | |
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) | |
finalized = [[] for _ in range(bsz)] | |
def is_a_loop(x, y, s, a): | |
b, l_x, l_y = x.size(0), x.size(1), y.size(1) | |
if l_x > l_y: | |
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) | |
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) | |
if a is not None: | |
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) | |
elif l_x < l_y: | |
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) | |
return (x == y).all(1), y, s, a | |
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): | |
cutoff = prev_out_token.ne(self.pad) | |
tokens = prev_out_token[cutoff] | |
if prev_out_score is None: | |
scores, score = None, None | |
else: | |
scores = prev_out_score[cutoff] | |
score = scores.mean() | |
if prev_out_attn is None: | |
hypo_attn, alignment = None, None | |
else: | |
hypo_attn = prev_out_attn[cutoff] | |
alignment = hypo_attn.max(dim=1)[1] | |
return { | |
"steps": step, | |
"tokens": tokens, | |
"positional_scores": scores, | |
"score": score, | |
"hypo_attn": hypo_attn, | |
"alignment": alignment, | |
} | |
for step in range(self.max_iter + 1): | |
decoder_options = { | |
"eos_penalty": self.eos_penalty, | |
"max_ratio": self.max_ratio, | |
"decoding_format": self.decoding_format, | |
} | |
prev_decoder_out = prev_decoder_out._replace( | |
step=step, | |
max_step=self.max_iter + 1, | |
) | |
decoder_out = model.forward_decoder( | |
prev_decoder_out, encoder_out, **decoder_options | |
) | |
if self.adaptive: | |
# terminate if there is a loop | |
terminated, out_tokens, out_scores, out_attn = is_a_loop( | |
prev_output_tokens, | |
decoder_out.output_tokens, | |
decoder_out.output_scores, | |
decoder_out.attn, | |
) | |
decoder_out = decoder_out._replace( | |
output_tokens=out_tokens, | |
output_scores=out_scores, | |
attn=out_attn, | |
) | |
else: | |
terminated = decoder_out.output_tokens.new_zeros( | |
decoder_out.output_tokens.size(0) | |
).bool() | |
if step == self.max_iter: # reach last iteration, terminate | |
terminated.fill_(1) | |
# collect finalized sentences | |
finalized_idxs = sent_idxs[terminated] | |
finalized_tokens = decoder_out.output_tokens[terminated] | |
finalized_scores = decoder_out.output_scores[terminated] | |
finalized_attn = ( | |
None | |
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) | |
else decoder_out.attn[terminated] | |
) | |
if self.retain_history: | |
finalized_history_tokens = [h[terminated] for h in decoder_out.history] | |
for i in range(finalized_idxs.size(0)): | |
finalized[finalized_idxs[i]] = [ | |
finalized_hypos( | |
step, | |
finalized_tokens[i], | |
finalized_scores[i], | |
None if finalized_attn is None else finalized_attn[i], | |
) | |
] | |
if self.retain_history: | |
finalized[finalized_idxs[i]][0]["history"] = [] | |
for j in range(len(finalized_history_tokens)): | |
finalized[finalized_idxs[i]][0]["history"].append( | |
finalized_hypos( | |
step, finalized_history_tokens[j][i], None, None | |
) | |
) | |
# check if all terminated | |
if terminated.sum() == terminated.size(0): | |
break | |
# for next step | |
not_terminated = ~terminated | |
prev_decoder_out = decoder_out._replace( | |
output_tokens=decoder_out.output_tokens[not_terminated], | |
output_scores=decoder_out.output_scores[not_terminated], | |
attn=decoder_out.attn[not_terminated] | |
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) | |
else None, | |
history=[h[not_terminated] for h in decoder_out.history] | |
if decoder_out.history is not None | |
else None, | |
) | |
encoder_out = model.encoder.reorder_encoder_out( | |
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() | |
) | |
sent_idxs = sent_idxs[not_terminated] | |
prev_output_tokens = prev_decoder_out.output_tokens.clone() | |
if self.beam_size > 1: | |
if reranker is not None: | |
finalized = self.rerank( | |
reranker, finalized, [src_tokens, src_lengths], self.beam_size | |
) | |
# aggregate information from length beam | |
finalized = [ | |
finalized[ | |
np.argmax( | |
[ | |
finalized[self.beam_size * i + j][0]["score"] | |
for j in range(self.beam_size) | |
] | |
) | |
+ self.beam_size * i | |
] | |
for i in range(len(finalized) // self.beam_size) | |
] | |
return finalized | |
def rerank(self, reranker, finalized, encoder_input, beam_size): | |
def rebuild_batch(finalized): | |
finalized_tokens = [f[0]["tokens"] for f in finalized] | |
finalized_maxlen = max(f.size(0) for f in finalized_tokens) | |
final_output_tokens = ( | |
finalized_tokens[0] | |
.new_zeros(len(finalized_tokens), finalized_maxlen) | |
.fill_(self.pad) | |
) | |
for i, f in enumerate(finalized_tokens): | |
final_output_tokens[i, : f.size(0)] = f | |
return final_output_tokens | |
final_output_tokens = rebuild_batch(finalized) | |
final_output_tokens[ | |
:, 0 | |
] = self.eos # autoregressive model assumes starting with EOS | |
reranker_encoder_out = reranker.encoder(*encoder_input) | |
length_beam_order = ( | |
utils.new_arange( | |
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) | |
) | |
.t() | |
.reshape(-1) | |
) | |
reranker_encoder_out = reranker.encoder.reorder_encoder_out( | |
reranker_encoder_out, length_beam_order | |
) | |
reranking_scores = reranker.get_normalized_probs( | |
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), | |
True, | |
None, | |
) | |
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) | |
reranking_masks = final_output_tokens[:, 1:].ne(self.pad) | |
reranking_scores = ( | |
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) | |
) | |
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( | |
reranking_scores | |
) | |
for i in range(len(finalized)): | |
finalized[i][0]["score"] = reranking_scores[i] | |
return finalized | |