Spaces:
Runtime error
Runtime error
| from typing import Dict, Optional | |
| import flax | |
| import jax | |
| import jax.numpy as jnp | |
| import jaxlib.xla_extension as jax_xla | |
| import numpy as np | |
| from jax import lax | |
| from transformers.file_utils import ModelOutput | |
| from transformers.generation_flax_logits_process import ( | |
| FlaxForcedBOSTokenLogitsProcessor, | |
| FlaxForcedEOSTokenLogitsProcessor, | |
| FlaxLogitsProcessorList, | |
| FlaxMinLengthLogitsProcessor, | |
| FlaxTemperatureLogitsWarper, | |
| FlaxTopKLogitsWarper, | |
| FlaxTopPLogitsWarper, | |
| ) | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class FlaxGreedySearchOutput(ModelOutput): | |
| """ | |
| Flax Base class for outputs of decoder-only generation models using greedy search. | |
| Args: | |
| sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): | |
| The generated sequences. | |
| """ | |
| sequences: jax_xla.DeviceArray = None | |
| class FlaxSampleOutput(ModelOutput): | |
| """ | |
| Flax Base class for outputs of decoder-only generation models using sampling. | |
| Args: | |
| sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): | |
| The generated sequences. | |
| """ | |
| sequences: jax_xla.DeviceArray = None | |
| class FlaxBeamSearchOutput(ModelOutput): | |
| """ | |
| Flax Base class for outputs of decoder-only generation models using greedy search. | |
| Args: | |
| sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): | |
| The generated sequences. | |
| scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`): | |
| The scores (log probabilites) of the generated sequences. | |
| """ | |
| sequences: jax_xla.DeviceArray = None | |
| scores: jax_xla.DeviceArray = None | |
| class GreedyState: | |
| cur_len: jax_xla.DeviceArray | |
| sequences: jax_xla.DeviceArray | |
| running_token: jax_xla.DeviceArray | |
| is_sent_finished: jax_xla.DeviceArray | |
| model_kwargs: Dict[str, jax_xla.DeviceArray] | |
| class SampleState: | |
| cur_len: jax_xla.DeviceArray | |
| sequences: jax_xla.DeviceArray | |
| running_token: jax_xla.DeviceArray | |
| is_sent_finished: jax_xla.DeviceArray | |
| prng_key: jax_xla.DeviceArray | |
| model_kwargs: Dict[str, jax_xla.DeviceArray] | |
| class BeamSearchState: | |
| cur_len: jax_xla.DeviceArray | |
| running_sequences: jax_xla.DeviceArray | |
| running_scores: jax_xla.DeviceArray | |
| sequences: jax_xla.DeviceArray | |
| scores: jax_xla.DeviceArray | |
| is_sent_finished: jax_xla.DeviceArray | |
| model_kwargs: Dict[str, jax_xla.DeviceArray] | |
| class FlaxCLIPVisionMarianGenerationMixin: | |
| """ | |
| A class containing all of the functions supporting generation, to be used as a mixin in | |
| :class:`~transformers.FlaxPreTrainedModel`. | |
| """ | |
| def _run_loop_in_debug(cond_fn, body_fn, init_state): | |
| """ | |
| Run generation in untraced mode. This should only be used for debugging purposes. | |
| """ | |
| state = init_state | |
| while cond_fn(state): | |
| state = body_fn(state) | |
| return state | |
| def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs): | |
| encoder_kwargs = { | |
| argument: value | |
| for argument, value in model_kwargs.items() | |
| if not ( | |
| argument.startswith("decoder_") or argument.startswith("cross_attn") | |
| ) | |
| } | |
| model_kwargs["encoder_outputs"] = self.encode( | |
| input_ids, return_dict=True, **encoder_kwargs | |
| ) | |
| return model_kwargs | |
| def _expand_to_num_beams(tensor, num_beams): | |
| return jnp.broadcast_to( | |
| tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:] | |
| ) | |
| def generate( | |
| self, | |
| input_ids: jax_xla.DeviceArray, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| bos_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| decoder_start_token_id: Optional[int] = None, | |
| do_sample: Optional[bool] = None, | |
| prng_key: Optional[jax_xla.DeviceArray] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| temperature: Optional[float] = None, | |
| num_beams: Optional[int] = None, | |
| no_repeat_ngram_size: Optional[int] = None, | |
| min_length: Optional[int] = None, | |
| forced_bos_token_id: Optional[int] = None, | |
| forced_eos_token_id: Optional[int] = None, | |
| length_penalty: Optional[float] = None, | |
| early_stopping: Optional[bool] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| **model_kwargs, | |
| ): | |
| r""" | |
| Generates sequences for models with a language modeling head. The method currently supports greedy decoding, | |
| and, multinomial sampling. | |
| Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same | |
| name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the | |
| default values of those config. | |
| Most of these parameters are explained in more detail in `this blog post | |
| <https://huggingface.co/blog/how-to-generate>`__. | |
| Parameters: | |
| input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`): | |
| The sequence used as a prompt for the generation. | |
| max_length (:obj:`int`, `optional`, defaults to 20): | |
| The maximum length of the sequence to be generated. | |
| do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Whether or not to use sampling ; use greedy decoding otherwise. | |
| temperature (:obj:`float`, `optional`, defaults to 1.0): | |
| The value used to module the next token probabilities. | |
| top_k (:obj:`int`, `optional`, defaults to 50): | |
| The number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (:obj:`float`, `optional`, defaults to 1.0): | |
| If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or | |
| higher are kept for generation. | |
| pad_token_id (:obj:`int`, `optional`): | |
| The id of the `padding` token. | |
| bos_token_id (:obj:`int`, `optional`): | |
| The id of the `beginning-of-sequence` token. | |
| eos_token_id (:obj:`int`, `optional`): | |
| The id of the `end-of-sequence` token. | |
| num_beams (:obj:`int`, `optional`, defaults to 1): | |
| Number of beams for beam search. 1 means no beam search. | |
| decoder_start_token_id (:obj:`int`, `optional`): | |
| If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. | |
| trace (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
| Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to | |
| a considerably slower runtime. | |
| params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`): | |
| Optionally the model parameters can be passed. Can be useful for parallelized generation. | |
| model_kwargs: | |
| Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. | |
| Return: | |
| :class:`~transformers.file_utils.ModelOutput`. | |
| Examples:: | |
| >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM | |
| >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
| >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") | |
| >>> input_context = "The dog" | |
| >>> # encode input context | |
| >>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids | |
| >>> # generate candidates using sampling | |
| >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) | |
| >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) | |
| """ | |
| # set init values | |
| max_length = ( | |
| max_length | |
| if max_length is not None | |
| else self.config.marian_config.max_length | |
| ) | |
| bos_token_id = ( | |
| bos_token_id | |
| if bos_token_id is not None | |
| else self.config.marian_config.bos_token_id | |
| ) | |
| pad_token_id = ( | |
| pad_token_id | |
| if pad_token_id is not None | |
| else self.config.marian_config.pad_token_id | |
| ) | |
| eos_token_id = ( | |
| eos_token_id | |
| if eos_token_id is not None | |
| else self.config.marian_config.eos_token_id | |
| ) | |
| decoder_start_token_id = ( | |
| decoder_start_token_id | |
| if decoder_start_token_id | |
| else self.config.marian_config.decoder_start_token_id | |
| ) | |
| prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) | |
| if decoder_start_token_id is None and self.config.is_encoder_decoder: | |
| raise ValueError( | |
| "`decoder_start_token_id` has to be defined for encoder-decoder generation." | |
| ) | |
| if self.config.is_encoder_decoder: | |
| # add encoder_outputs to model_kwargs | |
| model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( | |
| input_ids, model_kwargs | |
| ) | |
| # prepare decoder_input_ids for generation | |
| input_ids = ( | |
| jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id | |
| ) | |
| do_sample = ( | |
| do_sample if do_sample is not None else self.config.marian_config.do_sample | |
| ) | |
| num_beams = ( | |
| num_beams if num_beams is not None else self.config.marian_config.num_beams | |
| ) | |
| if not do_sample and num_beams == 1: | |
| logits_processor = self._get_logits_processor( | |
| no_repeat_ngram_size, | |
| min_length, | |
| max_length, | |
| eos_token_id, | |
| forced_bos_token_id, | |
| forced_eos_token_id, | |
| ) | |
| return self._greedy_search( | |
| input_ids, | |
| max_length, | |
| pad_token_id, | |
| eos_token_id, | |
| logits_processor=logits_processor, | |
| trace=trace, | |
| params=params, | |
| model_kwargs=model_kwargs, | |
| ) | |
| elif do_sample and num_beams == 1: | |
| logits_warper = self._get_logits_warper( | |
| top_k=top_k, top_p=top_p, temperature=temperature | |
| ) | |
| logits_processor = self._get_logits_processor( | |
| no_repeat_ngram_size, | |
| min_length, | |
| max_length, | |
| eos_token_id, | |
| forced_bos_token_id, | |
| forced_eos_token_id, | |
| ) | |
| return self._sample( | |
| input_ids, | |
| max_length, | |
| pad_token_id, | |
| eos_token_id, | |
| prng_key, | |
| logits_warper=logits_warper, | |
| logits_processor=logits_processor, | |
| trace=trace, | |
| params=params, | |
| model_kwargs=model_kwargs, | |
| ) | |
| elif not do_sample and num_beams > 1: | |
| # broadcast input_ids & encoder_outputs | |
| input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams) | |
| if "encoder_outputs" in model_kwargs: | |
| model_kwargs["encoder_outputs"][ | |
| "last_hidden_state" | |
| ] = self._expand_to_num_beams( | |
| model_kwargs["encoder_outputs"]["last_hidden_state"], | |
| num_beams=num_beams, | |
| ) | |
| if "attention_mask" in model_kwargs: | |
| model_kwargs["attention_mask"] = self._expand_to_num_beams( | |
| model_kwargs["attention_mask"], num_beams=num_beams | |
| ) | |
| logits_processor = self._get_logits_processor( | |
| no_repeat_ngram_size, | |
| min_length, | |
| max_length, | |
| eos_token_id, | |
| forced_bos_token_id, | |
| forced_eos_token_id, | |
| ) | |
| return self._beam_search( | |
| input_ids, | |
| max_length, | |
| pad_token_id, | |
| eos_token_id, | |
| length_penalty=length_penalty, | |
| early_stopping=early_stopping, | |
| logits_processor=logits_processor, | |
| trace=trace, | |
| params=params, | |
| model_kwargs=model_kwargs, | |
| ) | |
| else: | |
| raise NotImplementedError("`Beam sampling is currently not implemented.") | |
| def _get_logits_warper( | |
| self, top_k: int = None, top_p: float = None, temperature: float = None | |
| ) -> FlaxLogitsProcessorList: | |
| """ | |
| This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant | |
| :obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling. | |
| """ | |
| # init warp parameters | |
| top_k = top_k if top_k is not None else self.config.marian_config.top_k | |
| top_p = top_p if top_p is not None else self.config.marian_config.top_p | |
| temperature = ( | |
| temperature | |
| if temperature is not None | |
| else self.config.marian_config.temperature | |
| ) | |
| # instantiate warpers list | |
| warpers = FlaxLogitsProcessorList() | |
| # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
| # all samplers can be found in `generation_utils_samplers.py` | |
| if temperature is not None and temperature != 1.0: | |
| warpers.append(FlaxTemperatureLogitsWarper(temperature)) | |
| if top_k is not None and top_k != 0: | |
| warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1)) | |
| if top_p is not None and top_p < 1.0: | |
| warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) | |
| return warpers | |
| def _get_logits_processor( | |
| self, | |
| no_repeat_ngram_size: int, | |
| min_length: int, | |
| max_length: int, | |
| eos_token_id: int, | |
| forced_bos_token_id: int, | |
| forced_eos_token_id: int, | |
| ) -> FlaxLogitsProcessorList: | |
| """ | |
| This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant | |
| :obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head. | |
| """ | |
| processors = FlaxLogitsProcessorList() | |
| # init warp parameters | |
| no_repeat_ngram_size = ( | |
| no_repeat_ngram_size | |
| if no_repeat_ngram_size is not None | |
| else self.config.marian_config.no_repeat_ngram_size | |
| ) | |
| min_length = ( | |
| min_length | |
| if min_length is not None | |
| else self.config.marian_config.min_length | |
| ) | |
| eos_token_id = ( | |
| eos_token_id | |
| if eos_token_id is not None | |
| else self.config.marian_config.eos_token_id | |
| ) | |
| forced_bos_token_id = ( | |
| forced_bos_token_id | |
| if forced_bos_token_id is not None | |
| else self.config.marian_config.forced_bos_token_id | |
| ) | |
| forced_eos_token_id = ( | |
| forced_eos_token_id | |
| if forced_eos_token_id is not None | |
| else self.config.marian_config.forced_eos_token_id | |
| ) | |
| # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
| # all samplers can be found in `generation_utils_samplers.py` | |
| if min_length is not None and eos_token_id is not None and min_length > -1: | |
| processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id)) | |
| if forced_bos_token_id is not None: | |
| processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id)) | |
| if forced_eos_token_id is not None: | |
| processors.append( | |
| FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id) | |
| ) | |
| return processors | |
| def _greedy_search( | |
| self, | |
| input_ids: None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| ): | |
| # init values | |
| max_length = ( | |
| max_length | |
| if max_length is not None | |
| else self.config.marian_config.max_length | |
| ) | |
| pad_token_id = ( | |
| pad_token_id | |
| if pad_token_id is not None | |
| else self.config.marian_config.pad_token_id | |
| ) | |
| eos_token_id = ( | |
| eos_token_id | |
| if eos_token_id is not None | |
| else self.config.marian_config.eos_token_id | |
| ) | |
| batch_size, cur_len = input_ids.shape | |
| eos_token_id = jnp.array(eos_token_id) | |
| pad_token_id = jnp.array(pad_token_id) | |
| cur_len = jnp.array(cur_len) | |
| # per batch-item holding current token in loop. | |
| sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) | |
| sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) | |
| # per batch-item state bit indicating if sentence has finished. | |
| is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) | |
| # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop | |
| # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. | |
| model = self.decode if self.config.is_encoder_decoder else self | |
| # initialize model specific kwargs | |
| model_kwargs = self.prepare_inputs_for_generation( | |
| input_ids, max_length, **model_kwargs | |
| ) | |
| # initialize state | |
| state = GreedyState( | |
| cur_len=cur_len, | |
| sequences=sequences, | |
| running_token=input_ids, | |
| is_sent_finished=is_sent_finished, | |
| model_kwargs=model_kwargs, | |
| ) | |
| def greedy_search_cond_fn(state): | |
| """state termination condition fn.""" | |
| has_reached_max_length = state.cur_len == max_length | |
| all_sequence_finished = jnp.all(state.is_sent_finished) | |
| finish_generation = jnp.logical_or( | |
| has_reached_max_length, all_sequence_finished | |
| ) | |
| return ~finish_generation | |
| def greedy_search_body_fn(state): | |
| """state update fn.""" | |
| model_outputs = model( | |
| state.running_token, params=params, **state.model_kwargs | |
| ) | |
| logits = model_outputs.logits[:, -1] | |
| # apply min_length, ... | |
| logits = logits_processor(state.sequences, logits, state.cur_len) | |
| next_token = jnp.argmax(logits, axis=-1) | |
| next_is_sent_finished = state.is_sent_finished | ( | |
| next_token == eos_token_id | |
| ) | |
| next_token = ( | |
| next_token * ~next_is_sent_finished | |
| + pad_token_id * next_is_sent_finished | |
| ) | |
| next_token = next_token[:, None] | |
| next_sequences = lax.dynamic_update_slice( | |
| state.sequences, next_token, (0, state.cur_len) | |
| ) | |
| next_model_kwargs = self.update_inputs_for_generation( | |
| model_outputs, state.model_kwargs | |
| ) | |
| return GreedyState( | |
| cur_len=state.cur_len + 1, | |
| sequences=next_sequences, | |
| running_token=next_token, | |
| is_sent_finished=next_is_sent_finished, | |
| model_kwargs=next_model_kwargs, | |
| ) | |
| # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU | |
| if input_ids.shape[1] > 1: | |
| state = greedy_search_body_fn(state) | |
| if not trace: | |
| state = self._run_loop_in_debug( | |
| greedy_search_cond_fn, greedy_search_body_fn, state | |
| ) | |
| else: | |
| state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state) | |
| return FlaxGreedySearchOutput(sequences=state.sequences) | |
| def _sample( | |
| self, | |
| input_ids: None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| prng_key: Optional[jax_xla.DeviceArray] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| logits_warper: Optional[FlaxLogitsProcessorList] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| ): | |
| # init values | |
| max_length = ( | |
| max_length | |
| if max_length is not None | |
| else self.config.marian_config.max_length | |
| ) | |
| pad_token_id = ( | |
| pad_token_id | |
| if pad_token_id is not None | |
| else self.config.marian_config.pad_token_id | |
| ) | |
| eos_token_id = ( | |
| eos_token_id | |
| if eos_token_id is not None | |
| else self.config.marian_config.eos_token_id | |
| ) | |
| prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) | |
| batch_size, cur_len = input_ids.shape | |
| eos_token_id = jnp.array(eos_token_id) | |
| pad_token_id = jnp.array(pad_token_id) | |
| cur_len = jnp.array(cur_len) | |
| # per batch-item holding current token in loop. | |
| sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) | |
| sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) | |
| # per batch-item state bit indicating if sentence has finished. | |
| is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) | |
| # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop | |
| # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. | |
| model = self.decode if self.config.is_encoder_decoder else self | |
| # initialize model specific kwargs | |
| model_kwargs = self.prepare_inputs_for_generation( | |
| input_ids, max_length, **model_kwargs | |
| ) | |
| # initialize state | |
| state = SampleState( | |
| cur_len=cur_len, | |
| sequences=sequences, | |
| running_token=input_ids, | |
| is_sent_finished=is_sent_finished, | |
| prng_key=prng_key, | |
| model_kwargs=model_kwargs, | |
| ) | |
| def sample_search_cond_fn(state): | |
| """state termination condition fn.""" | |
| has_reached_max_length = state.cur_len == max_length | |
| all_sequence_finished = jnp.all(state.is_sent_finished) | |
| finish_generation = jnp.logical_or( | |
| has_reached_max_length, all_sequence_finished | |
| ) | |
| return ~finish_generation | |
| def sample_search_body_fn(state): | |
| """state update fn.""" | |
| prng_key, prng_key_next = jax.random.split(state.prng_key) | |
| model_outputs = model( | |
| state.running_token, params=params, **state.model_kwargs | |
| ) | |
| logits = model_outputs.logits[:, -1] | |
| # apply min_length, ... | |
| logits = logits_processor(state.sequences, logits, state.cur_len) | |
| # apply top_k, top_k, temperature | |
| logits = logits_warper(logits, logits, state.cur_len) | |
| next_token = jax.random.categorical( | |
| prng_key, model_outputs.logits[:, -1], axis=-1 | |
| ) | |
| next_is_sent_finished = state.is_sent_finished | ( | |
| next_token == eos_token_id | |
| ) | |
| next_token = ( | |
| next_token * ~next_is_sent_finished | |
| + pad_token_id * next_is_sent_finished | |
| ) | |
| next_token = next_token[:, None] | |
| next_sequences = lax.dynamic_update_slice( | |
| state.sequences, next_token, (0, state.cur_len) | |
| ) | |
| next_model_kwargs = self.update_inputs_for_generation( | |
| model_outputs, state.model_kwargs | |
| ) | |
| return SampleState( | |
| cur_len=state.cur_len + 1, | |
| sequences=next_sequences, | |
| running_token=next_token, | |
| is_sent_finished=next_is_sent_finished, | |
| model_kwargs=next_model_kwargs, | |
| prng_key=prng_key_next, | |
| ) | |
| # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU | |
| if input_ids.shape[1] > 1: | |
| state = sample_search_body_fn(state) | |
| if not trace: | |
| state = self._run_loop_in_debug( | |
| sample_search_cond_fn, sample_search_body_fn, state | |
| ) | |
| else: | |
| state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) | |
| return FlaxSampleOutput(sequences=state.sequences) | |
| def _beam_search( | |
| self, | |
| input_ids: None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| length_penalty: Optional[float] = None, | |
| early_stopping: Optional[bool] = None, | |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, | |
| trace: bool = True, | |
| params: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, | |
| ): | |
| """ | |
| This beam search function is heavily inspired by Flax's official example: | |
| https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 | |
| """ | |
| def flatten_beam_dim(tensor): | |
| """Flattens the first two dimensions of a non-scalar array.""" | |
| # ignore scalars (e.g. cache index) | |
| if tensor.ndim == 0: | |
| return tensor | |
| return tensor.reshape( | |
| (tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:] | |
| ) | |
| def unflatten_beam_dim(tensor, batch_size, num_beams): | |
| """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" | |
| # ignore scalars (e.g. cache index) | |
| if tensor.ndim == 0: | |
| return tensor | |
| return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) | |
| def gather_beams(nested, beam_indices, batch_size, new_num_beams): | |
| """ | |
| Gathers the beam slices indexed by beam_indices into new beam array. | |
| """ | |
| batch_indices = jnp.reshape( | |
| jnp.arange(batch_size * new_num_beams) // new_num_beams, | |
| (batch_size, new_num_beams), | |
| ) | |
| def gather_fn(tensor): | |
| # ignore scalars (e.g. cache index) | |
| if tensor.ndim == 0: | |
| return tensor | |
| else: | |
| return tensor[batch_indices, beam_indices] | |
| return jax.tree_map(gather_fn, nested) | |
| # init values | |
| max_length = ( | |
| max_length | |
| if max_length is not None | |
| else self.config.marian_config.max_length | |
| ) | |
| pad_token_id = ( | |
| pad_token_id | |
| if pad_token_id is not None | |
| else self.config.marian_config.pad_token_id | |
| ) | |
| eos_token_id = ( | |
| eos_token_id | |
| if eos_token_id is not None | |
| else self.config.marian_config.eos_token_id | |
| ) | |
| length_penalty = ( | |
| length_penalty | |
| if length_penalty is not None | |
| else self.config.marian_config.length_penalty | |
| ) | |
| early_stopping = ( | |
| early_stopping | |
| if early_stopping is not None | |
| else self.config.marian_config.early_stopping | |
| ) | |
| batch_size, num_beams, cur_len = input_ids.shape | |
| eos_token_id = jnp.array(eos_token_id) | |
| pad_token_id = jnp.array(pad_token_id) | |
| cur_len = jnp.array(cur_len) | |
| # per batch,beam-item holding current token in loop. | |
| sequences = jnp.full( | |
| (batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32 | |
| ) | |
| running_sequences = jnp.full( | |
| (batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32 | |
| ) | |
| running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) | |
| # per batch,beam-item state bit indicating if sentence has finished. | |
| is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) | |
| # per batch,beam-item score, logprobs | |
| running_scores = jnp.tile( | |
| jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1] | |
| ) | |
| scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) | |
| # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop | |
| # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. | |
| model = self.decode if self.config.is_encoder_decoder else self | |
| # flatten beam dim | |
| if "encoder_outputs" in model_kwargs: | |
| model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( | |
| model_kwargs["encoder_outputs"]["last_hidden_state"] | |
| ) | |
| if "attention_mask" in model_kwargs: | |
| model_kwargs["attention_mask"] = flatten_beam_dim( | |
| model_kwargs["attention_mask"] | |
| ) | |
| # initialize model specific kwargs | |
| model_kwargs = self.prepare_inputs_for_generation( | |
| flatten_beam_dim(input_ids), max_length, **model_kwargs | |
| ) | |
| # initialize state | |
| state = BeamSearchState( | |
| cur_len=cur_len, | |
| running_sequences=running_sequences, | |
| running_scores=running_scores, | |
| sequences=sequences, | |
| scores=scores, | |
| is_sent_finished=is_sent_finished, | |
| model_kwargs=model_kwargs, | |
| ) | |
| def beam_search_cond_fn(state): | |
| """beam search state termination condition fn.""" | |
| # 1. is less than max length? | |
| not_max_length_yet = state.cur_len < max_length | |
| # 2. can the new beams still improve? | |
| best_running_score = state.running_scores[:, -1:] / ( | |
| max_length ** length_penalty | |
| ) | |
| worst_finished_score = jnp.where( | |
| state.is_sent_finished, | |
| jnp.min(state.scores, axis=1, keepdims=True), | |
| np.array(-1.0e7), | |
| ) | |
| improvement_still_possible = jnp.all( | |
| worst_finished_score < best_running_score | |
| ) | |
| # 3. is there still a beam that has not finished? | |
| still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) | |
| return not_max_length_yet & still_open_beam & improvement_still_possible | |
| def beam_search_body_fn(state): | |
| """beam search state update fn.""" | |
| # 1. Forward current tokens | |
| # Collect the current position slice along length to feed the fast | |
| # autoregressive decoder model. Flatten the beam dimension into batch | |
| # dimension for feeding into the model. | |
| # unflatten beam dimension | |
| # Unflatten beam dimension in attention cache arrays | |
| input_token = flatten_beam_dim( | |
| lax.dynamic_slice( | |
| state.running_sequences, | |
| (0, 0, state.cur_len - 1), | |
| (batch_size, num_beams, 1), | |
| ) | |
| ) | |
| model_outputs = model(input_token, params=params, **state.model_kwargs) | |
| logits = unflatten_beam_dim( | |
| model_outputs.logits[:, 0], batch_size, num_beams | |
| ) | |
| cache = jax.tree_map( | |
| lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), | |
| model_outputs.past_key_values, | |
| ) | |
| # 2. Compute log probs | |
| # get log probabilities from logits, | |
| # process logits with processors (*e.g.* min_length, ...), and | |
| # add new logprobs to existing running logprobs scores. | |
| log_probs = jax.nn.log_softmax(logits) | |
| log_probs = logits_processor( | |
| flatten_beam_dim(running_sequences), | |
| flatten_beam_dim(log_probs), | |
| state.cur_len, | |
| ) | |
| log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) | |
| log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) | |
| vocab_size = log_probs.shape[2] | |
| log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) | |
| # 3. Retrieve top-K | |
| # Each item in batch has num_beams * vocab_size candidate sequences. | |
| # For each item, get the top 2*k candidates with the highest log- | |
| # probabilities. We gather the top 2*K beams here so that even if the best | |
| # K sequences reach EOS simultaneously, we have another K sequences | |
| # remaining to continue the live beam search. | |
| # Gather the top 2*K scores from _all_ beams. | |
| # Gather 2*k top beams. | |
| # Recover the beam index by floor division. | |
| # Recover token id by modulo division and expand Id array for broadcasting. | |
| # Update sequences for the 2*K top-k new sequences. | |
| beams_to_keep = 2 * num_beams | |
| topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) | |
| topk_beam_indices = topk_indices // vocab_size | |
| topk_running_sequences = gather_beams( | |
| state.running_sequences, topk_beam_indices, batch_size, beams_to_keep | |
| ) | |
| topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) | |
| topk_sequences = lax.dynamic_update_slice( | |
| topk_running_sequences, topk_ids, (0, 0, state.cur_len) | |
| ) | |
| # 4. Check which sequences have ended | |
| # Update current sequences: | |
| # Did any of these sequences reach an end marker? | |
| # To prevent these just finished sequences from being added to the current sequences | |
| # set of active beam search sequences, set their log probs to a very large | |
| # negative value. | |
| did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id | |
| topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) | |
| # 5. Get running sequences scores for next | |
| # Determine the top k beam indices (from top 2*k beams) from log probs | |
| # and gather top k beams (from top 2*k beams). | |
| next_topk_indices = jnp.flip( | |
| lax.top_k(topk_log_probs, k=num_beams)[1], axis=1 | |
| ) | |
| next_running_sequences, next_running_scores = gather_beams( | |
| [topk_sequences, topk_log_probs], | |
| next_topk_indices, | |
| batch_size, | |
| num_beams, | |
| ) | |
| # 6. Process topk logits | |
| # Further process log probs: | |
| # - add length penalty | |
| # - make sure no scores can be added anymore if beam is full | |
| # - make sure still running sequences cannot be chosen as finalized beam | |
| topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty) | |
| beams_in_batch_are_full = ( | |
| jnp.broadcast_to( | |
| state.is_sent_finished.all(axis=-1, keepdims=True), | |
| did_topk_just_finished.shape, | |
| ) | |
| & early_stopping | |
| ) | |
| add_penalty = ~did_topk_just_finished | beams_in_batch_are_full | |
| topk_log_probs += add_penalty * np.array(-1.0e7) | |
| # 7. Get scores, sequences, is sentence finished for next. | |
| # Combine sequences, scores, and flags along the beam dimension and compare | |
| # new finished sequence scores to existing finished scores and select the | |
| # best from the new set of beams | |
| merged_sequences = jnp.concatenate( | |
| [state.sequences, topk_sequences], axis=1 | |
| ) | |
| merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) | |
| merged_is_sent_finished = jnp.concatenate( | |
| [state.is_sent_finished, did_topk_just_finished], axis=1 | |
| ) | |
| topk_merged_indices = jnp.flip( | |
| lax.top_k(merged_scores, k=num_beams)[1], axis=1 | |
| ) | |
| next_sequences, next_scores, next_is_sent_finished = gather_beams( | |
| [merged_sequences, merged_scores, merged_is_sent_finished], | |
| topk_merged_indices, | |
| batch_size, | |
| num_beams, | |
| ) | |
| # 8. Update model kwargs. | |
| # Determine the top k beam indices from the original set of all beams. | |
| # With these, gather the top k beam-associated caches. | |
| next_running_indices = gather_beams( | |
| topk_beam_indices, next_topk_indices, batch_size, num_beams | |
| ) | |
| next_cache = gather_beams( | |
| cache, next_running_indices, batch_size, num_beams | |
| ) | |
| model_outputs["past_key_values"] = jax.tree_map( | |
| lambda x: flatten_beam_dim(x), next_cache | |
| ) | |
| next_model_kwargs = self.update_inputs_for_generation( | |
| model_outputs, state.model_kwargs | |
| ) | |
| return BeamSearchState( | |
| cur_len=state.cur_len + 1, | |
| running_scores=next_running_scores, | |
| running_sequences=next_running_sequences, | |
| scores=next_scores, | |
| sequences=next_sequences, | |
| is_sent_finished=next_is_sent_finished, | |
| model_kwargs=next_model_kwargs, | |
| ) | |
| # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU | |
| state = beam_search_body_fn(state) | |
| if not trace: | |
| state = self._run_loop_in_debug( | |
| beam_search_cond_fn, beam_search_body_fn, state | |
| ) | |
| else: | |
| state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) | |
| # Account for the edge-case where there are no finished sequences for a | |
| # particular batch item. If so, return running sequences for that batch item. | |
| none_finished = jnp.any(state.is_sent_finished, axis=1) | |
| sequences = jnp.where( | |
| none_finished[:, None, None], state.sequences, state.running_sequences | |
| ) | |
| scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) | |
| # take best beam for each batch | |
| sequences = sequences[:, -1] | |
| scores = scores[:, -1] | |
| return FlaxBeamSearchOutput(sequences=sequences, scores=scores) | |