|
import ast |
|
import contextlib |
|
import gc |
|
import json |
|
import math |
|
import os |
|
from dataclasses import dataclass |
|
from functools import partial |
|
from itertools import chain |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from timm.layers import LayerNorm, LayerNorm2d |
|
from timm.models.regnet import RegStage |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
PreTrainedModel, |
|
) |
|
from transformers.generation.utils import GenerationMixin |
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
from transformers.modeling_utils import ( |
|
is_fsdp_enabled, |
|
is_local_dist_rank_0, |
|
no_init_weights, |
|
) |
|
from transformers.models.auto import CONFIG_MAPPING |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_hyperclovax import HCXVisionConfig |
|
from .preprocessor import select_best_resolution |
|
|
|
EOT = "<|endofturn|>" |
|
IMG_LOC = "<|dummy3|>" |
|
|
|
|
|
def get_rank(): |
|
if dist.is_initialized(): |
|
return dist.get_rank() |
|
return 0 |
|
|
|
|
|
def get_world_size(): |
|
if torch.distributed.is_initialized(): |
|
world_size = torch.distributed.get_world_size() |
|
else: |
|
world_size = 1 |
|
return world_size |
|
|
|
|
|
def unpad_image(tensor: torch.Tensor, original_size: Tuple[int, int]) -> torch.Tensor: |
|
"""Unpads a PyTorch tensor of a padded and resized image. |
|
|
|
This function removes padding from a tensor image that was previously padded and resized. |
|
The padding is removed based on the aspect ratio difference between the original and current image dimensions. |
|
|
|
Args: |
|
tensor: The image tensor, assumed to be in CxHxW format. |
|
original_size: The original size of the image as (width, height). |
|
|
|
Returns: |
|
The unpadded image tensor. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> # Example 1: Unpadding with height padding |
|
>>> padded_tensor = torch.randn(1, 64, 48) # Padded tensor (C=1, H=64, W=48) |
|
>>> original_size = (32, 32) # Original size (width=32, height=32) |
|
>>> unpadded_tensor = unpad_image(padded_tensor, original_size) |
|
>>> unpadded_tensor.shape |
|
torch.Size([1, 48, 48]) |
|
>>> # Example 2: Unpadding with width padding |
|
>>> padded_tensor = torch.randn(1, 48, 64) # Padded tensor (C=1, H=48, W=64) |
|
>>> original_size = (32, 32) # Original size (width=32, height=32) |
|
>>> unpadded_tensor = unpad_image(padded_tensor, original_size) |
|
>>> unpadded_tensor.shape |
|
torch.Size([1, 48, 48]) |
|
""" |
|
original_width, original_height = original_size |
|
current_height, current_width = tensor.shape[1:] |
|
|
|
original_aspect_ratio = original_width / original_height |
|
current_aspect_ratio = current_width / current_height |
|
|
|
if original_aspect_ratio > current_aspect_ratio: |
|
scale_factor = current_width / original_width |
|
new_height = int(original_height * scale_factor) |
|
padding = (current_height - new_height) // 2 |
|
unpadded_tensor = tensor[:, padding : current_height - padding, :] |
|
else: |
|
scale_factor = current_height / original_height |
|
new_width = int(original_width * scale_factor) |
|
padding = (current_width - new_width) // 2 |
|
unpadded_tensor = tensor[:, :, padding : current_width - padding] |
|
|
|
return unpadded_tensor |
|
|
|
|
|
def get_anyres_image_grid_shape( |
|
image_size: Tuple[int, int], |
|
grid_pinpoints: Union[str, List[Tuple[int, int]]], |
|
patch_size: int, |
|
) -> Tuple[int, int]: |
|
"""Calculates the image patch grid shape after any-resolution preprocessing. |
|
|
|
Selects the optimal resolution from predefined grid pinpoints based on input image |
|
dimensions using `select_best_resolution`, then computes the grid layout by |
|
dividing the selected resolution by the patch size using integer division. |
|
|
|
Args: |
|
image_size (Tuple[int, int]): Original image dimensions in (width, height) format. |
|
grid_pinpoints (Union[str, List[Tuple[int, int]]]): Accepts either: |
|
- List of (height, width) resolution tuples |
|
- String representation of list (e.g., "[(224, 224), (336, 336)]") |
|
patch_size (int): Spatial dimension of square patches for grid division. |
|
|
|
Returns: |
|
Tuple[int, int]: Grid dimensions as (num_patches_width, num_patches_height). |
|
|
|
Examples: |
|
>>> # Basic case with list input |
|
>>> get_anyres_image_grid_shape((1000, 800), [(224, 224), (448, 448)], 112) |
|
(4, 4) |
|
|
|
>>> # Basic case with string input |
|
>>> get_anyres_image_grid_shape((600, 400), "[(336, 336), (672, 672)]", 112) |
|
(6, 6) |
|
|
|
>>> # Case where resolution is not perfectly divisible by patch_size |
|
>>> # select_best_resolution picks (224, 224). 224 // 100 = 2 |
|
>>> get_anyres_image_grid_shape((500, 500), [(224, 224)], 100) |
|
(2, 2) |
|
|
|
>>> # Different patch size |
|
>>> # select_best_resolution picks (448, 448). 448 // 224 = 2 |
|
>>> get_anyres_image_grid_shape((1200, 900), [(448, 448), (224, 224)], 224) |
|
(2, 2) |
|
|
|
Note: |
|
String-formatted grid_pinpoints are converted via ast.literal_eval. Invalid formats |
|
may raise syntax exceptions. The actual resolution selection depends on the |
|
implementation of `select_best_resolution`. The doctests assume |
|
`select_best_resolution` picks the *first* resolution provided in `grid_pinpoints`. |
|
""" |
|
possible_resolutions = grid_pinpoints if isinstance(grid_pinpoints, list) else ast.literal_eval(grid_pinpoints) |
|
|
|
original_width, original_height = image_size |
|
height, width = select_best_resolution((original_height, original_width), possible_resolutions) |
|
return width // patch_size, height // patch_size |
|
|
|
|
|
def reshape_and_unpad_image_features( |
|
image_feature: torch.Tensor, |
|
height: int, |
|
width: int, |
|
image_size: Tuple[int, int], |
|
possible_resolutions: List[Tuple[int, int]], |
|
grid_size: int, |
|
unpad: bool, |
|
image_newline: torch.Tensor, |
|
) -> torch.Tensor: |
|
"""Reshapes and processes image features with optional unpadding operation. |
|
|
|
Processes input image features by: |
|
1. Separating base features from spatial features |
|
2. Reshaping spatial features into a 5D tensor (num_patch_height, num_patch_width, height, width, channels) |
|
3. Performing either unpadding operation or simple reshaping based on 'unpad' flag |
|
4. Concatenating processed features with base features |
|
|
|
Args: |
|
image_feature: Input tensor containing image features with shape |
|
[1 + num_patches, feature_dim] where the first element is the base feature |
|
height: Original image height in pixels |
|
width: Original image width in pixels |
|
image_size: Target image size as (width, height) tuple |
|
possible_resolutions: List of possible [height, width] resolutions for multi-scale processing |
|
grid_size: Grid dimension for patch arrangement |
|
unpad: Flag to enable unpadding operation |
|
image_newline: Special token tensor used as separator when unpadding |
|
|
|
Returns: |
|
torch.Tensor: Processed image features tensor with shape [1 + num_processed_patches, feature_dim] |
|
|
|
Raises: |
|
AssertionError: If base feature dimension doesn't match height*width |
|
""" |
|
base_image_feature = image_feature[0] |
|
image_feature = image_feature[1:] |
|
|
|
assert ( |
|
height * width == base_image_feature.shape[0] |
|
), f"height: {height}, width: {width}, base_image_feature.shape[0]: {base_image_feature.shape[0]}" |
|
|
|
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_size, possible_resolutions, grid_size) |
|
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
|
|
|
if unpad: |
|
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
|
image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
|
image_feature = unpad_image(image_feature, image_size) |
|
image_feature = torch.cat( |
|
( |
|
image_feature, |
|
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device), |
|
), |
|
dim=-1, |
|
) |
|
image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
|
else: |
|
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() |
|
image_feature = image_feature.flatten(0, 3) |
|
image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
|
|
|
return image_feature |
|
|
|
|
|
def anyres_postprocessing( |
|
image_forward_outs: torch.FloatTensor, |
|
split_sizes: List[int], |
|
image_sizes: List[List[int]], |
|
possible_resolutions: List[Tuple[int, int]], |
|
is_videos: List[bool], |
|
patch_size: int, |
|
grid_size: int, |
|
image_newline: torch.FloatTensor, |
|
num_queries_vis_abstractor: int = -1, |
|
unpad: bool = False, |
|
) -> List[torch.FloatTensor]: |
|
"""Processes 2D visual features into 1D sequences with post-processing steps. |
|
|
|
Performs AnyRes postprocessing by flattening 2D visual features from grid partitions into 1D sequences, adding |
|
newline embeddings at row boundaries for images, and optionally removing padding regions based on original image |
|
sizes. For video data, processes each frame's features separately into a single sequence per video and disables |
|
unpadding and newline insertion. |
|
|
|
Args: |
|
image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape |
|
(number_of_images_in_grid, total_patches, feature_dim) containing visual features. |
|
split_sizes (List[int]): A list containing the number of patches for each sample in the batch. The sum of |
|
`split_sizes` should equal `image_forward_outs.shape[0]`. |
|
image_sizes (List[List[int]]): A list where each element is a list `[width, height]` representing the original |
|
dimensions of the corresponding image sample. Used for unpadding. |
|
possible_resolutions (List[Tuple[int, int]]): A list of supported resolution tuples `(height, width)` used by |
|
`reshape_and_unpad_image_features` for spatial reconstruction, especially during unpadding. |
|
is_videos (List[bool]): A list of boolean flags indicating whether each corresponding sample in the batch is a |
|
video [`True`] or an image [`False`]. |
|
patch_size (int): The spatial dimension (height and width) of the square patches the image was divided into. |
|
grid_size (int): The spatial dimension (height and width) of the square grid onto which patches are mapped. |
|
`grid_size` should be divisible by `patch_size`. |
|
image_newline (torch.FloatTensor): A learnable tensor representing the newline embedding, typically with shape |
|
(1, feature_dim). Added after each row of image patches when not unpadding. |
|
num_queries_vis_abstractor (int, optional): If a visual abstractor with a fixed number of output queries is used |
|
instead of grid patching, this specifies the number of queries. Must be a perfect square if > 0. |
|
Defaults to -1 (indicating standard grid patching is used). |
|
unpad (bool, optional): If `True`, removes padding tokens from image features based on `image_sizes` and |
|
`possible_resolutions`. Does not apply to video features. Defaults to False. |
|
|
|
Returns: |
|
List[torch.FloatTensor]: A list of tensors, where each tensor represents the processed 1D sequence of visual |
|
features for a single sample from the input batch. The length of the sequence varies depending on processing |
|
(unpadding, newlines, video flattening). |
|
|
|
Raises: |
|
AssertionError: If `num_queries_vis_abstractor` is greater than 0 but not a perfect square. |
|
""" |
|
height = width = grid_size // patch_size |
|
|
|
if num_queries_vis_abstractor > 0: |
|
assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number" |
|
height = width = int(num_queries_vis_abstractor**0.5) |
|
|
|
image_features = torch.split(image_forward_outs, split_sizes, dim=0) |
|
|
|
|
|
new_image_features = [] |
|
for image_idx, (image_feature, is_video) in enumerate(zip(image_features, is_videos)): |
|
if image_feature.shape[0] > 1: |
|
if not is_video: |
|
image_feature = reshape_and_unpad_image_features( |
|
image_feature=image_feature, |
|
height=height, |
|
width=width, |
|
image_size=image_sizes[image_idx], |
|
possible_resolutions=possible_resolutions, |
|
grid_size=grid_size, |
|
unpad=unpad, |
|
image_newline=image_newline, |
|
) |
|
else: |
|
image_feature = image_feature.flatten(0, 1) |
|
else: |
|
image_feature = image_feature[0] |
|
if unpad and not is_video: |
|
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0) |
|
new_image_features.append(image_feature) |
|
image_features = new_image_features |
|
return image_features |
|
|
|
|
|
def adaptive_anyres_postprocessing( |
|
image_forward_outs: torch.FloatTensor, |
|
image_sizes: List[List[int]], |
|
possible_resolutions: List[Tuple[int, int]], |
|
is_videos: List[bool], |
|
group_ids: List[List[int]], |
|
num_queries_vis_abstractors: List[List[int]], |
|
grid_size: int, |
|
image_newline: torch.FloatTensor, |
|
unpad: bool = False, |
|
) -> List[torch.FloatTensor]: |
|
"""Adaptive AnyRes postprocessing for multi-group feature aggregation. |
|
|
|
Processes 2D visual features into 1D sequences with group-wise adaptive processing. Each image can belong to |
|
multiple processing groups with different query configurations. Features are processed per group and aggregated |
|
according to group_ids. |
|
|
|
Args: |
|
image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape |
|
(number_of_images_in_grid, total_patches, feature_dim) containing visual features. |
|
image_sizes (List[List[int]]): Original image dimensions for each sample. [[width, height], ... ] |
|
possible_resolutions (List[Tuple[int, int]]): Supported resolutions. [[height, width], ... ] |
|
is_videos (List[bool]): Flags indicating video inputs |
|
group_ids (List[List[int]]): Group indices for feature aggregation. Each group means a single grid. |
|
num_queries_vis_abstractors (List[List[int]]): Query numbers per group |
|
grid_size (int): Total grid size for spatial processing |
|
image_newline (torch.FloatTensor): Sample-wise config. Newline embedding tensor |
|
unpad (bool, optional): Sample-wise config. Enable padding removal. Defaults to False. |
|
|
|
Returns: |
|
List[torch.FloatTensor]: Aggregated features per group |
|
|
|
Raises: |
|
AssertionError: If num_queries is not square number in any group |
|
""" |
|
|
|
new_image_features = [] |
|
for image_idx, (image_feature, is_video) in enumerate(zip(image_forward_outs, is_videos)): |
|
num_queries_vis_abstractor = num_queries_vis_abstractors[image_idx] |
|
assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number" |
|
height = width = int(num_queries_vis_abstractor**0.5) |
|
|
|
if image_feature.shape[0] > 1: |
|
if not is_video: |
|
image_feature = reshape_and_unpad_image_features( |
|
image_feature=image_feature, |
|
height=height, |
|
width=width, |
|
image_size=image_sizes[image_idx], |
|
possible_resolutions=possible_resolutions, |
|
grid_size=grid_size, |
|
unpad=unpad, |
|
image_newline=image_newline, |
|
) |
|
else: |
|
image_feature = image_feature.flatten(0, 1) |
|
else: |
|
image_feature = image_feature[0] |
|
if unpad and not is_video: |
|
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0) |
|
new_image_features.append(image_feature) |
|
|
|
image_features = [ |
|
torch.cat([new_image_features[group_id] for group_id in group_ids_list], dim=0) for group_ids_list in group_ids |
|
] |
|
return image_features |
|
|
|
|
|
@dataclass |
|
class HCXVisionOutput(ModelOutput): |
|
"""Output class for vision models, containing various computation results. |
|
|
|
Args: |
|
loss (Optional[torch.FloatTensor], optional): Total cross-entropy loss calculated from logits and labels. |
|
loss_per_sample (Optional[torch.FloatTensor], optional): Per-sample loss values for advanced loss processing. |
|
logits (torch.FloatTensor): Classification scores (before SoftMax) of shape (batch_size, num_classes). |
|
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): Contains precomputed hidden-states |
|
that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
hidden_states (Optional[Tuple[torch.FloatTensor]], optional): |
|
Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of |
|
shape (batch_size, sequence_length, hidden_size). |
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (Optional[Tuple[torch.FloatTensor]], optional): Tuple of torch.FloatTensor (one for each layer) |
|
of shape (batch_size, num_heads, sequence_length, sequence_length). Attentions weights after the attention |
|
softmax, used to compute the weighted average in the self-attention heads. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
loss_per_sample: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin): |
|
"""HCX Vision model for causal language modeling with vision-language capabilities. |
|
|
|
This class combines a vision model with a language model to create a multimodal model |
|
capable of processing images or videos and generating text based on the visual inputs. |
|
|
|
Attributes: |
|
config_class: Configuration class for the model. |
|
vision_model_name: Name of the vision model component. |
|
_no_split_modules: List of modules that should not be split during parallel processing. |
|
supports_gradient_checkpointing: Whether the model supports gradient checkpointing. |
|
_skip_keys_device_placement: Keys to skip during device placement. |
|
""" |
|
|
|
config_class = HCXVisionConfig |
|
vision_model_name = "vision_model" |
|
_no_split_modules = ["CLIPAttention", "SiglipVisionModel"] |
|
supports_gradient_checkpointing = True |
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
def __init__( |
|
self, |
|
config: HCXVisionConfig, |
|
**kwargs: Optional[Any], |
|
) -> None: |
|
"""Initialize the HCXVisionForCausalLM model. |
|
|
|
Args: |
|
config: Configuration object for the model containing parameters for both |
|
vision and language components. |
|
**kwargs: Additional keyword arguments: |
|
- use_liger: Whether to use liger kernel for hyperclovax models. |
|
- use_fused_ce: Whether to use fused cross-entropy loss. |
|
- use_sum_loss: Whether to use sum reduction for loss instead of mean. |
|
- is_safetensor_save: Whether to save model using safetensors format. |
|
|
|
Raises: |
|
ValueError: If vision_config is not defined or if language_config is not defined. |
|
""" |
|
super().__init__(config) |
|
|
|
self.flag_changed_max_position_embeddings = False |
|
|
|
vision_model_type = config.vision_config["model_type"] |
|
if vision_model_type in CONFIG_MAPPING: |
|
vision_config = CONFIG_MAPPING[vision_model_type](**config.vision_config) |
|
vision_config.auto_map = {} |
|
else: |
|
if config.vision_model_name_or_path is not None: |
|
vision_config = AutoConfig.from_pretrained(config.vision_model_name_or_path, trust_remote_code=True) |
|
elif config.vision_config["_name_or_path"] is not None: |
|
vision_config = AutoConfig.from_pretrained( |
|
config.vision_config["_name_or_path"], trust_remote_code=True |
|
) |
|
else: |
|
raise ValueError("vision_config is not defined") |
|
|
|
self.use_liger = kwargs.pop("use_liger", False) |
|
self.use_fused_ce = kwargs.pop("use_fused_ce", False) |
|
self.reduction = "sum" if kwargs.pop("use_sum_loss", False) else "mean" |
|
|
|
self.vision_config = vision_config |
|
vision_config.anyres = config.anyres |
|
vision_config.max_num_grids = config.max_num_grids |
|
|
|
possible_resolutions = [] |
|
if config.anyres: |
|
assert config.max_num_grids > 0 |
|
for i in range(1, config.max_num_grids + 1): |
|
for j in range(1, config.max_num_grids + 1): |
|
if i == 1 and j == 1 and not config.use_1x1_grid: |
|
continue |
|
if i * j <= config.max_num_grids: |
|
possible_resolutions.append([i, j]) |
|
|
|
possible_resolutions = [ |
|
[ys * vision_config.image_size, xs * vision_config.image_size] for ys, xs in possible_resolutions |
|
] |
|
|
|
self.possible_resolutions = possible_resolutions |
|
|
|
with no_init_weights(): |
|
self.vision_model = AutoModel.from_config( |
|
vision_config, trust_remote_code=True |
|
) |
|
|
|
assert config.language_config["model_type"] == "llama" |
|
language_config = CONFIG_MAPPING["llama"](**config.language_config) |
|
language_config._attn_implementation = kwargs.get("attn_implementation", "sdpa") |
|
language_config.logits_scaling = 1.0 |
|
|
|
self.language_config = language_config |
|
self.language_model = AutoModelForCausalLM.from_config(language_config) |
|
|
|
self.language_model.gradient_checkpointing_enable() |
|
self.num_queries_vis_abstractor = config.num_queries_vis_abstractor |
|
|
|
|
|
input_hidden_size = vision_config.hidden_size |
|
self.mm_projector = HCXVisionCAbstractor( |
|
num_queries=self.num_queries_vis_abstractor, |
|
num_input_tokens=(self.vision_config.image_size // self.vision_config.patch_size) ** 2, |
|
encoder_hidden_size=input_hidden_size, |
|
hidden_size=input_hidden_size, |
|
output_hidden_size=language_config.hidden_size, |
|
pos_emb=config.proj_pos_emb, |
|
prenorm=config.proj_prenorm, |
|
) |
|
self.use_nth_layer = config.use_nth_layer |
|
self.config.update({"vision_config": self.vision_model.config.to_dict()}) |
|
self.config.update({"language_config": self.language_model.config.to_dict()}) |
|
self.lm_head_vocab_size = ( |
|
language_config.padded_vocab_size |
|
if hasattr(language_config, "padded_vocab_size") |
|
else language_config.vocab_size |
|
) |
|
self.language_model.lm_head = nn.Linear(language_config.hidden_size, self.lm_head_vocab_size, bias=False) |
|
self.model_parallel = False |
|
self.device_map = None |
|
self.use_no_grad = None |
|
self.decoder_max_length = config.decoder_max_length |
|
|
|
self.anyres = config.anyres |
|
self.unpad = config.unpad |
|
if self.anyres: |
|
self.image_newline = nn.Parameter(torch.empty(language_config.hidden_size, dtype=self.dtype)) |
|
|
|
self.is_safetensor_save = kwargs.get("is_safetensor_save", True) |
|
self._backward_compatibility_gradient_checkpointing() |
|
|
|
def _init_weights(self, module): |
|
|
|
if ( |
|
isinstance(module, nn.Conv2d) |
|
or isinstance(module, nn.Embedding) |
|
or isinstance(module, nn.Linear) |
|
): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if hasattr(module, "bias") and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Parameter): |
|
embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype) |
|
module.data.normal_(mean=0.0, std=embed_std) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
non_vision_query_lengths: Optional[List[int]] = None, |
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
first_last_frames_slows: Optional[List[bool]] = None, |
|
is_video_list: Optional[List[bool]] = None, |
|
**kwargs, |
|
) -> Union[Tuple, HCXVisionOutput]: |
|
"""Forward pass of the model. |
|
|
|
This method processes the input tokens and images, combines them into a unified |
|
representation, and generates text output based on the inputs. |
|
|
|
Args: |
|
input_ids: Input token IDs. In positions where images are inputted, the value is replaced by "<|dummy3|>" |
|
pixel_values: List of lists of 4D tensors for images. Each outer list corresponds to a batch and contains |
|
inner lists of image tensors. |
|
past_key_values: Pre-computed key and value states of the attention layers for faster inference. |
|
attention_mask: Mask to avoid performing attention on padding token indices. |
|
inputs_embeds: Input embeddings. If provided, input_ids will not be used. |
|
labels: Labels for computing the language modeling loss. |
|
use_cache: Whether to use past key/values for faster inference. |
|
output_attentions: Whether to return attention weights of each layer. |
|
output_hidden_states: Whether to return hidden states of each layer. |
|
return_dict: Whether to return a ModelOutput instead of a tuple. |
|
image_sizes: List of lists representing image dimensions (width, height). |
|
vision_query_lengths: List of lists containing lengths when each image is converted into visual tokens. |
|
non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. |
|
img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. |
|
num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.\ |
|
For video frames, this is the number of visual tokens for the fast part. |
|
num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for |
|
the slow part when applying the slowfast algorithm to video frames. |
|
first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is |
|
applied to the first or last frames of the video. |
|
is_video_list: List of booleans indicating which inputs are videos. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
If return_dict=True, returns an HCXVisionOutput object containing: |
|
- loss: Language modeling loss if labels are provided, otherwise None. |
|
- loss_per_sample: Per-sample loss if labels are provided, otherwise None. |
|
- logits: Prediction scores of the language modeling head. |
|
- past_key_values: Past key/values for faster inference if use_cache=True. |
|
- hidden_states: Hidden states of all layers if output_hidden_states=True. |
|
- attentions: Attention weights of all layers if output_attentions=True. |
|
If return_dict=False, returns a tuple containing the above items except loss_per_sample. |
|
""" |
|
output_attentions = ( |
|
output_attentions if output_attentions is not None else self.config.vision_config["output_attentions"] |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.vision_config["output_hidden_states"] |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if inputs_embeds is None and past_key_values is None: |
|
inputs_embeds = self.extract_inputs_embeds( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
past_key_values=past_key_values, |
|
image_sizes=image_sizes, |
|
vision_query_lengths=vision_query_lengths, |
|
non_vision_query_lengths=non_vision_query_lengths, |
|
img_start_ids_list=img_start_ids_list, |
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
first_last_frames_slows=first_last_frames_slows, |
|
is_videos=is_video_list, |
|
) |
|
|
|
if inputs_embeds is not None: |
|
input_ids = None |
|
|
|
|
|
outputs = self.language_model.base_model( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
hidden_states = hidden_states * self.language_config.logits_scaling |
|
|
|
loss = None |
|
loss_per_sample = None |
|
logits = self.language_model.lm_head(hidden_states) |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss(reduction="none") |
|
shift_logits = shift_logits.view(-1, self.lm_head_vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
if get_rank() == 0: |
|
loss_per_sample = loss.view(logits.shape[0], -1).sum(axis=1) / ( |
|
shift_labels.view(logits.shape[0], -1) != self.config.ignore_index |
|
).sum(axis=1) |
|
loss = loss[shift_labels != self.config.ignore_index].mean() |
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return HCXVisionOutput( |
|
loss=loss, |
|
loss_per_sample=loss_per_sample, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def determine_non_vision_query_lengths( |
|
self, input_ids: torch.LongTensor, pad_id: int, img_start_id: int |
|
) -> List[int]: |
|
"""Calculate the lengths of non-vision query parts in the input. |
|
|
|
This method calculates the length of text tokens (excluding visual tokens) for each sample. |
|
When input_ids are collated, they are padded with pad_id on the right, so this method finds |
|
these values by identifying pad tokens and img_start_id tokens. |
|
|
|
Args: |
|
input_ids: Input token IDs with img_start_id markers for image positions. |
|
pad_id: Token ID used for padding. |
|
img_start_id: Token ID marking the start of image data. |
|
|
|
Returns: |
|
List of lengths of non-vision query parts for each sample in the batch. |
|
""" |
|
non_vision_query_lengths = [] |
|
batch_size, len_seq = input_ids.size(0), input_ids.size(1) |
|
|
|
for i in range(batch_size): |
|
temp_idx = (input_ids[i] == pad_id).nonzero() |
|
eos_idx = temp_idx[0, 0].item() if len(temp_idx) > 0 else len_seq |
|
num_imgs = (input_ids[i] == img_start_id).sum().item() |
|
non_vision_query_lengths.append(eos_idx - num_imgs) |
|
|
|
if all([pad_id in input_id for input_id in input_ids.tolist()]): |
|
non_vision_query_lengths = [ |
|
non_vision_query_length + 1 for non_vision_query_length in non_vision_query_lengths |
|
] |
|
|
|
return non_vision_query_lengths |
|
|
|
def determine_vision_query_lengths( |
|
self, image_features: List[List[torch.Tensor]], image_cnts: List[int] |
|
) -> List[List[int]]: |
|
"""Calculate the lengths of vision query parts in the input. |
|
|
|
This method calculates the lengths of visual tokens for each image in each sample based on |
|
the shapes of image feature tensors. For samples without any images, a dummy image is included |
|
but then converted to an empty list. |
|
|
|
Args: |
|
image_features: List of lists of image features tensors. |
|
image_cnts: List of counts of images for each sample in the batch. |
|
|
|
Returns: |
|
List of lists of lengths of visual tokens for each image in each sample. |
|
""" |
|
vision_query_lengths = [ |
|
[image_feature.size(0) for image_feature in image_feature_list] for image_feature_list in image_features |
|
] |
|
|
|
for i, image_cnt in enumerate(image_cnts): |
|
if image_cnt == 0: |
|
assert len(vision_query_lengths[i]) == 1 |
|
vision_query_lengths[i] = [] |
|
|
|
return vision_query_lengths |
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
return self.language_model.get_output_embeddings() |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
|
|
def set_decoder(self, decoder): |
|
self.language_model.set_decoder(decoder) |
|
|
|
|
|
def get_decoder(self): |
|
return self.language_model.get_decoder() |
|
|
|
|
|
def tie_weights(self): |
|
return self.language_model.tie_weights() |
|
|
|
|
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: |
|
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
self.config.text_config.vocab_size = model_embeds.num_embeddings |
|
self.vocab_size = model_embeds.num_embeddings |
|
return model_embeds |
|
|
|
def extract_inputs_embeds( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
non_vision_query_lengths: Optional[List[int]] = None, |
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
first_last_frames_slows: Optional[List[bool]] = None, |
|
is_videos: Optional[List[str]] = None, |
|
): |
|
"""Extract input embeddings by processing text tokens and visual features. |
|
|
|
This method processes the input tokens and image features, extracts the visual features |
|
using the vision model, and combines them with the text token embeddings to create |
|
a unified input representation for the language model. |
|
|
|
Args: |
|
input_ids: Input token IDs with img_start_id markers for image positions. |
|
pixel_values: List of lists of image tensors. |
|
past_key_values: Pre-computed key and value states for faster inference. |
|
image_sizes: List of lists of image dimensions (width, height). |
|
vision_query_lengths: List of lists of lengths when each image is converted to visual tokens. |
|
non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. |
|
img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. |
|
num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid. |
|
num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for |
|
the slow part when applying the slowfast algorithm to video frames. |
|
first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is |
|
applied to the first or last frames of the video. |
|
is_videos: List of booleans indicating which inputs are videos. |
|
|
|
Returns: |
|
Combined embeddings of text tokens and visual features. |
|
""" |
|
inputs_embeds = None |
|
if past_key_values: |
|
pass |
|
else: |
|
|
|
len_pixel_values = [len(pixel_value) for pixel_value in pixel_values] |
|
concat_pixel_values = torch.cat(list(chain(*pixel_values)), dim=0) |
|
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 |
|
|
|
if self.use_no_grad is None: |
|
self.use_no_grad = all(not p.requires_grad for p in self.vision_model.vision_model.encoder.parameters()) |
|
context = torch.no_grad() if self.use_no_grad else contextlib.nullcontext() |
|
with context: |
|
if self.use_no_grad: |
|
|
|
|
|
n_chunks = 1 |
|
else: |
|
n_chunks = 1 |
|
total_len = concat_pixel_values.size(0) |
|
|
|
chunk_size = math.ceil(total_len / n_chunks) if total_len > 0 else 1 |
|
image_forward_outs_chunks = [] |
|
|
|
for i in range(n_chunks): |
|
start = i * chunk_size |
|
end = (i + 1) * chunk_size |
|
|
|
chunk = concat_pixel_values[start:end].to(self.vision_model.dtype) |
|
|
|
if chunk.size(0) < chunk_size: |
|
|
|
pad_size = chunk_size - chunk.size(0) |
|
|
|
dummy_shape = (pad_size,) + tuple(concat_pixel_values.shape[1:]) |
|
dummy = torch.zeros( |
|
dummy_shape, |
|
dtype=concat_pixel_values.dtype, |
|
device=concat_pixel_values.device, |
|
) |
|
chunk = torch.cat([chunk, dummy], dim=0) |
|
|
|
|
|
if self.use_nth_layer == -1: |
|
|
|
self.vision_model.vision_model.post_layernorm = nn.Identity() |
|
outs = self.vision_model(chunk) |
|
outs = outs.last_hidden_state[:, visual_token_idx:] |
|
else: |
|
outs = self.vision_model(chunk, output_hidden_states=True) |
|
outs = outs.hidden_states[self.use_nth_layer][:, visual_token_idx:] |
|
image_forward_outs_chunks.append(outs) |
|
|
|
|
|
image_forward_outs = torch.cat(image_forward_outs_chunks, dim=0).to(image_forward_outs_chunks[0].dtype) |
|
|
|
if num_queries_vis_abstractors is None: |
|
assert num_queries_vis_abstractors_slow is None |
|
image_sizes = list(chain(*image_sizes)) |
|
if is_videos is not None: |
|
is_videos = list(chain(*is_videos)) |
|
group_ids = None |
|
image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) |
|
image_forward_outs = self.mm_projector(image_forward_outs) |
|
else: |
|
|
|
assert isinstance(self.mm_projector, HCXVisionCAbstractor) |
|
|
|
( |
|
num_queries_vis_abstractors, |
|
num_grids, |
|
image_sizes, |
|
is_videos, |
|
group_ids, |
|
) = self.compute_adaptive_params( |
|
pixel_values, |
|
num_queries_vis_abstractors, |
|
num_queries_vis_abstractors_slow, |
|
image_sizes, |
|
is_videos, |
|
first_last_frames_slows, |
|
) |
|
|
|
image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) |
|
image_forward_outs = self.mm_projector( |
|
image_forward_outs, |
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
num_grids=num_grids, |
|
) |
|
|
|
if self.anyres: |
|
split_sizes = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)] |
|
|
|
if num_queries_vis_abstractors is None: |
|
image_features = anyres_postprocessing( |
|
image_forward_outs=image_forward_outs, |
|
split_sizes=split_sizes, |
|
image_sizes=image_sizes, |
|
num_queries_vis_abstractor=self.num_queries_vis_abstractor, |
|
unpad=self.unpad, |
|
is_videos=is_videos, |
|
patch_size=self.vision_model.config.patch_size, |
|
grid_size=self.vision_model.config.image_size, |
|
image_newline=self.image_newline, |
|
possible_resolutions=self.possible_resolutions, |
|
) |
|
else: |
|
image_features = adaptive_anyres_postprocessing( |
|
image_forward_outs=image_forward_outs, |
|
image_sizes=image_sizes, |
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
unpad=self.unpad, |
|
is_videos=is_videos, |
|
grid_size=self.vision_model.config.image_size, |
|
image_newline=self.image_newline, |
|
possible_resolutions=self.possible_resolutions, |
|
group_ids=group_ids, |
|
) |
|
else: |
|
if num_queries_vis_abstractors is None: |
|
image_features = [image_forward_out for image_forward_out in image_forward_outs] |
|
else: |
|
image_features = [image_forward_out.unsqueeze(0) for image_forward_out in image_forward_outs] |
|
|
|
|
|
image_features = [ |
|
image_features[sum(len_pixel_values[:i]) : sum(len_pixel_values[: i + 1])] |
|
for i in range(len(len_pixel_values)) |
|
] |
|
|
|
batch_size = input_ids.size(0) |
|
image_feature_dim = image_features[0][0].size(1) |
|
image_feature_dtype = image_features[0][0].dtype |
|
|
|
if img_start_ids_list is None: |
|
image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist() |
|
else: |
|
image_cnts = [len(img_start_ids) for img_start_ids in img_start_ids_list] |
|
|
|
if non_vision_query_lengths is None: |
|
non_vision_query_lengths = self.determine_non_vision_query_lengths( |
|
input_ids, self.tokenizer.pad_token_id, self.config.img_start_id |
|
) |
|
|
|
if vision_query_lengths is None: |
|
vision_query_lengths = self.determine_vision_query_lengths(image_features, image_cnts) |
|
|
|
|
|
len_inputs_embeds = max( |
|
[ |
|
sum(vision_query_length) + non_vision_query_length |
|
for non_vision_query_length, vision_query_length in zip( |
|
non_vision_query_lengths, vision_query_lengths |
|
) |
|
] |
|
) |
|
len_inputs_embeds = min(self.decoder_max_length, len_inputs_embeds) |
|
|
|
inputs_embeds = torch.zeros( |
|
[batch_size, len_inputs_embeds, image_feature_dim], |
|
dtype=image_feature_dtype, |
|
device=self.device, |
|
requires_grad=True, |
|
).clone() |
|
|
|
temp_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
for batch_idx, sample in enumerate(input_ids): |
|
|
|
non_vision_query_length = non_vision_query_lengths[batch_idx] |
|
|
|
sample = sample[: non_vision_query_length + image_cnts[batch_idx]] |
|
|
|
if image_cnts[batch_idx] == 0: |
|
temp_idx = 0 |
|
|
|
|
|
inputs_embeds[batch_idx, :non_vision_query_length] = temp_embeds[batch_idx][ |
|
:non_vision_query_length |
|
] |
|
inputs_embeds[batch_idx, temp_idx:temp_idx] = image_features[batch_idx][0][ |
|
0:0 |
|
] |
|
else: |
|
if img_start_ids_list is None: |
|
img_start_ids = (sample == self.config.img_start_id).nonzero() |
|
else: |
|
img_start_ids = img_start_ids_list[batch_idx] |
|
assert len(img_start_ids) == image_cnts[batch_idx] == len(image_features[batch_idx]) |
|
|
|
input_start, temp_start = 0, 0 |
|
|
|
|
|
for multi_img_idx, img_start_idx in enumerate(img_start_ids): |
|
|
|
token_len = img_start_idx - temp_start |
|
|
|
|
|
inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[ |
|
batch_idx, temp_start : temp_start + token_len |
|
] |
|
|
|
inputs_embeds[ |
|
batch_idx, |
|
input_start |
|
+ token_len : input_start |
|
+ token_len |
|
+ vision_query_lengths[batch_idx][multi_img_idx], |
|
] = image_features[batch_idx][multi_img_idx] |
|
|
|
|
|
input_start += token_len + vision_query_lengths[batch_idx][multi_img_idx] |
|
temp_start += token_len + 1 |
|
|
|
|
|
token_len = min(sample[temp_start:].size(0), inputs_embeds.size(1) - input_start) |
|
inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[ |
|
batch_idx, temp_start : temp_start + token_len |
|
] |
|
return inputs_embeds |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
non_vision_query_lengths: Optional[List[int]] = None, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
first_last_frames_slows: Optional[List[bool]] = None, |
|
is_videos: Optional[List[bool]] = None, |
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
bad_words_ids: Optional[List[List[int]]] = None, |
|
max_length: int = 196, |
|
min_length: int = 2, |
|
do_sample: bool = True, |
|
num_beams: int = 1, |
|
top_p: float = 0.6, |
|
top_k: int = 0, |
|
temperature: float = 0.5, |
|
repetition_penalty: float = 1.0, |
|
length_penalty: int = 1, |
|
use_cache: bool = True, |
|
**kwargs, |
|
) -> torch.LongTensor: |
|
"""Generate text based on input tokens and images. |
|
|
|
This method generates text based on the provided input tokens and images using |
|
beam search and/or sampling strategies. |
|
|
|
Args: |
|
input_ids: Input token IDs with img_start_id markers for image positions. |
|
pixel_values: List of lists of image tensors. |
|
image_sizes: List of lists of image dimensions (width, height). |
|
vision_query_lengths: List of lists of lengths when each image is converted to visual tokens. |
|
non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. |
|
num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid. |
|
num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for the slow part when |
|
applying the slowfast algorithm to video frames. |
|
first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is applied to the first |
|
or last frames of the video. |
|
is_videos: List of booleans indicating which inputs are videos. |
|
img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. |
|
pad_token_id: Token ID used for padding. |
|
eos_token_id: Token ID used to signal the end of a sequence. |
|
bad_words_ids: List of token ID sequences that should not be generated. |
|
max_length: Maximum length of the sequence to be generated (input length + max_new_tokens). |
|
min_length: Minimum length of the sequence to be generated (input length + min_new_tokens). |
|
do_sample: Whether to use sampling for generation (otherwise uses greedy decoding). |
|
num_beams: Number of beams for beam search. 1 means no beam search. |
|
top_p: Nucleus sampling parameter. Tokens with cumulative probability > top_p are kept. |
|
top_k: Number of highest probability tokens to keep for top-k-filtering. |
|
temperature: Value used to modulate the next token probabilities. |
|
repetition_penalty: Penalty applied to tokens that have already appeared in the sequence. |
|
length_penalty: Exponential penalty applied to sequence length. |
|
use_cache: Whether to use past key/values for faster inference. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
Generated token IDs. |
|
""" |
|
|
|
if pad_token_id is None: |
|
pad_token_id = self.tokenizer.pad_token_id |
|
if eos_token_id is None: |
|
eos_token_id = self.tokenizer.encode("<|endofturn|>")[0] |
|
if bad_words_ids is None: |
|
bad_words_ids = [ |
|
[ |
|
self.config.language_config["bos_token_id"], |
|
], |
|
[ |
|
self.config.language_config["eos_token_id"], |
|
], |
|
] |
|
|
|
if pixel_values is None: |
|
return self.language_model.generate( |
|
input_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, bad_words_ids=bad_words_ids, **kwargs |
|
) |
|
inputs_embeds = self.extract_inputs_embeds( |
|
input_ids=input_ids, |
|
pixel_values=self.to_vision_model_device(pixel_values), |
|
image_sizes=image_sizes, |
|
vision_query_lengths=vision_query_lengths, |
|
non_vision_query_lengths=non_vision_query_lengths, |
|
img_start_ids_list=img_start_ids_list, |
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
first_last_frames_slows=first_last_frames_slows, |
|
is_videos=is_videos, |
|
) |
|
inputs_embeds = ( |
|
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds |
|
) |
|
|
|
|
|
pred = self.language_model.generate( |
|
inputs_embeds=inputs_embeds, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
bad_words_ids=bad_words_ids, |
|
max_new_tokens=max_length, |
|
min_length=min_length, |
|
num_beams=num_beams, |
|
do_sample=(False if temperature == 0.0 else do_sample), |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
early_stopping=(False if num_beams <= 1 else True), |
|
use_cache=use_cache, |
|
) |
|
|
|
return pred |
|
|
|
def to_vision_model_device(self, input_tensor: Union[torch.Tensor, List]) -> Union[torch.Tensor, List]: |
|
"""Move input tensors to the vision model's device. |
|
This method recursively moves input tensors or lists of tensors to the vision model's device. |
|
|
|
Args: |
|
input_tensor: Input tensor or list of tensors to be moved to the vision model's device. |
|
|
|
Returns: |
|
The input tensor or list of tensors moved to the vision model's device. |
|
|
|
Raises: |
|
TypeError: If the input is neither a tensor nor a list. |
|
""" |
|
if isinstance(input_tensor, list): |
|
return [self.to_vision_model_device(item) for item in input_tensor] |
|
elif isinstance(input_tensor, torch.Tensor): |
|
return input_tensor.to(self.vision_model.device) |
|
else: |
|
raise TypeError("Unsupported data type. Only tensors and lists are allowed.") |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
**kwargs, |
|
) -> Dict[str, Any]: |
|
"""Prepare inputs for the generation algorithm. |
|
|
|
This method prepares the input for each generation step based on the model's needs. |
|
|
|
Args: |
|
input_ids: Input token IDs. |
|
past_key_values: Pre-computed key and value states for faster inference. |
|
attention_mask: Mask to avoid performing attention on padding token indices. |
|
inputs_embeds: Input embeddings. If provided, input_ids will not be used. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
Dictionary containing the prepared inputs for the model. |
|
""" |
|
input_ids = kwargs.get("decoder_input_ids", input_ids) |
|
|
|
if past_key_values: |
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
"pixel_values": kwargs.get("pixel_values", None), |
|
} |
|
) |
|
return model_inputs |
|
|
|
@classmethod |
|
def from_config(cls, config, vision_model_name_or_path): |
|
return cls(config, vision_model_name_or_path) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
|
*model_args, |
|
**kwargs, |
|
) -> "HCXVisionForCausalLM": |
|
assert pretrained_model_name_or_path is not None |
|
|
|
save_only_vision = kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False |
|
save_only_qformer = kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False |
|
save_shard_size = kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB" |
|
|
|
if pretrained_model_name_or_path is not None: |
|
model: HCXVisionForCausalLM = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
|
|
|
img_start_id = model.tokenizer.encode(IMG_LOC, add_special_tokens=False) |
|
assert ( |
|
len(img_start_id) == 1 |
|
), f'"<|dummy3|>" was not encoded into a single special token. Encoding result: {img_start_id}' |
|
model.config.img_start_id = img_start_id[0] |
|
|
|
model.save_only_vision = save_only_vision |
|
model.save_only_qformer = save_only_qformer |
|
model.save_shard_size = save_shard_size |
|
|
|
return model |
|
|
|
def get_language_model(self): |
|
return self.language_model.base_model |
|
|
|
def get_vision_model(self): |
|
return self.vision_model |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: Union[str, os.PathLike], |
|
*args, |
|
**kwargs, |
|
): |
|
state_dict = kwargs["state_dict"] if "state_dict" in kwargs else self.state_dict() |
|
partial_state_dict = self.get_pretrained_state_dict( |
|
state_dict, |
|
save_directory, |
|
) |
|
kwargs["state_dict"] = partial_state_dict |
|
kwargs["safe_serialization"] = self.is_safetensor_save |
|
kwargs.setdefault("max_shard_size", self.save_shard_size) |
|
super().save_pretrained(save_directory, *args, **kwargs) |
|
|
|
def get_pretrained_state_dict(self, state_dict, save_dir): |
|
vision_key = "vision_model." |
|
llm_keys = ["language_model."] |
|
head_key = "lm_head." |
|
|
|
for key in list(state_dict.keys()): |
|
if self.save_only_vision: |
|
for llm_key in llm_keys: |
|
if llm_key in key: |
|
state_dict.pop(key) |
|
if key.startswith(head_key): |
|
state_dict.pop(key) |
|
|
|
elif self.save_only_qformer: |
|
if f"{vision_key}" in key: |
|
state_dict.pop(key) |
|
|
|
return state_dict |
|
|
|
def compute_adaptive_params( |
|
self, |
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
is_videos: Optional[List[bool]] = None, |
|
first_last_frames_slows: Optional[List[bool]] = None, |
|
) -> Tuple[List[int], List[int], List[List[int]], List[bool], List[List[int]]]: |
|
"""Compute adaptive parameters for processing different image and video inputs. |
|
|
|
This method calculates parameters needed for adaptive processing, especially when handling |
|
variable resolutions or applying the slowfast algorithm to video frames. It flattens |
|
batch-level inputs (lists of lists) into single lists representing all images/frames |
|
in the batch. Based on slowfast configuration, it may split video frames into 'slow' |
|
and 'fast' components, adjusting query counts and grid indices accordingly. |
|
|
|
Args: |
|
pixel_values: List of lists of image tensors (per sample). Used to determine the initial number of grids per |
|
image/frame. |
|
num_queries_vis_abstractors: List of lists (per sample) containing the base number of visual tokens |
|
generated by the visual abstractor for each image grid |
|
(e.g., 81 for a full grid, 9 for a subsampled/fast grid). |
|
num_queries_vis_abstractors_slow: List of lists (per sample) containing the number of visual tokens for the |
|
'slow' path when applying slowfast. Non-zero values here trigger the slowfast processing logic. |
|
image_sizes: List of lists (per sample) of original image dimensions ([width, height]). |
|
is_videos: List of lists (per sample) of booleans indicating if each input item is part of a video sequence. |
|
first_last_frames_slows: List (per sample) of booleans. If True, slowfast logic |
|
(if active based on `num_queries_vis_abstractors_slow`) is applied only to the first or last frame(s) |
|
within each video sequence. |
|
|
|
Returns: |
|
Tuple containing: |
|
- num_queries_vis_abstractors: Flattened list of final query counts per processed grid. |
|
Values might be adjusted based on slow/fast splitting |
|
(e.g., using values from `num_queries_vis_abstractors_slow` for slow frames). |
|
Example: [81, 81, 81, 9, 81, 9, ...] (Image, Image, Vid_Slow, Vid_Fast, Vid_Slow, Vid_Fast...) |
|
- num_grids: Flattened list representing cumulative grid counts, acting as end indices for slicing the |
|
flattened `image_forward_outs`. Adjusted for slow/fast splits. |
|
Example: [0, 1, 9, 10, 18, 19, 27, ...] (Indices after Grid0_Slow(1), |
|
Grid1_Fast(8), Grid2_Slow(1), Grid3_Fast(8)...). |
|
- image_sizes: Flattened list of image dimensions ([width, height]), potentially duplicated if slow/fast |
|
splitting occurred. |
|
- is_videos: Flattened list of booleans indicating video status, potentially duplicated for |
|
slow/fast splits. Example: [False, False, True, True, True, True, ...] |
|
(Image1, Image2, Vid_grid1_slow, Vid_grid1_fast, Vid_grid2_slow, Vid_grid2_fast...) |
|
- group_ids: List of lists, grouping indices that correspond to the same original image or frame. |
|
If a frame is split into slow/fast, its group will contain multiple indices. |
|
Example: [[0], [1], [2, 3], [4, 5], ...] |
|
(Group for Image1, Group for Image2, Group for Vid1_Slow+Fast, Group for Vid2_Slow+Fast...). |
|
|
|
Raises: |
|
AssertionError: If input validation fails (e.g., negative query counts). |
|
Exception: If an unexpected case is encountered during slowfast processing. |
|
""" |
|
|
|
|
|
assert all( |
|
all(isinstance(value, int) and value >= 0 for value in sublist) for sublist in num_queries_vis_abstractors |
|
), "All values in num_queries_vis_abstractors must be integers >= 0." |
|
|
|
assert all( |
|
all(isinstance(value, int) and value >= 0 for value in sublist) |
|
for sublist in num_queries_vis_abstractors_slow |
|
), "All values in num_queries_vis_abstractors_slow must be integers >= 0." |
|
|
|
assert is_videos is not None |
|
|
|
|
|
is_first_images = [] |
|
is_last_images = [] |
|
for is_video in is_videos: |
|
for idx, is_video_item in enumerate(is_video): |
|
if idx == 0: |
|
is_first_images.append(True) |
|
else: |
|
is_first_images.append(False) |
|
if idx == len(is_video) - 1: |
|
is_last_images.append(True) |
|
else: |
|
is_last_images.append(False) |
|
|
|
num_queries_vis_abstractors = list(chain(*num_queries_vis_abstractors)) |
|
num_queries_vis_abstractors_slow = list(chain(*num_queries_vis_abstractors_slow)) |
|
image_sizes = list(chain(*image_sizes)) |
|
is_videos = list(chain(*is_videos)) |
|
first_last_frames_slows = list(chain(*first_last_frames_slows)) |
|
|
|
|
|
use_slowfast = any([num_query > 0 for num_query in num_queries_vis_abstractors_slow]) |
|
num_grids = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)] |
|
num_grids = [0] + num_grids |
|
group_ids = [] |
|
|
|
if use_slowfast: |
|
new_num_grids = [num_grids[0]] |
|
new_num_queries = [] |
|
new_image_sizes = [] |
|
new_is_videos = [] |
|
|
|
|
|
|
|
for ( |
|
num_query, |
|
num_query_slow, |
|
num_grid, |
|
image_size, |
|
is_video, |
|
first_last_frames_slow, |
|
is_first_image, |
|
is_last_image, |
|
) in zip( |
|
num_queries_vis_abstractors, |
|
num_queries_vis_abstractors_slow, |
|
num_grids[1:], |
|
image_sizes, |
|
is_videos, |
|
first_last_frames_slows, |
|
is_first_images, |
|
is_last_images, |
|
): |
|
|
|
if not first_last_frames_slow and num_query_slow > 0: |
|
assert is_video |
|
|
|
this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0] |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
new_num_queries.append(num_query_slow) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
|
|
if num_grid >= 2: |
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid - 1) |
|
new_num_queries.append(num_query) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
this_group_ids.append(this_group_ids[-1] + 1) |
|
|
|
group_ids.append(this_group_ids) |
|
elif ( |
|
first_last_frames_slow and num_query_slow > 0 and (is_first_image or is_last_image) |
|
): |
|
|
|
assert is_video |
|
|
|
this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0] |
|
|
|
if num_grid == 1: |
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
new_num_queries.append(num_query_slow) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
|
|
if num_grid >= 2: |
|
|
|
|
|
if is_first_image: |
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
new_num_queries.append(num_query_slow) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid - 1) |
|
new_num_queries.append(num_query) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
this_group_ids.append(this_group_ids[-1] + 1) |
|
elif is_last_image: |
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid - 1) |
|
new_num_queries.append(num_query) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
new_num_queries.append(num_query_slow) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
this_group_ids.append(this_group_ids[-1] + 1) |
|
else: |
|
raise Exception("This case should not be reached.") |
|
group_ids.append(this_group_ids) |
|
else: |
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid) |
|
new_num_queries.append(num_query) |
|
new_image_sizes.append(image_size) |
|
new_is_videos.append(is_video) |
|
|
|
start_group_id = group_ids[-1][-1] + 1 if group_ids else 0 |
|
group_ids.append([start_group_id]) |
|
|
|
num_grids = new_num_grids |
|
num_queries_vis_abstractors = new_num_queries |
|
image_sizes = new_image_sizes |
|
is_videos = new_is_videos |
|
else: |
|
num_grids = [sum(num_grids[:i]) for i in range(1, len(num_grids) + 1)] |
|
group_ids = [[group_id] for group_id in range(len(is_videos))] |
|
|
|
return num_queries_vis_abstractors, num_grids, image_sizes, is_videos, group_ids |
|
|
|
|
|
def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""): |
|
|
|
|
|
old_keys = [] |
|
new_keys = [] |
|
for key in state_dict.keys(): |
|
new_key = None |
|
if "gamma" in key: |
|
new_key = key.replace("gamma", "weight") |
|
if "beta" in key: |
|
new_key = key.replace("beta", "bias") |
|
if new_key: |
|
old_keys.append(key) |
|
new_keys.append(new_key) |
|
for old_key, new_key in zip(old_keys, new_keys): |
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
metadata = getattr(state_dict, "_metadata", None) |
|
state_dict = state_dict.copy() |
|
if metadata is not None: |
|
state_dict._metadata = metadata |
|
|
|
error_msgs = [] |
|
|
|
|
|
|
|
def load(module: nn.Module, state_dict, prefix=""): |
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
|
args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs) |
|
|
|
|
|
if len([key for key in state_dict if key.startswith(prefix)]) > 0: |
|
if is_deepspeed_zero3_enabled(): |
|
import deepspeed |
|
|
|
|
|
|
|
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) |
|
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] |
|
if len(params_to_gather) > 0: |
|
|
|
|
|
|
|
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): |
|
if torch.distributed.get_rank() == 0: |
|
module._load_from_state_dict(*args) |
|
else: |
|
module._load_from_state_dict(*args) |
|
|
|
for name, child in module._modules.items(): |
|
if child is not None: |
|
load(child, state_dict, prefix + name + ".") |
|
|
|
load(model_to_load, state_dict, prefix=start_prefix) |
|
|
|
|
|
del state_dict |
|
|
|
return error_msgs |
|
|
|
|
|
class HCXVisionCAbstractor(nn.Module): |
|
""" |
|
This module is based on C-Abstractor, whose license is under apache-2.0. |
|
You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py |
|
and we made necessary modifications. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_queries: int, |
|
num_input_tokens: int, |
|
encoder_hidden_size: int, |
|
hidden_size: int, |
|
output_hidden_size: int, |
|
pos_emb: bool = True, |
|
prenorm: bool = False, |
|
): |
|
super().__init__() |
|
self.num_input_tokens = num_input_tokens |
|
self.output_hidden_size = output_hidden_size |
|
|
|
|
|
if pos_emb: |
|
self.pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size)) |
|
self.pos_emb.data.normal_(mean=0.0, std=0.02) |
|
else: |
|
self.pos_emb = None |
|
|
|
|
|
if prenorm: |
|
self.prenorm = LayerNorm(encoder_hidden_size) |
|
else: |
|
self.prenorm = None |
|
|
|
self.build_net(num_queries, encoder_hidden_size, hidden_size, output_hidden_size) |
|
self.dtype = next(self.parameters()).dtype |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_grids: Optional[List[int]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: (B, L, encoder_hidden_size) tensor from the visual backbone (e.g. CLIP visual encoder), including cls token. |
|
""" |
|
if self.prenorm is not None: |
|
x = self.prenorm(x) |
|
|
|
if self.pos_emb is not None: |
|
x = x + self.pos_emb |
|
|
|
x = self._forward( |
|
x, |
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
num_grids=num_grids, |
|
) |
|
|
|
return x |
|
|
|
def _forward( |
|
self, |
|
x: torch.Tensor, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_grids: Optional[List[int]] = None, |
|
) -> torch.Tensor: |
|
|
|
B, L, dim = x.shape |
|
hw = int(L ** 0.5) |
|
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) |
|
|
|
if num_queries_vis_abstractors is not None: |
|
assert num_grids is not None |
|
return self._forward_adaptive_num_query(x, num_queries_vis_abstractors, num_grids) |
|
|
|
x = self.net(x) |
|
x = rearrange(x, "b d h w -> b (h w) d") |
|
x = self.readout(x) |
|
return x |
|
|
|
def _forward_adaptive_num_query( |
|
self, |
|
x: torch.Tensor, |
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
num_grids: Optional[List[int]] = None, |
|
) -> List[torch.Tensor]: |
|
|
|
assert len(self.net) == 3 |
|
|
|
x = self.net[0](x) |
|
new_x = [] |
|
for i, num_queries in enumerate(num_queries_vis_abstractors): |
|
hw = int(num_queries**0.5) |
|
sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
|
out = sampler(x[num_grids[i]:num_grids[i + 1], :]) |
|
out = self.net[2](out) |
|
|
|
out = rearrange(out, "b d h w -> b (h w) d") |
|
out = self.readout(out) |
|
|
|
new_x.append(out) |
|
return new_x |
|
|
|
def build_net( |
|
self, |
|
n_queries: int, |
|
encoder_hidden_size: int, |
|
hidden_size: int, |
|
output_hidden_size: int, |
|
depth: int = 3, |
|
mlp_depth: int = 2, |
|
): |
|
assert (n_queries ** 0.5).is_integer(), f"n_queries must be square number. n_queries: {n_queries}" |
|
hw = int(n_queries ** 0.5) |
|
|
|
|
|
RegBlock = partial( |
|
RegStage, |
|
stride=1, |
|
dilation=1, |
|
act_layer=nn.SiLU, |
|
norm_layer=LayerNorm2d, |
|
) |
|
|
|
s1 = RegBlock( |
|
depth, |
|
encoder_hidden_size, |
|
hidden_size, |
|
) |
|
sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
|
s2 = RegBlock( |
|
depth, |
|
hidden_size, |
|
hidden_size, |
|
) |
|
|
|
self.net = nn.Sequential(s1, sampler, s2) |
|
self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) |
|
|
|
def build_mlp( |
|
self, |
|
depth: int, |
|
hidden_size: int, |
|
output_hidden_size: int, |
|
): |
|
layers = [nn.Linear(hidden_size, output_hidden_size)] |
|
for _ in range(1, depth): |
|
layers.append(nn.SiLU()) |
|
layers.append(nn.Linear(output_hidden_size, output_hidden_size)) |
|
return nn.Sequential(*layers) |
|
|
|
def load_sharded_checkpoint( |
|
model, folder, pick_prefix="", replace_prefix_list=[], replace_prefix_dict={}, print_info=True |
|
): |
|
if folder is None: |
|
return {} |
|
|
|
files = os.listdir(folder) |
|
|
|
|
|
pytorch_bin_files = [file for file in files if file.startswith("pytorch_model") and file.endswith(".bin")] |
|
safetensor_files = [file for file in files if file.endswith(".safetensors")] |
|
shard_index_file = [file for file in files if file.endswith(".index.json")] |
|
|
|
|
|
index_present = len(shard_index_file) > 0 |
|
index_file = os.path.join(folder, shard_index_file[0]) if index_present else [] |
|
|
|
|
|
is_safetensor = len(safetensor_files) > 0 |
|
|
|
model_keys = model.state_dict().keys() |
|
|
|
if is_safetensor: |
|
from safetensors.torch import load_file |
|
|
|
load_function = load_file |
|
shard_files = safetensor_files |
|
else: |
|
load_function = partial(torch.load, map_location="cpu") |
|
shard_files = pytorch_bin_files |
|
|
|
|
|
if index_present: |
|
with open(index_file, "r", encoding="utf-8") as f: |
|
index = json.load(f) |
|
loaded_keys = index["weight_map"].keys() |
|
if pick_prefix: |
|
loaded_keys = [k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix)] |
|
if replace_prefix_list: |
|
for rep_prefix in replace_prefix_list: |
|
loaded_keys = [k[len(rep_prefix) :] if k.startswith(rep_prefix) else k for k in loaded_keys] |
|
if replace_prefix_dict: |
|
for rep_prefix in replace_prefix_dict: |
|
loaded_keys = [ |
|
k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k |
|
for k in loaded_keys |
|
] |
|
|
|
for i, shard_file in enumerate(shard_files): |
|
state_dict = load_function(os.path.join(folder, shard_file)) |
|
|
|
|
|
if pick_prefix: |
|
state_dict = {k[len(pick_prefix) :]: v for k, v in state_dict.items() if k.startswith(pick_prefix)} |
|
|
|
for rep_prefix in replace_prefix_list: |
|
state_dict = {k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v for k, v in state_dict.items()} |
|
|
|
for rep_prefix in replace_prefix_dict: |
|
state_dict = { |
|
k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k: v |
|
for k, v in state_dict.items() |
|
} |
|
|
|
if is_deepspeed_zero3_enabled(): |
|
|
|
rank = torch.distributed.get_rank() |
|
print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}") |
|
load_state_dict_into_model(model, state_dict, strict=False) |
|
elif is_fsdp_enabled(): |
|
if is_local_dist_rank_0(): |
|
model.load_state_dict(state_dict, strict=False) |
|
else: |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
if not index_present: |
|
loaded_keys = state_dict.keys() |
|
|
|
del state_dict |
|
gc.collect() |
|
|
|
|
|
missing_keys = [key for key in model_keys if key not in loaded_keys] |
|
unexpected_keys = [key for key in loaded_keys if key not in model_keys] |
|
|
|
if get_rank() == 0 and print_info: |
|
print(f"[info] missing_keys: {missing_keys}") |
|
print(f"[info] unexpected_keys: {unexpected_keys}") |
|
|
|
return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} |
|
|