mexma-siglip2 / mexma_siglip.py
visheratin's picture
Upload folder using huggingface_hub
f01b1a7 verified
from typing import List, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (
PretrainedConfig,
PreTrainedModel,
SiglipVisionConfig,
SiglipVisionModel,
XLMRobertaConfig,
XLMRobertaModel,
)
class MexmaSigLIPConfig(PretrainedConfig):
def __init__(
self,
optimized: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.optimized = optimized
class MLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = nn.SiLU()(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MultiheadAttentionPoolingHead(nn.Module):
def __init__(self, hidden_size: int, out_hidden_size: int, num_attention_heads: int, layer_norm_eps: float, intermediate_size: int):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, hidden_size))
self.attention = torch.nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.mlp = MLP(hidden_size, intermediate_size)
self.projector = nn.Linear(hidden_size, out_hidden_size)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
hidden_state = self.projector(hidden_state)
return hidden_state[:, 0]
class MexmaSigLIP(PreTrainedModel):
config_class = MexmaSigLIPConfig
def __init__(self, config: MexmaSigLIPConfig):
super().__init__(config)
self.config = config
text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA")
if self.config.optimized:
text_config._attn_implementation = "sdpa"
self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False)
self.text_projector = MultiheadAttentionPoolingHead(1024, 1152, 16, 1e-5, 4304)
vision_congig = SiglipVisionConfig.from_pretrained(
"google/siglip2-so400m-patch16-512"
)
if self.config.optimized:
vision_congig._attn_implementation = "flash_attention_2"
vision_congig.torch_dtype = "bfloat16"
self.vision_model = SiglipVisionModel(vision_congig).vision_model
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10)
def forward(self, image_inputs, input_ids, attention_mask, normalize=False):
text_features = self.encode_texts(input_ids, attention_mask, normalize)
image_features = self.encode_images(image_inputs, normalize)
return {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale,
"logit_bias": self.logit_bias,
}
def encode_images(
self,
pixel_values,
normalize=False,
):
features = self.vision_model(pixel_values).pooler_output
return F.normalize(features, dim=-1) if normalize else features
def encode_texts(
self,
input_ids,
attention_mask,
normalize=False,
):
features = self.text_model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
features = self.text_projector(features)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(
self,
input_ids,
attention_mask,
pixel_values,
):
image_features = self.encode_images(pixel_values, normalize=True)
text_features = self.encode_texts(input_ids, attention_mask, normalize=True)
image_logits = (
self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias
)
text_logits = image_logits.T
return image_logits, text_logits