Finetuning_Multimodal_LLM / multimodal_app.py
seemggoel's picture
Update multimodal_app.py
45469a3 verified
import os
import torch
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM
from peft import PeftConfig, PeftModel
import torch.nn as nn
import torch.nn.functional as F
# Force CPU usage
DEVICE = torch.device("cpu")
# Use float32 for CPU compatibility
torch.set_default_dtype(torch.float32)
class MultiModalModel(nn.Module):
def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32", peft_model_path=None):
super().__init__()
# Load CLIP model and processor first (smaller model)
print("Loading CLIP model...")
self.clip = CLIPModel.from_pretrained(clip_model_name).to(DEVICE)
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True)
# Load base LLM or PEFT model
print("Loading language model...")
if peft_model_path and os.path.exists(peft_model_path):
print(f"Loading PEFT model from {peft_model_path}")
# Load PEFT Config
config = PeftConfig.from_pretrained(peft_model_path)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=False
)
# Load PEFT model
self.phi = PeftModel.from_pretrained(base_model, peft_model_path)
# Get tokenizer from either the PEFT model path or base model
try:
self.tokenizer = AutoTokenizer.from_pretrained(peft_model_path)
except:
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=False)
else:
print(f"Loading base model {phi_model_name}")
self.phi = AutoModelForCausalLM.from_pretrained(
phi_model_name,
return_dict=True,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=False
)
self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=False)
# Add [IMG] token and make sure pad token is available
print("Configuring tokenizer...")
num_added = self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"})
print(f"Added {num_added} special tokens")
# Resize token embeddings if needed
if num_added > 0:
print("Resizing token embeddings...")
self.phi.resize_token_embeddings(len(self.tokenizer))
# Verify [IMG] token was added correctly
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]")
assert img_token_id != self.tokenizer.unk_token_id, "[IMG] not added correctly"
# Image projection layer
print("Setting up image projection layer...")
image_embedding_dim = self.clip.config.projection_dim
phi_hidden_size = self.phi.config.hidden_size
self.image_projection = nn.Sequential(
nn.Linear(image_embedding_dim, image_embedding_dim * 2),
nn.GELU(),
nn.Linear(image_embedding_dim * 2, phi_hidden_size),
nn.LayerNorm(phi_hidden_size),
nn.Dropout(0.1)
)
# Initialize weights
nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0)
nn.init.zeros_(self.image_projection[0].bias)
nn.init.xavier_uniform_(self.image_projection[2].weight, gain=1.0)
nn.init.zeros_(self.image_projection[2].bias)
def forward(self, text_input_ids, attention_mask=None, image_embedding=None, labels=None):
image_embedding = F.normalize(image_embedding, dim=-1)
projected_image = 10.0 * self.image_projection(image_embedding) # Amplify image signal
if projected_image.dim() == 2:
projected_image = projected_image.unsqueeze(1)
text_embeddings = self.phi.get_input_embeddings()(text_input_ids)
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]")
img_token_mask = (text_input_ids == img_token_id)
fused_embeddings = text_embeddings.clone()
for i in range(fused_embeddings.shape[0]):
img_positions = img_token_mask[i].nonzero(as_tuple=True)[0]
if img_positions.numel() > 0:
fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :]
return fused_embeddings
def process_image(self, image):
"""Process an image through CLIP to get embeddings"""
image_inputs = self.clip_processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
image_embedding = self.clip.get_image_features(**image_inputs)
return image_embedding
def generate_description(self, image, prompt_template="[IMG] A detailed description of this image is:", max_tokens=100):
"""End-to-end generation from image to text description"""
# Process image
if isinstance(image, str):
image = Image.open(image).convert("RGB")
elif not isinstance(image, Image.Image):
# Handle numpy array or other formats
if hasattr(image, 'shape') and len(image.shape) == 3:
image = Image.fromarray(image).convert("RGB")
else:
raise ValueError("Unsupported image format")
# Process text prompt
tokenized = self.tokenizer(prompt_template, return_tensors="pt", truncation=True, max_length=128)
text_input_ids = tokenized["input_ids"].to(DEVICE)
attention_mask = tokenized["attention_mask"].to(DEVICE)
# Get image embedding
image_embedding = self.process_image(image)
# Generate description
with torch.no_grad():
fused_embeddings = self(
text_input_ids=text_input_ids,
attention_mask=attention_mask,
image_embedding=image_embedding
)
generated_ids = self.phi.generate(
inputs_embeds=fused_embeddings,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=False, # Greedy decoding for deterministic output
repetition_penalty=1.2
)
output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return output
def load_projection_weights(self, weights_path):
"""Load saved weights for the image projection layer"""
try:
state_dict = torch.load(weights_path, map_location=DEVICE)
self.image_projection.load_state_dict(state_dict)
return True
except Exception as e:
print(f"Failed to load projection weights: {e}")
return False
def load_peft_model(self, peft_model_path):
"""Load a PEFT model while keeping the current image projection"""
try:
# Save current image projection weights
current_projection = self.image_projection.state_dict()
# Load PEFT Config
config = PeftConfig.from_pretrained(peft_model_path)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=False
)
# Load PEFT model
self.phi = PeftModel.from_pretrained(base_model, peft_model_path)
# Restore image projection
self.image_projection.load_state_dict(current_projection)
# Ensure tokenizer has [IMG] token
try:
self.tokenizer = AutoTokenizer.from_pretrained(peft_model_path)
except:
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Add special tokens if needed
num_added = self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"})
if num_added > 0:
self.phi.resize_token_embeddings(len(self.tokenizer))
return True
except Exception as e:
print(f"Failed to load PEFT model: {e}")
import traceback
traceback.print_exc()
return False
# Global model instance (will be loaded on demand)
model = None
def load_model(peft_model_path=None):
"""Load the model if not already loaded"""
global model
if model is None:
print("Loading models. This may take a few minutes...")
model = MultiModalModel(peft_model_path=peft_model_path).to(DEVICE)
print("Models loaded!")
return model
def generate_description(image, prompt, max_length):
"""Generate a description for the given image"""
try:
# Make sure model is loaded
model = load_model()
if image is None:
return "Error: No image provided"
result = model.generate_description(image, prompt, int(max_length))
return result
except Exception as e:
import traceback
error_details = traceback.format_exc()
print("Error in generate_description:", error_details)
return f"Error generating description: {str(e)}"
def load_projection_weights(weights_file):
"""Load custom projection weights"""
try:
if weights_file is None:
return "No file uploaded"
model = load_model()
success = model.load_projection_weights(weights_file.name)
if success:
return "βœ… Projection weights loaded successfully!"
else:
return "❌ Failed to load weights"
except Exception as e:
return f"❌ Error: {str(e)}"
def load_peft_model_weights(peft_model_path):
"""Load a new PEFT model"""
try:
if not peft_model_path or peft_model_path.strip() == "":
return "No path provided"
global model
# If model is not loaded yet, load it with the PEFT path
if model is None:
load_model(peft_model_path)
return "βœ… PEFT model loaded successfully!"
else:
# Model already loaded, update PEFT model
success = model.load_peft_model(peft_model_path)
if success:
return "βœ… PEFT model updated successfully!"
else:
return "❌ Failed to update PEFT model"
except Exception as e:
return f"❌ Error: {str(e)}"
def create_interface():
"""Create and return the Gradio interface"""
# Define the interface without loading the model immediately
with gr.Blocks() as demo:
gr.Markdown("# Multimodal Image Description with Phi-3 Mini and PEFT")
with gr.Tab("Generate"):
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image")
prompt_input = gr.Textbox(
label="Prompt (use [IMG] for image placement)",
value="[IMG] A detailed description of this image is:",
lines=2
)
max_length = gr.Slider(
minimum=10, maximum=300, value=100, step=10,
label="Maximum Output Length"
)
submit_btn = gr.Button("Generate Description")
with gr.Column():
output_text = gr.Textbox(label="Generated Description", lines=12)
submit_btn.click(
generate_description,
inputs=[image_input, prompt_input, max_length],
outputs=output_text
)
with gr.Tab("Advanced"):
# Model loading section
gr.Markdown("### Load PEFT Model")
peft_path = gr.Textbox(
label="PEFT Model Path (local path or HF Hub ID)",
value="/content/drive/MyDrive/V6_Checkpoints/Epoch_peft_1" # Your default path
)
load_peft_btn = gr.Button("Load PEFT Model")
peft_status = gr.Textbox(label="Status")
load_peft_btn.click(
load_peft_model_weights,
inputs=[peft_path],
outputs=[peft_status]
)
# Projection weights section
gr.Markdown("### Load Custom Projection Weights")
weights_file = gr.File(label="Upload Projection Weights (.pt file)")
load_btn = gr.Button("Load Weights")
weight_status = gr.Textbox(label="Status")
load_btn.click(
load_projection_weights,
inputs=[weights_file],
outputs=[weight_status]
)
gr.Markdown("""
### About This Model
This app uses:
- CLIP (ViT-B/32) to extract image features
- Phi-3 Mini with PEFT adaptations for text generation
- A projection layer to connect image and text spaces
The app can load different PEFT models and projection weights for various tasks.
""")
return demo