seemggoel's picture
Update model.py
6bcdcce verified
# Step 2: Import necessary libraries
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import DynamicCache, StaticCache
# Step 3: Set device and default dtype
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32 if DEVICE.type == "cpu" else torch.float16)
# Step 4: Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16).to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
# Step 5: Define the MultiModalModel class
class MultiModalModel(nn.Module):
def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32"):
super().__init__()
self.phi = None # Will be set after loading the PEFT model
self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"})
self.clip = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16).eval().to(DEVICE)
image_embedding_dim = self.clip.config.projection_dim
phi_hidden_size = 3072 # Hardcoded for Phi-3 mini
self.image_projection = nn.Sequential(
nn.Linear(image_embedding_dim, phi_hidden_size, dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16),
nn.LayerNorm(phi_hidden_size, dtype=torch.float32 if DEVICE.type == "cpu" else torch.float16),
nn.Dropout(0.1)
).to(DEVICE)
nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0)
nn.init.zeros_(self.image_projection[0].bias)
def forward(self, text_input_ids, attention_mask=None, image_embedding=None):
image_embedding = torch.clamp(image_embedding, min=-1e4, max=1e4)
image_embedding = F.normalize(image_embedding, dim=-1, eps=1e-5).to(torch.float32 if DEVICE.type == "cpu" else torch.float16)
with torch.no_grad():
self.image_projection[0].weight.clamp_(-1.0, 1.0)
self.image_projection[0].bias.clamp_(-1.0, 1.0)
projected_image = 1.0 * self.image_projection(image_embedding)
projected_image = torch.clamp(projected_image, min=-1e4, max=1e4)
if torch.isnan(projected_image).any() or torch.isinf(projected_image).any():
print("Warning: Projected image contains NaN or Inf values after clamping, replacing with zeros")
projected_image = torch.where(
torch.logical_or(torch.isnan(projected_image), torch.isinf(projected_image)),
torch.zeros_like(projected_image),
projected_image
)
if projected_image.dim() == 2:
projected_image = projected_image.unsqueeze(1)
text_embeddings = self.phi.get_input_embeddings()(text_input_ids)
fused_embeddings = text_embeddings.clone()
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]")
img_token_mask = (text_input_ids == img_token_id)
for i in range(fused_embeddings.shape[0]):
img_positions = img_token_mask[i].nonzero(as_tuple=True)[0]
if img_positions.numel() > 0:
fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :]
if torch.isnan(fused_embeddings).any() or torch.isinf(fused_embeddings).any():
print("Warning: Fused embeddings contain NaN or Inf values, replacing with zeros")
fused_embeddings = torch.where(
torch.logical_or(torch.isnan(fused_embeddings), torch.isinf(fused_embeddings)),
torch.zeros_like(fused_embeddings),
fused_embeddings
)
return fused_embeddings