Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
torch._dynamo.config.disable = True | |
from collections.abc import Iterator | |
from transformers import ( | |
Gemma3ForConditionalGeneration, | |
TextIteratorStreamer, | |
Gemma3Processor, | |
Gemma3nForConditionalGeneration, | |
Gemma3nProcessor | |
) | |
import spaces | |
from threading import Thread | |
import gradio as gr | |
import os | |
from dotenv import load_dotenv, find_dotenv | |
from loguru import logger | |
from utils import * | |
dotenv_path = find_dotenv() | |
load_dotenv(dotenv_path) | |
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it") | |
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it") | |
input_processor_12 = Gemma3Processor.from_pretrained(model_12_id) | |
input_processor_3n = Gemma3nProcessor.from_pretrained(model_3n_id) | |
model_12 = Gemma3ForConditionalGeneration.from_pretrained( | |
model_12_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
attn_implementation="eager", | |
) | |
model_3n = Gemma3nForConditionalGeneration.from_pretrained( | |
model_3n_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
attn_implementation="eager", | |
) | |
def run( | |
message: dict, | |
history: list[dict], | |
system_prompt_preset: str, | |
custom_system_prompt: str, | |
model_choice: str, | |
max_new_tokens: int, | |
max_images: int, | |
temperature: float, | |
top_p: float, | |
top_k: int, | |
repetition_penalty: float, | |
) -> Iterator[str]: | |
# Define preset system prompts | |
preset_prompts = get_preset_prompts() | |
# Determine which system prompt to use | |
if system_prompt_preset == "Custom Prompt": | |
system_prompt = custom_system_prompt | |
else: | |
system_prompt = preset_prompts.get(system_prompt_preset, custom_system_prompt) | |
logger.debug( | |
f"\n message: {message} \n history: {history} \n system_prompt_preset: {system_prompt_preset} \n " | |
f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}" | |
) | |
def try_fallback_model(original_model_choice: str): | |
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12 | |
fallback_processor = input_processor_3n if original_model_choice == "Gemma 3 12B" else input_processor_12 | |
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B" | |
logger.info(f"Attempting fallback to {fallback_name} model") | |
return fallback_model, fallback_processor, fallback_name | |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n | |
selected_processor = input_processor_12 if model_choice == "Gemma 3 12B" else input_processor_3n | |
current_model_name = model_choice | |
try: | |
messages = [] | |
if system_prompt: | |
messages.append( | |
{"role": "system", "content": [{"type": "text", "text": system_prompt}]} | |
) | |
messages.extend(process_history(history)) | |
user_content = process_user_input(message, max_images) | |
messages.append( | |
{"role": "user", "content": user_content} | |
) | |
# Validate messages structure before processing | |
logger.debug(f"Final messages structure: {len(messages)} messages") | |
for i, msg in enumerate(messages): | |
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}") | |
inputs = selected_processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(device=selected_model.device, dtype=torch.bfloat16) | |
streamer = TextIteratorStreamer( | |
selected_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0 | |
) | |
generate_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
) | |
# Wrapper function to catch thread exceptions | |
def safe_generate(): | |
try: | |
selected_model.generate(**generate_kwargs) | |
except Exception as thread_e: | |
logger.error(f"Exception in generation thread: {thread_e}") | |
logger.error(f"Thread exception type: {type(thread_e)}") | |
# Store the exception so we can handle it in the main thread | |
import traceback | |
logger.error(f"Thread traceback: {traceback.format_exc()}") | |
raise | |
t = Thread(target=safe_generate) | |
t.start() | |
output = "" | |
generation_failed = False | |
try: | |
for delta in streamer: | |
if delta is None: | |
continue | |
output += delta | |
yield output | |
except Exception as stream_error: | |
logger.error(f"Streaming failed with {current_model_name}: {stream_error}") | |
generation_failed = True | |
# Wait for thread to complete | |
t.join(timeout=120) # 2 minute timeout | |
if t.is_alive() or generation_failed or not output.strip(): | |
raise Exception(f"Generation failed or timed out with {current_model_name}") | |
except Exception as primary_error: | |
logger.error(f"Primary model ({current_model_name}) failed: {primary_error}") | |
# Try fallback model | |
try: | |
selected_model, fallback_processor, fallback_name = try_fallback_model(model_choice) | |
logger.info(f"Switching to fallback model: {fallback_name}") | |
# Rebuild inputs for fallback model | |
inputs = fallback_processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
).to(device=selected_model.device, dtype=torch.bfloat16) | |
streamer = TextIteratorStreamer( | |
fallback_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0 | |
) | |
generate_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
) | |
# Wrapper function to catch thread exceptions in fallback | |
def safe_fallback_generate(): | |
try: | |
selected_model.generate(**generate_kwargs) | |
except Exception as thread_e: | |
logger.error(f"Exception in fallback generation thread: {thread_e}") | |
logger.error(f"Fallback thread exception type: {type(thread_e)}") | |
import traceback | |
logger.error(f"Fallback thread traceback: {traceback.format_exc()}") | |
raise | |
t = Thread(target=safe_fallback_generate) | |
t.start() | |
output = f"⚠️ Switched to {fallback_name} due to {current_model_name} failure.\n\n" | |
yield output | |
try: | |
for delta in streamer: | |
if delta is None: | |
continue | |
output += delta | |
yield output | |
except Exception as fallback_stream_error: | |
logger.error(f"Fallback streaming failed: {fallback_stream_error}") | |
raise fallback_stream_error | |
# Wait for fallback thread | |
t.join(timeout=120) | |
if t.is_alive() or not output.strip(): | |
raise Exception(f"Fallback model {fallback_name} also failed") | |
except Exception as fallback_error: | |
logger.error(f"Fallback model also failed: {fallback_error}") | |
# Final fallback - return error message | |
error_message = ( | |
"❌ **Generation Failed**\n\n" | |
f"Both {model_choice} and fallback model encountered errors. " | |
"This could be due to:\n" | |
"- High server load\n" | |
"- Memory constraints\n" | |
"- Input complexity\n\n" | |
"**Suggestions:**\n" | |
"- Try reducing max tokens or image count\n" | |
"- Simplify your prompt\n" | |
"- Try again in a few moments\n\n" | |
f"*Error details: {str(primary_error)[:200]}...*" | |
) | |
yield error_message | |
demo = gr.ChatInterface( | |
fn=run, | |
type="messages", | |
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), | |
textbox=gr.MultimodalTextbox( | |
file_types=[".mp4", ".jpg", ".png", ".pdf"], file_count="multiple", autofocus=True | |
), | |
multimodal=True, | |
additional_inputs=[ | |
gr.Dropdown( | |
label="System Prompt Preset", | |
choices=[ | |
"General Assistant", | |
"Document Analyzer", | |
"Visual Content Expert", | |
"Educational Tutor", | |
"Technical Reviewer", | |
"Creative Storyteller", | |
"Custom Prompt" | |
], | |
value="General Assistant", | |
info="System prompts define the AI's role and behavior. Choose a preset that matches your task, or select 'Custom Prompt' to write your own specialized instructions." | |
), | |
gr.Textbox( | |
label="Custom System Prompt", | |
value="You are a helpful AI assistant capable of analyzing images, videos, and PDF documents. Provide clear, accurate, and helpful responses to user queries.", | |
lines=3, | |
info="Edit this field when 'Custom Prompt' is selected above, or modify any preset" | |
), | |
gr.Dropdown( | |
label="Model", | |
choices=["Gemma 3 12B", "Gemma 3n E4B"], | |
value="Gemma 3 12B", | |
info="Gemma 3 12B: More powerful and detailed responses, but slower processing. Gemma 3n E4B: Faster processing with efficient performance for most tasks." | |
), | |
gr.Slider( | |
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700 | |
), | |
gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2), | |
gr.Slider( | |
label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7 | |
), | |
gr.Slider( | |
label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9 | |
), | |
gr.Slider( | |
label="Top K", minimum=1, maximum=100, step=1, value=50 | |
), | |
gr.Slider( | |
label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1 | |
) | |
], | |
stop_btn=False, | |
) | |
# Connect the dropdown to update the textbox | |
with demo: | |
preset_dropdown = demo.additional_inputs[0] | |
custom_textbox = demo.additional_inputs[1] | |
preset_dropdown.change( | |
fn=update_custom_prompt, | |
inputs=[preset_dropdown], | |
outputs=[custom_textbox] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |