Spaces:
Runtime error
Runtime error
# import gradio as gr | |
# from PIL import Image | |
# import torch | |
# from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM | |
# # Set device to CPU and default dtype to float32 | |
# DEVICE = torch.device("cpu") | |
# torch.set_default_dtype(torch.float32) | |
# # Load CLIP model and processor | |
# try: | |
# clip_model = CLIPModel.from_pretrained( | |
# "openai/clip-vit-base-patch32", | |
# torch_dtype=torch.float32 | |
# ).to(DEVICE) | |
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# except Exception as e: | |
# raise Exception(f"Error loading CLIP model or processor: {str(e)}") | |
# # Load language model and tokenizer | |
# def load_model(): | |
# try: | |
# # Use a lightweight model suitable for CPU (distilgpt2 for lower memory) | |
# #model_name = "distilgpt2" # Switched to distilgpt2 for better CPU performance | |
# model_name="microsoft/phi-3-mini-4k-instruct" | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.float32, | |
# trust_remote_code=True | |
# ).to(DEVICE) | |
# tokenizer = AutoTokenizer.from_pretrained( | |
# model_name, | |
# trust_remote_code=True | |
# ) | |
# # Set pad token if not defined | |
# if tokenizer.pad_token is None: | |
# tokenizer.pad_token = tokenizer.eos_token | |
# model.config.pad_token_id = model.config.eos_token_id | |
# model.eval() | |
# return model, tokenizer | |
# except Exception as e: | |
# raise Exception(f"Error loading language model: {str(e)}") | |
# # Simple multimodal captioning function | |
# def generate_caption(image, model, tokenizer): | |
# try: | |
# if not isinstance(image, Image.Image): | |
# return "Error: Input must be a valid image." | |
# if image.mode != "RGB": | |
# image = image.convert("RGB") | |
# # Process image with CLIP | |
# image_inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE) | |
# with torch.no_grad(): | |
# image_embedding = clip_model.get_image_features(**image_inputs).to(torch.float32) | |
# # Prepare prompt | |
# prompt = "Caption this image:" | |
# inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
# input_ids = inputs["input_ids"].to(DEVICE) | |
# attention_mask = inputs["attention_mask"].to(DEVICE) | |
# # Simple projection: use image embedding as a prefix | |
# projection = torch.nn.Linear(512, model.config.hidden_size).to(DEVICE) | |
# with torch.no_grad(): | |
# image_embedding_projected = projection(image_embedding) | |
# # Combine image and text embeddings | |
# text_embedding = model.get_input_embeddings()(input_ids) | |
# fused_embedding = torch.cat([image_embedding_projected.unsqueeze(1), text_embedding], dim=1) | |
# attention_mask = torch.cat([ | |
# torch.ones(input_ids.size(0), 1, device=DEVICE), | |
# attention_mask | |
# ], dim=1) | |
# # Generate caption | |
# with torch.no_grad(): | |
# generated_ids = model.generate( | |
# inputs_embeds=fused_embedding, | |
# attention_mask=attention_mask, | |
# max_new_tokens=50, | |
# min_length=10, | |
# num_beams=3, # Reduced for CPU speed | |
# repetition_penalty=1.2, | |
# do_sample=False | |
# ) | |
# caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
# return caption.strip() | |
# except Exception as e: | |
# return f"Error generating caption: {str(e)}" | |
# # Load model and tokenizer | |
# model, tokenizer = load_model() | |
# # Gradio interface with explicit component configuration | |
# def gradio_caption(image): | |
# if image is None: | |
# return "Please upload an image." | |
# result = generate_caption(image, model, tokenizer) | |
# return result if isinstance(result, str) else str(result) | |
# # Define components explicitly to avoid schema issues | |
# inputs = gr.Image( | |
# type="pil", | |
# label="Upload an Image", | |
# sources=["upload"], # Restrict to uploads to simplify schema | |
# ) | |
# outputs = gr.Textbox( | |
# label="Generated Caption", | |
# lines=2, | |
# placeholder="Caption will appear here..." | |
# ) | |
# # Use gr.Blocks for finer control instead of gr.Interface | |
# with gr.Blocks(title="CPU-Based Image Captioning") as interface: | |
# gr.Markdown( | |
# """ | |
# # CPU-Based Image Captioning with CLIP and DistilGPT2 | |
# Upload an image to generate a caption using a lightweight multimodal model. | |
# This app runs on CPU and may produce basic captions due to simplified processing. | |
# """ | |
# ) | |
# with gr.Row(): | |
# with gr.Column(): | |
# image_input = inputs | |
# submit_button = gr.Button("Generate Caption") | |
# with gr.Column(): | |
# caption_output = outputs | |
# submit_button.click( | |
# fn=gradio_caption, | |
# inputs=image_input, | |
# outputs=caption_output | |
# ) | |
# # Launch locally with debugging enabled | |
# interface.launch(debug=True) | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM | |
# Set device to CPU and default dtype to float32 | |
DEVICE = torch.device("cpu") | |
torch.set_default_dtype(torch.float32) | |
# Load CLIP model and processor | |
try: | |
clip_model = CLIPModel.from_pretrained( | |
"openai/clip-vit-base-patch32", | |
torch_dtype=torch.float32 | |
).to(DEVICE) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
except Exception as e: | |
raise Exception(f"Error loading CLIP model or processor: {str(e)}") | |
# Load language model and tokenizer | |
def load_model(): | |
try: | |
#model_name = "distilgpt2" | |
model_name="microsoft/phi-3-mini-4k-instruct" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
trust_remote_code=True | |
).to(DEVICE) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
model.eval() | |
return model, tokenizer | |
except Exception as e: | |
raise Exception(f"Error loading language model: {str(e)}") | |
# Caption generation logic | |
def generate_caption(image, model, tokenizer): | |
try: | |
# Ensure the image is a PIL Image and convert to RGB if necessary | |
if not isinstance(image, Image.Image): | |
image = Image.frombytes('RGB', image.size, image.rgb) if hasattr(image, 'rgb') else image | |
else: | |
# Convert to RGB if the image has a different mode (e.g., RGBA, L) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image_inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
image_embedding = clip_model.get_image_features(**image_inputs).to(torch.float32) | |
prompt = "[IMG] Caption this image:" | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
input_ids = inputs["input_ids"].to(DEVICE) | |
attention_mask = inputs["attention_mask"].to(DEVICE) | |
projection = torch.nn.Linear(512, model.config.hidden_size).to(DEVICE) | |
with torch.no_grad(): | |
image_embedding_projected = projection(image_embedding) | |
text_embedding = model.get_input_embeddings()(input_ids) | |
fused_embedding = torch.cat([image_embedding_projected.unsqueeze(1), text_embedding], dim=1) | |
attention_mask = torch.cat([ | |
torch.ones(input_ids.size(0), 1, device=DEVICE), | |
attention_mask | |
], dim=1) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
inputs_embeds=fused_embedding, | |
attention_mask=attention_mask, | |
max_new_tokens=50, | |
min_length=10, | |
num_beams=3, | |
repetition_penalty=1.2, | |
do_sample=False | |
) | |
caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return caption.strip() | |
except Exception as e: | |
return f"Error generating caption: {str(e)}" | |
# Load model/tokenizer | |
model, tokenizer = load_model() | |
# Wrapper for Gradio function call | |
def gradio_caption(image): | |
if image is None: | |
return "Please upload an image." | |
return generate_caption(image, model, tokenizer) | |
# Reusable UI component blocks | |
def create_image_input(): | |
return gr.Image( | |
type="pil", | |
label="Upload an Image", | |
sources=["upload"] | |
) | |
def create_caption_output(): | |
return gr.Textbox( | |
label="Generated Caption", | |
lines=2, | |
placeholder="Caption will appear here..." | |
) | |
# Build UI | |
interface = gr.Interface( | |
fn=gradio_caption, | |
inputs=gr.Image(type="pil", label="Upload an Image"), | |
outputs=gr.Textbox(label="Generated Caption"), | |
title="Image Captioning with Fine-Tuned MultiModalModel (Epoch 0)", | |
description=( | |
"Upload an image to generate a caption using a fine-tuned multimodal model based on Phi-3 and CLIP. " | |
"The weights from Epoch_0 are used here, but the model may not generate accurate captions due to limited training." | |
) | |
) | |
interface.launch() | |