Spaces:
Runtime error
Runtime error
# Step 2: Import necessary libraries | |
import gradio as gr | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from peft import PeftConfig, PeftModel | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers.cache_utils import DynamicCache, StaticCache | |
# Step 3: Set device and default dtype | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.set_default_dtype(torch.float32 if DEVICE.type == "cpu" else torch.float16) | |
# Step 4: Load CLIP model and processor | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16).to(DEVICE) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) | |
# Step 5: Define the MultiModalModel class | |
class MultiModalModel(nn.Module): | |
def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32"): | |
super().__init__() | |
self.phi = None # Will be set after loading the PEFT model | |
self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) | |
self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}) | |
self.clip = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16).eval().to(DEVICE) | |
image_embedding_dim = self.clip.config.projection_dim | |
phi_hidden_size = 3072 # Hardcoded for Phi-3 mini | |
self.image_projection = nn.Sequential( | |
nn.Linear(image_embedding_dim, phi_hidden_size, dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16), | |
nn.LayerNorm(phi_hidden_size, dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16), | |
nn.Dropout(0.1) | |
).to(DEVICE) | |
nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0) | |
nn.init.zeros_(self.image_projection[0].bias) | |
def forward(self, text_input_ids, attention_mask=None, image_embedding=None): | |
image_embedding = torch.clamp(image_embedding, min=-1e4, max=1e4) | |
image_embedding = F.normalize(image_embedding, dim=-1, eps=1e-5).to(torch.float32 if DEVICE.type == "cpu" else torch.float16) | |
with torch.no_grad(): | |
self.image_projection[0].weight.clamp_(-1.0, 1.0) | |
self.image_projection[0].bias.clamp_(-1.0, 1.0) | |
projected_image = 1.0 * self.image_projection(image_embedding) | |
projected_image = torch.clamp(projected_image, min=-1e4, max=1e4) | |
if torch.isnan(projected_image).any() or torch.isinf(projected_image).any(): | |
print("Warning: Projected image contains NaN or Inf values after clamping, replacing with zeros") | |
projected_image = torch.where( | |
torch.logical_or(torch.isnan(projected_image), torch.isinf(projected_image)), | |
torch.zeros_like(projected_image), | |
projected_image | |
) | |
if projected_image.dim() == 2: | |
projected_image = projected_image.unsqueeze(1) | |
text_embeddings = self.phi.get_input_embeddings()(text_input_ids) | |
fused_embeddings = text_embeddings.clone() | |
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") | |
img_token_mask = (text_input_ids == img_token_id) | |
for i in range(fused_embeddings.shape[0]): | |
img_positions = img_token_mask[i].nonzero(as_tuple=True)[0] | |
if img_positions.numel() > 0: | |
fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :] | |
if torch.isnan(fused_embeddings).any() or torch.isinf(fused_embeddings).any(): | |
print("Warning: Fused embeddings contain NaN or Inf values, replacing with zeros") | |
fused_embeddings = torch.where( | |
torch.logical_or(torch.isnan(fused_embeddings), torch.isinf(fused_embeddings)), | |
torch.zeros_like(fused_embeddings), | |
fused_embeddings | |
) | |
return fused_embeddings | |