Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftConfig, PeftModel | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Force CPU usage | |
DEVICE = torch.device("cpu") | |
# Use float32 for CPU compatibility | |
torch.set_default_dtype(torch.float32) | |
class MultiModalModel(nn.Module): | |
def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32", peft_model_path=None): | |
super().__init__() | |
# Load CLIP model and processor first (smaller model) | |
print("Loading CLIP model...") | |
self.clip = CLIPModel.from_pretrained(clip_model_name).to(DEVICE) | |
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True) | |
# Load base LLM or PEFT model | |
print("Loading language model...") | |
if peft_model_path and os.path.exists(peft_model_path): | |
print(f"Loading PEFT model from {peft_model_path}") | |
# Load PEFT Config | |
config = PeftConfig.from_pretrained(peft_model_path) | |
# Load base model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
device_map="cpu", | |
low_cpu_mem_usage=True, | |
trust_remote_code=False | |
) | |
# Load PEFT model | |
self.phi = PeftModel.from_pretrained(base_model, peft_model_path) | |
# Get tokenizer from either the PEFT model path or base model | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(peft_model_path) | |
except: | |
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=False) | |
else: | |
print(f"Loading base model {phi_model_name}") | |
self.phi = AutoModelForCausalLM.from_pretrained( | |
phi_model_name, | |
return_dict=True, | |
device_map="cpu", | |
low_cpu_mem_usage=True, | |
trust_remote_code=False | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=False) | |
# Add [IMG] token and make sure pad token is available | |
print("Configuring tokenizer...") | |
num_added = self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}) | |
print(f"Added {num_added} special tokens") | |
# Resize token embeddings if needed | |
if num_added > 0: | |
print("Resizing token embeddings...") | |
self.phi.resize_token_embeddings(len(self.tokenizer)) | |
# Verify [IMG] token was added correctly | |
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") | |
assert img_token_id != self.tokenizer.unk_token_id, "[IMG] not added correctly" | |
# Image projection layer | |
print("Setting up image projection layer...") | |
image_embedding_dim = self.clip.config.projection_dim | |
phi_hidden_size = self.phi.config.hidden_size | |
self.image_projection = nn.Sequential( | |
nn.Linear(image_embedding_dim, image_embedding_dim * 2), | |
nn.GELU(), | |
nn.Linear(image_embedding_dim * 2, phi_hidden_size), | |
nn.LayerNorm(phi_hidden_size), | |
nn.Dropout(0.1) | |
) | |
# Initialize weights | |
nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0) | |
nn.init.zeros_(self.image_projection[0].bias) | |
nn.init.xavier_uniform_(self.image_projection[2].weight, gain=1.0) | |
nn.init.zeros_(self.image_projection[2].bias) | |
def forward(self, text_input_ids, attention_mask=None, image_embedding=None, labels=None): | |
image_embedding = F.normalize(image_embedding, dim=-1) | |
projected_image = 10.0 * self.image_projection(image_embedding) # Amplify image signal | |
if projected_image.dim() == 2: | |
projected_image = projected_image.unsqueeze(1) | |
text_embeddings = self.phi.get_input_embeddings()(text_input_ids) | |
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") | |
img_token_mask = (text_input_ids == img_token_id) | |
fused_embeddings = text_embeddings.clone() | |
for i in range(fused_embeddings.shape[0]): | |
img_positions = img_token_mask[i].nonzero(as_tuple=True)[0] | |
if img_positions.numel() > 0: | |
fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :] | |
return fused_embeddings | |
def process_image(self, image): | |
"""Process an image through CLIP to get embeddings""" | |
image_inputs = self.clip_processor(images=image, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
image_embedding = self.clip.get_image_features(**image_inputs) | |
return image_embedding | |
def generate_description(self, image, prompt_template="[IMG] A detailed description of this image is:", max_tokens=100): | |
"""End-to-end generation from image to text description""" | |
# Process image | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
elif not isinstance(image, Image.Image): | |
# Handle numpy array or other formats | |
if hasattr(image, 'shape') and len(image.shape) == 3: | |
image = Image.fromarray(image).convert("RGB") | |
else: | |
raise ValueError("Unsupported image format") | |
# Process text prompt | |
tokenized = self.tokenizer(prompt_template, return_tensors="pt", truncation=True, max_length=128) | |
text_input_ids = tokenized["input_ids"].to(DEVICE) | |
attention_mask = tokenized["attention_mask"].to(DEVICE) | |
# Get image embedding | |
image_embedding = self.process_image(image) | |
# Generate description | |
with torch.no_grad(): | |
fused_embeddings = self( | |
text_input_ids=text_input_ids, | |
attention_mask=attention_mask, | |
image_embedding=image_embedding | |
) | |
generated_ids = self.phi.generate( | |
inputs_embeds=fused_embeddings, | |
attention_mask=attention_mask, | |
max_new_tokens=max_tokens, | |
do_sample=False, # Greedy decoding for deterministic output | |
repetition_penalty=1.2 | |
) | |
output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return output | |
def load_projection_weights(self, weights_path): | |
"""Load saved weights for the image projection layer""" | |
try: | |
state_dict = torch.load(weights_path, map_location=DEVICE) | |
self.image_projection.load_state_dict(state_dict) | |
return True | |
except Exception as e: | |
print(f"Failed to load projection weights: {e}") | |
return False | |
def load_peft_model(self, peft_model_path): | |
"""Load a PEFT model while keeping the current image projection""" | |
try: | |
# Save current image projection weights | |
current_projection = self.image_projection.state_dict() | |
# Load PEFT Config | |
config = PeftConfig.from_pretrained(peft_model_path) | |
# Load base model | |
base_model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
device_map="cpu", | |
low_cpu_mem_usage=True, | |
trust_remote_code=False | |
) | |
# Load PEFT model | |
self.phi = PeftModel.from_pretrained(base_model, peft_model_path) | |
# Restore image projection | |
self.image_projection.load_state_dict(current_projection) | |
# Ensure tokenizer has [IMG] token | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(peft_model_path) | |
except: | |
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
# Add special tokens if needed | |
num_added = self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}) | |
if num_added > 0: | |
self.phi.resize_token_embeddings(len(self.tokenizer)) | |
return True | |
except Exception as e: | |
print(f"Failed to load PEFT model: {e}") | |
import traceback | |
traceback.print_exc() | |
return False | |
# Global model instance (will be loaded on demand) | |
model = None | |
def load_model(peft_model_path=None): | |
"""Load the model if not already loaded""" | |
global model | |
if model is None: | |
print("Loading models. This may take a few minutes...") | |
model = MultiModalModel(peft_model_path=peft_model_path).to(DEVICE) | |
print("Models loaded!") | |
return model | |
def generate_description(image, prompt, max_length): | |
"""Generate a description for the given image""" | |
try: | |
# Make sure model is loaded | |
model = load_model() | |
if image is None: | |
return "Error: No image provided" | |
result = model.generate_description(image, prompt, int(max_length)) | |
return result | |
except Exception as e: | |
import traceback | |
error_details = traceback.format_exc() | |
print("Error in generate_description:", error_details) | |
return f"Error generating description: {str(e)}" | |
def load_projection_weights(weights_file): | |
"""Load custom projection weights""" | |
try: | |
if weights_file is None: | |
return "No file uploaded" | |
model = load_model() | |
success = model.load_projection_weights(weights_file.name) | |
if success: | |
return "β Projection weights loaded successfully!" | |
else: | |
return "β Failed to load weights" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
def load_peft_model_weights(peft_model_path): | |
"""Load a new PEFT model""" | |
try: | |
if not peft_model_path or peft_model_path.strip() == "": | |
return "No path provided" | |
global model | |
# If model is not loaded yet, load it with the PEFT path | |
if model is None: | |
load_model(peft_model_path) | |
return "β PEFT model loaded successfully!" | |
else: | |
# Model already loaded, update PEFT model | |
success = model.load_peft_model(peft_model_path) | |
if success: | |
return "β PEFT model updated successfully!" | |
else: | |
return "β Failed to update PEFT model" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
def create_interface(): | |
"""Create and return the Gradio interface""" | |
# Define the interface without loading the model immediately | |
with gr.Blocks() as demo: | |
gr.Markdown("# Multimodal Image Description with Phi-3 Mini and PEFT") | |
with gr.Tab("Generate"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Upload Image") | |
prompt_input = gr.Textbox( | |
label="Prompt (use [IMG] for image placement)", | |
value="[IMG] A detailed description of this image is:", | |
lines=2 | |
) | |
max_length = gr.Slider( | |
minimum=10, maximum=300, value=100, step=10, | |
label="Maximum Output Length" | |
) | |
submit_btn = gr.Button("Generate Description") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Generated Description", lines=12) | |
submit_btn.click( | |
generate_description, | |
inputs=[image_input, prompt_input, max_length], | |
outputs=output_text | |
) | |
with gr.Tab("Advanced"): | |
# Model loading section | |
gr.Markdown("### Load PEFT Model") | |
peft_path = gr.Textbox( | |
label="PEFT Model Path (local path or HF Hub ID)", | |
value="/content/drive/MyDrive/V6_Checkpoints/Epoch_peft_1" # Your default path | |
) | |
load_peft_btn = gr.Button("Load PEFT Model") | |
peft_status = gr.Textbox(label="Status") | |
load_peft_btn.click( | |
load_peft_model_weights, | |
inputs=[peft_path], | |
outputs=[peft_status] | |
) | |
# Projection weights section | |
gr.Markdown("### Load Custom Projection Weights") | |
weights_file = gr.File(label="Upload Projection Weights (.pt file)") | |
load_btn = gr.Button("Load Weights") | |
weight_status = gr.Textbox(label="Status") | |
load_btn.click( | |
load_projection_weights, | |
inputs=[weights_file], | |
outputs=[weight_status] | |
) | |
gr.Markdown(""" | |
### About This Model | |
This app uses: | |
- CLIP (ViT-B/32) to extract image features | |
- Phi-3 Mini with PEFT adaptations for text generation | |
- A projection layer to connect image and text spaces | |
The app can load different PEFT models and projection weights for various tasks. | |
""") | |
return demo |