Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftConfig, PeftModel | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import gc | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info(f"Gradio version: {gr.__version__}") | |
# Device setup | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.set_default_dtype(torch.float16) | |
logger.info(f"Using device: {DEVICE}") | |
class MultiModalModel(nn.Module): | |
def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", | |
clip_model_name="openai/clip-vit-base-patch32", peft_model_path=None): | |
super().__init__() | |
logger.info("Loading CLIP model...") | |
self.clip = CLIPModel.from_pretrained(clip_model_name, torch_dtype=torch.float16).to(DEVICE) | |
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True) | |
logger.info("Loading language model...") | |
if peft_model_path: | |
logger.info(f"Loading PEFT model from {peft_model_path}") | |
try: | |
config = PeftConfig.from_pretrained(peft_model_path) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
device_map=DEVICE | |
) | |
self.phi = PeftModel.from_pretrained(base_model, peft_model_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(peft_model_path) | |
except Exception as e: | |
logger.error(f"Failed to load PEFT model: {str(e)}", exc_info=True) | |
raise | |
else: | |
logger.info(f"Loading base model {phi_model_name}") | |
self.phi = AutoModelForCausalLM.from_pretrained( | |
phi_model_name, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
device_map=DEVICE | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name) | |
self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}) | |
self.phi.resize_token_embeddings(len(self.tokenizer)) | |
image_embedding_dim = self.clip.config.projection_dim | |
phi_hidden_size = self.phi.config.hidden_size | |
self.image_projection = nn.Sequential( | |
nn.Linear(image_embedding_dim, image_embedding_dim * 2), | |
nn.GELU(), | |
nn.Linear(image_embedding_dim * 2, phi_hidden_size), | |
nn.LayerNorm(phi_hidden_size), | |
nn.Dropout(0.1) | |
).to(DEVICE) | |
def forward(self, text_input_ids, attention_mask=None, image_embedding=None): | |
image_embedding = F.normalize(image_embedding, dim=-1) | |
projected_image = 10.0 * self.image_projection(image_embedding) | |
if projected_image.dim() == 2: | |
projected_image = projected_image.unsqueeze(1) | |
text_embeddings = self.phi.get_input_embeddings()(text_input_ids) | |
img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]") | |
img_token_mask = (text_input_ids == img_token_id) | |
fused_embeddings = text_embeddings.clone() | |
for i in range(fused_embeddings.shape[0]): | |
img_positions = img_token_mask[i].nonzero(as_tuple=True)[0] | |
if img_positions.numel() > 0: | |
fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :] | |
return fused_embeddings | |
def process_image(self, image): | |
image_inputs = self.clip_processor(images=image, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
image_embedding = self.clip.get_image_features(**image_inputs) | |
return image_embedding | |
def generate_description(self, image, prompt_template="[IMG] A detailed description of this image is:", max_tokens=100): | |
if isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
elif not isinstance(image, Image.Image): | |
image = Image.fromarray(image).convert("RGB") | |
image = image.resize((224, 224), Image.LANCZOS) | |
tokenized = self.tokenizer(prompt_template, return_tensors="pt", truncation=True, max_length=128) | |
text_input_ids = tokenized["input_ids"].to(DEVICE) | |
attention_mask = tokenized["attention_mask"].to(DEVICE) | |
image_embedding = self.process_image(image) | |
with torch.no_grad(): | |
fused_embeddings = self( | |
text_input_ids=text_input_ids, | |
attention_mask=attention_mask, | |
image_embedding=image_embedding | |
) | |
generated_ids = self.phi.generate( | |
inputs_embeds=fused_embeddings, | |
attention_mask=attention_mask, | |
max_new_tokens=max_tokens, | |
do_sample=False, | |
repetition_penalty=1.2 | |
) | |
output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return output | |
model = None | |
def load_model(peft_model_path=None): | |
global model | |
if model is None: | |
logger.info("Loading model...") | |
try: | |
model = MultiModalModel(peft_model_path=peft_model_path) | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}", exc_info=True) | |
raise | |
gc.collect() | |
if DEVICE.type == "cuda": | |
torch.cuda.empty_cache() | |
return model | |
def generate_description(image, prompt, max_length): | |
logger.info("Generating description...") | |
try: | |
model = load_model(peft_model_path=os.getenv("model_V1", None)) | |
if image is None: | |
logger.error("No image provided") | |
return "Error: No image provided" | |
result = model.generate_description(image, prompt, int(max_length)) | |
logger.info("Description generated successfully") | |
gc.collect() | |
return result | |
except Exception as e: | |
logger.error(f"Error generating description: {str(e)}", exc_info=True) | |
return f"Error: {str(e)}" | |
import gradio as gr | |
# Gradio interface | |
def create_gradio_interface(generate_fn, color_theme=COLOR_THEME): | |
# Set color variables based on theme | |
if color_theme == "blue": | |
primary_gradient = "linear-gradient(145deg, #e0f2fe, #dbeafe)" | |
header_gradient = "linear-gradient(135deg, #bfdbfe, #93c5fd)" # Blue gradient for header | |
button_gradient = "linear-gradient(135deg, #3b82f6, #1d4ed8)" | |
button_hover_gradient = "linear-gradient(135deg, #2563eb, #1e40af)" | |
primary_color = "#1e40af" | |
icon_color = "#2563eb" | |
shadow_color = "rgba(59, 130, 246, 0.15)" | |
button_shadow = "rgba(29, 78, 216, 0.25)" | |
else: | |
primary_gradient = "linear-gradient(145deg, #fff7ed, #ffedd5)" | |
header_gradient = "linear-gradient(135deg, #fed7aa, #fdba74)" # Orange gradient for header | |
button_gradient = "linear-gradient(135deg, #f97316, #ea580c)" | |
button_hover_gradient = "linear-gradient(135deg, #ea580c, #c2410c)" | |
primary_color = "#9a3412" | |
icon_color = "#ea580c" | |
shadow_color = "rgba(249, 115, 22, 0.15)" | |
button_shadow = "rgba(234, 88, 12, 0.25)" | |
# Gradio interface | |
def create_gradio_interface(generate_fn, color_theme=COLOR_THEME): | |
# Set color variables based on theme | |
if color_theme == "blue": | |
primary_gradient = "linear-gradient(145deg, #e0f2fe, #dbeafe)" | |
header_gradient = "linear-gradient(135deg, #bfdbfe, #93c5fd)" | |
header_background = "#dbeafe" # Light blue for section headers | |
button_gradient = "linear-gradient(135deg, #3b82f6, #1d4ed8)" | |
button_hover_gradient = "linear-gradient(135deg, #2563eb, #1e40af)" | |
primary_color = "#1e40af" | |
icon_color = "#2563eb" | |
shadow_color = "rgba(59, 130, 246, 0.15)" | |
button_shadow = "rgba(29, 78, 216, 0.25)" | |
else: | |
primary_gradient = "linear-gradient(145deg, #fff7ed, #ffedd5)" | |
header_gradient = "linear-gradient(135deg, #fed7aa, #fdba74)" | |
header_background = "#ffedd5" # Light orange for section headers | |
button_gradient = "linear-gradient(135deg, #f97316, #ea580c)" | |
button_hover_gradient = "linear-gradient(135deg, #ea580c, #c2410c)" | |
primary_color = "#9a3412" | |
icon_color = "#ea580c" | |
shadow_color = "rgba(249, 115, 22, 0.15)" | |
button_shadow = "rgba(234, 88, 12, 0.25)" | |
# Custom CSS with dynamic color variables | |
custom_css = f""" | |
body {{ | |
font-family: 'Inter', 'Segoe UI', sans-serif; | |
background-color: #f8fafc; | |
}} | |
.container {{ | |
background: {primary_gradient}; | |
border-radius: 16px; | |
padding: 30px; | |
max-width: 1200px; | |
margin: 0 auto; | |
box-shadow: 0 10px 25px {shadow_color}; | |
}} | |
.app-header {{ | |
text-align: center; | |
margin-bottom: 30px; | |
background: {header_gradient}; | |
border-radius: 12px; | |
padding: 20px; | |
box-shadow: 0 4px 12px {shadow_color}; | |
}} | |
.app-title {{ | |
color: {primary_color}; | |
font-size: 2.2em; | |
font-weight: 700; | |
margin-bottom: 10px; | |
}} | |
.app-description {{ | |
color: #334155; | |
font-size: 1.1em; | |
line-height: 1.5; | |
max-width: 700px; | |
margin: 0 auto; | |
}} | |
.card {{ | |
background: #ffffff; | |
border-radius: 12px; | |
padding: 20px; | |
margin-bottom: 20px; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.05); | |
border: 1px solid rgba(226, 232, 240, 0.8); | |
transition: transform 0.2s, box-shadow 0.2s; | |
height: 100%; | |
}} | |
.card:hover {{ | |
transform: translateY(-2px); | |
box-shadow: 0 6px 16px rgba(0,0,0,0.08); | |
}} | |
.input-label {{ | |
color: {primary_color}; | |
font-weight: 600; | |
margin-bottom: 8px; | |
font-size: 1.05em; | |
background: {header_background}; /* Add background to section headers */ | |
padding: 5px 10px; | |
border-radius: 6px; | |
display: inline-block; | |
}} | |
.output-card {{ | |
background: #ffffff; | |
border-radius: 12px; | |
padding: 25px; | |
border: 1px solid rgba(226, 232, 240, 0.8); | |
box-shadow: 0 4px 15px rgba(0,0,0,0.05); | |
height: 100%; | |
display: flex; | |
flex-direction: column; | |
}} | |
.output-content {{ | |
font-size: 1.1em; | |
line-height: 1.6; | |
color: #1e293b; | |
flex-grow: 1; | |
}} | |
.btn-generate {{ | |
background: {button_gradient} !important; | |
color: white !important; | |
border-radius: 8px !important; | |
padding: 12px 24px !important; | |
font-weight: 600 !important; | |
font-size: 1.05em !important; | |
border: none !important; | |
box-shadow: 0 4px 12px {button_shadow} !important; | |
transition: all 0.3s ease !important; | |
width: 100% !important; | |
margin-top: 15px; | |
}} | |
.btn-generate:hover {{ | |
background: {button_hover_gradient} !important; | |
box-shadow: 0 6px 16px {button_shadow} !important; | |
transform: translateY(-2px) !important; | |
}} | |
.footer {{ | |
text-align: center; | |
margin-top: 30px; | |
color: #64748b; | |
font-size: 0.9em; | |
}} | |
.model-selector {{ | |
margin-bottom: 15px; | |
}} | |
.input-icon {{ | |
font-size: 1.5em; | |
margin-right: 8px; | |
color: {icon_color}; | |
}} | |
.divider {{ | |
border-top: 1px solid #e2e8f0; | |
margin: 15px 0; | |
}} | |
.input-section {{ | |
height: 100%; | |
}} | |
.result-heading {{ | |
margin-bottom: 15px; | |
color: {primary_color}; | |
background: {header_background}; /* Add background to result header */ | |
padding: 5px 10px; | |
border-radius: 6px; | |
display: inline-block; | |
}} | |
""" | |
# Create Blocks interface with improved structure and parallel layout | |
with gr.Blocks(css=custom_css) as iface: | |
with gr.Group(): | |
icon = "π·" if color_theme == "blue" else "πΆ" | |
app_name = "OmniPhi Blue" if color_theme == "blue" else "OmniPhi Orange" | |
gr.Markdown( | |
f""" | |
<div class="app-header"> | |
<div class="app-title">{icon} {app_name}</div> | |
<div class="app-description">Advanced Multi-Modal AI with BLIP or Custom Model Integration. Upload an image and provide instructions through text or voice to generate detailed descriptions.</div> | |
</div> | |
""" | |
) | |
# Main content in a 2-column layout (inputs and output side by side) | |
with gr.Row(): | |
# Left column for all inputs | |
with gr.Column(scale=3): | |
with gr.Group(): | |
# Image upload card | |
with gr.Group(): | |
gr.Markdown('<span class="input-icon">πΌοΈ</span><span class="input-label">Upload Image</span>') | |
image_input = gr.Image( | |
type="pil", | |
label=None | |
) | |
# Text and voice input card | |
with gr.Group(): | |
gr.Markdown('<span class="input-icon">π¬</span><span class="input-label">Text Instruction</span>') | |
text_input = gr.Textbox( | |
label=None, | |
placeholder="e.g., Describe this image in detail, focusing on the environment...", | |
lines=3 | |
) | |
gr.Markdown('<div class="divider"></div>') | |
gr.Markdown('<span class="input-icon">ποΈ</span><span class="input-label">Voice Instruction (optional)</span>') | |
audio_input = gr.Audio( | |
type="microphone", | |
label=None | |
) | |
gr.Markdown('<div class="divider"></div>') | |
gr.Markdown('<span class="input-icon">βοΈ</span><span class="input-label">Model Selection</span>') | |
with gr.Group(): | |
model_choice = gr.Radio( | |
choices=["BLIP", "OmniPhi"], | |
value="BLIP", | |
label=None, | |
interactive=True | |
) | |
submit_btn = gr.Button("Generate Description") | |
# Right column for output | |
with gr.Column(scale=2): | |
with gr.Group(): | |
gr.Markdown('<span class="input-icon">β¨</span><span class="input-label result-heading">Generated Description</span>') | |
output = gr.Textbox( | |
label=None, | |
lines=12, | |
placeholder="Your description will appear here after generation..." | |
) | |
# Footer | |
gr.Markdown( | |
f""" | |
<div class="footer"> | |
Powered by OmniPhi Technology β’ Upload your image and provide instructions through text or voice | |
</div> | |
""" | |
) | |
# Connect the button to the function | |
submit_btn.click( | |
fn=generate_fn, | |
inputs=[image_input, text_input, audio_input, model_choice], | |
outputs=output | |
) | |
return iface | |
# Main execution | |
if __name__ == "__main__": | |
# Load models | |
transcriber = initialize_transcriber(WHISPER_MODEL) | |
blip_model, blip_processor = load_blip(BLIP_MODEL, DEVICE, TORCH_DTYPE) | |
clip_model, clip_processor = load_clip(CLIP_MODEL, DEVICE, TORCH_DTYPE) | |
omniphi_model, omniphi_tokenizer = load_omniphi(CHECKPOINT_DIR, PHI_MODEL, CLIP_MODEL, DEVICE) | |
# Define generate function | |
generate_fn = lambda image, text_prompt, audio, model_choice: generate_description( | |
image, text_prompt, audio, model_choice, transcriber, blip_model, blip_processor, | |
clip_model, clip_processor, omniphi_model, omniphi_tokenizer, DEVICE | |
) | |
# Launch Gradio interface | |
iface = create_gradio_interface(generate_fn, color_theme=COLOR_THEME) | |
iface.launch(server_name="0.0.0.0", server_port=7860) |