|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from transformers import GenerationMixin |
|
from transformers.cache_utils import Cache |
|
from transformers.utils import ModelOutput |
|
|
|
|
|
class VoraGenerationMixin(GenerationMixin): |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past_key_values: Optional[Cache] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
): |
|
if attention_mask is not None and attention_mask.ndim == 4: |
|
attention_mask_2d = (attention_mask[:, 0, :, :] == 0).any(dim=1).long().to(attention_mask.device) |
|
model_input = super().prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask_2d, |
|
inputs_embeds=inputs_embeds, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
model_input['attention_mask'] = attention_mask |
|
return model_input |
|
else: |
|
return super().prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
def _update_model_kwargs_for_generation( |
|
self, |
|
outputs: ModelOutput, |
|
model_kwargs: Dict[str, Any], |
|
is_encoder_decoder: bool = False, |
|
num_new_tokens: int = 1, |
|
) -> Dict[str, Any]: |
|
if "attention_mask" in model_kwargs and model_kwargs["attention_mask"].ndim == 4: |
|
attention_mask = model_kwargs.pop("attention_mask") |
|
model_kwargs = super()._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens |
|
) |
|
bs, _, seq_len, tgt_len = attention_mask.shape |
|
dtype = attention_mask.dtype |
|
min_dtype = torch.finfo(dtype).min |
|
new_col = attention_mask.new_zeros((bs, 1, seq_len, 1)).fill_(min_dtype) |
|
new_row = attention_mask.new_zeros((bs, 1, 1, tgt_len + 1)) |
|
model_kwargs["attention_mask"] = torch.cat([ |
|
torch.cat([attention_mask, new_col], dim=-1), |
|
new_row |
|
], dim=2) |
|
return model_kwargs |
|
else: |
|
return super()._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens |
|
) |
|
|
|
|
|
def custom_prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask: torch.Tensor, |
|
sequence_length: int, |
|
target_length: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
cache_position: torch.Tensor, |
|
batch_size: int, |
|
**kwargs, |
|
): |
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
causal_mask = attention_mask[:, :, -sequence_length:, -target_length:] |
|
else: |
|
min_dtype = torch.finfo(dtype).min |
|
causal_mask = torch.full( |
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
|
) |
|
if sequence_length != 1: |
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
if attention_mask is not None: |
|
causal_mask = causal_mask.clone() |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
padding_mask, min_dtype |
|
) |
|
|
|
return causal_mask |
|
|