|
from typing import List, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from transformers import ( |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
SiglipVisionConfig, |
|
SiglipVisionModel, |
|
XLMRobertaConfig, |
|
XLMRobertaModel, |
|
) |
|
|
|
|
|
class MexmaSigLIPConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
optimized: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.optimized = optimized |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, hidden_size: int, intermediate_size: int): |
|
super().__init__() |
|
self.fc1 = nn.Linear(hidden_size, intermediate_size) |
|
self.fc2 = nn.Linear(intermediate_size, hidden_size) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = nn.SiLU()(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
return hidden_states |
|
|
|
class MultiheadAttentionPoolingHead(nn.Module): |
|
def __init__(self, hidden_size: int, out_hidden_size: int, num_attention_heads: int, layer_norm_eps: float, intermediate_size: int): |
|
super().__init__() |
|
|
|
self.probe = nn.Parameter(torch.randn(1, 1, hidden_size)) |
|
self.attention = torch.nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True) |
|
self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
|
self.mlp = MLP(hidden_size, intermediate_size) |
|
self.projector = nn.Linear(hidden_size, out_hidden_size) |
|
|
|
def forward(self, hidden_state): |
|
batch_size = hidden_state.shape[0] |
|
probe = self.probe.repeat(batch_size, 1, 1) |
|
|
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0] |
|
|
|
residual = hidden_state |
|
hidden_state = self.layernorm(hidden_state) |
|
hidden_state = residual + self.mlp(hidden_state) |
|
hidden_state = self.projector(hidden_state) |
|
return hidden_state[:, 0] |
|
|
|
|
|
class MexmaSigLIP(PreTrainedModel): |
|
config_class = MexmaSigLIPConfig |
|
|
|
def __init__(self, config: MexmaSigLIPConfig): |
|
super().__init__(config) |
|
self.config = config |
|
text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA") |
|
if self.config.optimized: |
|
text_config._attn_implementation = "sdpa" |
|
self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False) |
|
self.text_projector = MultiheadAttentionPoolingHead(1024, 1152, 16, 1e-5, 4304) |
|
vision_congig = SiglipVisionConfig.from_pretrained( |
|
"google/siglip2-so400m-patch16-512" |
|
) |
|
if self.config.optimized: |
|
vision_congig._attn_implementation = "flash_attention_2" |
|
vision_congig.torch_dtype = "bfloat16" |
|
self.vision_model = SiglipVisionModel(vision_congig).vision_model |
|
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10) |
|
|
|
def forward(self, image_inputs, input_ids, attention_mask, normalize=False): |
|
text_features = self.encode_texts(input_ids, attention_mask, normalize) |
|
image_features = self.encode_images(image_inputs, normalize) |
|
return { |
|
"image_features": image_features, |
|
"text_features": text_features, |
|
"logit_scale": self.logit_scale, |
|
"logit_bias": self.logit_bias, |
|
} |
|
|
|
def encode_images( |
|
self, |
|
pixel_values, |
|
normalize=False, |
|
): |
|
features = self.vision_model(pixel_values).pooler_output |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def encode_texts( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
normalize=False, |
|
): |
|
features = self.text_model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
).last_hidden_state |
|
features = self.text_projector(features) |
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
def get_logits( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
pixel_values, |
|
): |
|
image_features = self.encode_images(pixel_values, normalize=True) |
|
text_features = self.encode_texts(input_ids, attention_mask, normalize=True) |
|
image_logits = ( |
|
self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias |
|
) |
|
text_logits = image_logits.T |
|
return image_logits, text_logits |
|
|