|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
GenerationMixin, |
|
LlamaForCausalLM, |
|
PreTrainedModel, |
|
) |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel |
|
from transformers.utils import logging, replace_return_docstrings |
|
|
|
from .configuration_sarashina2_vision import Sarashina2VisionConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "Sarashina2VisionConfig" |
|
|
|
|
|
class Sarashina2VisionPreTrainedModel(PreTrainedModel): |
|
config_class = Sarashina2VisionConfig |
|
base_model_prefix = "model" |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_supports_static_cache = True |
|
|
|
def _init_weights(self, module): |
|
std = ( |
|
self.config.initializer_range |
|
if hasattr(self.config, "initializer_range") |
|
else self.config.text_config.initializer_range |
|
) |
|
|
|
if hasattr(module, "class_embedding"): |
|
module.class_embedding.data.normal_(mean=0.0, std=std) |
|
|
|
if isinstance(module, (nn.Linear, nn.Conv3d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMixin): |
|
def __init__(self, config: Sarashina2VisionConfig): |
|
super().__init__(config) |
|
self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) |
|
self.norm = nn.LayerNorm(config.text_config.hidden_size) |
|
self.llm = LlamaForCausalLM._from_config(config.text_config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.llm.get_input_embeddings() |
|
|
|
def get_image_embeds( |
|
self, |
|
hidden_states: torch.Tensor, |
|
grid_thw: torch.Tensor, |
|
) -> torch.Tensor: |
|
rotary_pos_emb = self.visual.rot_pos_emb(grid_thw) |
|
hidden_states = self.visual.patch_embed(hidden_states) |
|
|
|
cu_seqlens = torch.repeat_interleave( |
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] |
|
).cumsum(dim=0, dtype=torch.int32) |
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
for blk in self.visual.blocks: |
|
hidden_states = blk( |
|
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb |
|
) |
|
return self.norm(self.visual.merger(hidden_states)) |
|
|
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
pixel_values: torch.FloatTensor = None, |
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**lm_kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
""" |
|
Args: |
|
input_ids (torch.LongTensor, optional): Indices of input sequence tokens in the vocabulary. Defaults to None. |
|
attention_mask (Optional[torch.Tensor], optional): Mask to avoid performing attention on padding token indices. Defaults to None. |
|
position_ids (Optional[torch.LongTensor], optional): Indices of positions of each input sequence tokens in the position embeddings. Defaults to None. |
|
past_key_values (Optional[List[torch.FloatTensor]], optional): _description_. Defaults to None. |
|
inputs_embeds (Optional[torch.FloatTensor], optional): Instead of passing `input_ids` you can choose to directly pass an embedded representation. Defaults to None. |
|
labels (Optional[torch.LongTensor], optional): Labels for computing the masked language modeling loss. Defaults to None. |
|
use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding. Defaults to None. |
|
output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers. Defaults to None. |
|
output_hidden_states (Optional[bool], optional): Whether or not to return the hidden states of all layers. Defaults to None. |
|
return_dict (Optional[bool], optional): Whether or not to return a `CausalLMOutputWithPast` instead of a plain tuple. Defaults to None. |
|
pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None. |
|
image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None. |
|
cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None. |
|
logits_to_keep (Union[int, torch.Tensor]): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
|
This is useful when using packed tensor format (single dimension for batch and sequence length). |
|
Returns: |
|
CausalLMOutputWithPast: The output of the model. |
|
""" |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
if pixel_values is not None: |
|
pixel_values = pixel_values.type(self.visual.get_dtype()) |
|
image_embeds = self.get_image_embeds(pixel_values, image_grid_thw) |
|
n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
|
n_image_features = image_embeds.shape[0] |
|
if n_image_tokens != n_image_features: |
|
raise ValueError( |
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
) |
|
image_mask = ( |
|
(input_ids == self.config.image_token_index) |
|
.unsqueeze(-1) |
|
.expand_as(inputs_embeds) |
|
.to(inputs_embeds.device) |
|
) |
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
outputs = self.llm( |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
logits_to_keep=logits_to_keep, |
|
**lm_kwargs, |
|
) |
|
|
|
logits = outputs[0] |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.float() |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
pixel_values=None, |
|
attention_mask=None, |
|
cache_position=None, |
|
logits_to_keep=None, |
|
image_grid_thw=None, |
|
**kwargs, |
|
): |
|
model_inputs = self.llm.prepare_inputs_for_generation( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
cache_position=cache_position, |
|
logits_to_keep=logits_to_keep, |
|
**kwargs, |
|
) |
|
|
|
if cache_position[0] == 0: |
|
|
|
|
|
model_inputs["pixel_values"] = pixel_values |
|
model_inputs["image_grid_thw"] = image_grid_thw |
|
|
|
return model_inputs |
|
|
|
|
|
AutoConfig.register("sarashina2_vision", Sarashina2VisionConfig) |
|
AutoModelForCausalLM.register(Sarashina2VisionConfig, Sarashina2VisionForCausalLM) |
|
Sarashina2VisionConfig.register_for_auto_class() |
|
Sarashina2VisionForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|