VoRA-7B-Base / vora_generation_utils.py
Hon-Wong's picture
Upload folder using huggingface_hub
b92bd4e verified
from typing import Any, Dict, Optional
import torch
from transformers import GenerationMixin
from transformers.cache_utils import Cache
from transformers.utils import ModelOutput
class VoraGenerationMixin(GenerationMixin):
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask_2d = (attention_mask[:, 0, :, :] == 0).any(dim=1).long().to(attention_mask.device)
model_input = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask_2d,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
model_input['attention_mask'] = attention_mask
return model_input
else:
return super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
if "attention_mask" in model_kwargs and model_kwargs["attention_mask"].ndim == 4:
attention_mask = model_kwargs.pop("attention_mask")
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens
)
bs, _, seq_len, tgt_len = attention_mask.shape
dtype = attention_mask.dtype
min_dtype = torch.finfo(dtype).min
new_col = attention_mask.new_zeros((bs, 1, seq_len, 1)).fill_(min_dtype)
new_row = attention_mask.new_zeros((bs, 1, 1, tgt_len + 1))
model_kwargs["attention_mask"] = torch.cat([
torch.cat([attention_mask, new_col], dim=-1),
new_row
], dim=2)
return model_kwargs
else:
return super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens
)
def custom_prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask[:, :, -sequence_length:, -target_length:]
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask