Spaces:
Runtime error
Runtime error
# Copyright (c) 2023-2024 DeepSeek. | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy of | |
# this software and associated documentation files (the "Software"), to deal in | |
# the Software without restriction, including without limitation the rights to | |
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | |
# the Software, and to permit persons to whom the Software is furnished to do so, | |
# subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | |
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | |
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | |
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
from typing import Dict, List, Literal, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms | |
from einops import rearrange | |
from janus.models.siglip_vit import create_siglip_vit | |
class CLIPVisionTower(nn.Module): | |
def __init__( | |
self, | |
model_name: str = "siglip_large_patch16_384", | |
image_size: Union[Tuple[int, int], int] = 336, | |
select_feature: str = "patch", | |
select_layer: int = -2, | |
select_layers: list = None, | |
ckpt_path: str = "", | |
pixel_mean: Optional[List[float]] = None, | |
pixel_std: Optional[List[float]] = None, | |
**kwargs, | |
): | |
super().__init__() | |
self.model_name = model_name | |
self.select_feature = select_feature | |
self.select_layer = select_layer | |
self.select_layers = select_layers | |
vision_tower_params = { | |
"model_name": model_name, | |
"image_size": image_size, | |
"ckpt_path": ckpt_path, | |
"select_layer": select_layer, | |
} | |
vision_tower_params.update(kwargs) | |
self.vision_tower, self.forward_kwargs = self.build_vision_tower( | |
vision_tower_params | |
) | |
if pixel_mean is not None and pixel_std is not None: | |
image_norm = torchvision.transforms.Normalize( | |
mean=pixel_mean, std=pixel_std | |
) | |
else: | |
image_norm = None | |
self.image_norm = image_norm | |
def build_vision_tower(self, vision_tower_params): | |
if self.model_name.startswith("siglip"): | |
self.select_feature = "same" | |
vision_tower = create_siglip_vit(**vision_tower_params) | |
forward_kwargs = dict() | |
elif self.model_name.startswith("sam"): | |
vision_tower = create_sam_vit(**vision_tower_params) | |
forward_kwargs = dict() | |
else: # huggingface | |
from transformers import CLIPVisionModel | |
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) | |
forward_kwargs = dict(output_hidden_states=True) | |
return vision_tower, forward_kwargs | |
def feature_select(self, image_forward_outs): | |
if isinstance(image_forward_outs, torch.Tensor): | |
# the output has been the self.select_layer"s features | |
image_features = image_forward_outs | |
else: | |
image_features = image_forward_outs.hidden_states[self.select_layer] | |
if self.select_feature == "patch": | |
# if the output has cls_token | |
image_features = image_features[:, 1:] | |
elif self.select_feature == "cls_patch": | |
image_features = image_features | |
elif self.select_feature == "same": | |
image_features = image_features | |
else: | |
raise ValueError(f"Unexpected select feature: {self.select_feature}") | |
return image_features | |
def forward(self, images): | |
""" | |
Args: | |
images (torch.Tensor): [b, 3, H, W] | |
Returns: | |
image_features (torch.Tensor): [b, n_patch, d] | |
""" | |
if self.image_norm is not None: | |
images = self.image_norm(images) | |
image_forward_outs = self.vision_tower(images, **self.forward_kwargs) | |
image_features = self.feature_select(image_forward_outs) | |
return image_features | |