gemma-demo / app.py
AC2513's picture
altered processor due to huggingface update
a1ac37a
import torch
torch._dynamo.config.disable = True
from collections.abc import Iterator
from transformers import (
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
Gemma3Processor,
Gemma3nForConditionalGeneration,
Gemma3nProcessor
)
import spaces
from threading import Thread
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv
from loguru import logger
from utils import *
dotenv_path = find_dotenv()
load_dotenv(dotenv_path)
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
input_processor_12 = Gemma3Processor.from_pretrained(model_12_id)
input_processor_3n = Gemma3nProcessor.from_pretrained(model_3n_id)
model_12 = Gemma3ForConditionalGeneration.from_pretrained(
model_12_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
model_3n_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
@spaces.GPU(duration=120)
def run(
message: dict,
history: list[dict],
system_prompt_preset: str,
custom_system_prompt: str,
model_choice: str,
max_new_tokens: int,
max_images: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
# Define preset system prompts
preset_prompts = get_preset_prompts()
# Determine which system prompt to use
if system_prompt_preset == "Custom Prompt":
system_prompt = custom_system_prompt
else:
system_prompt = preset_prompts.get(system_prompt_preset, custom_system_prompt)
logger.debug(
f"\n message: {message} \n history: {history} \n system_prompt_preset: {system_prompt_preset} \n "
f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
)
def try_fallback_model(original_model_choice: str):
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
fallback_processor = input_processor_3n if original_model_choice == "Gemma 3 12B" else input_processor_12
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
logger.info(f"Attempting fallback to {fallback_name} model")
return fallback_model, fallback_processor, fallback_name
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
selected_processor = input_processor_12 if model_choice == "Gemma 3 12B" else input_processor_3n
current_model_name = model_choice
try:
messages = []
if system_prompt:
messages.append(
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
)
messages.extend(process_history(history))
user_content = process_user_input(message, max_images)
messages.append(
{"role": "user", "content": user_content}
)
# Validate messages structure before processing
logger.debug(f"Final messages structure: {len(messages)} messages")
for i, msg in enumerate(messages):
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
inputs = selected_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=selected_model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
selected_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
)
# Wrapper function to catch thread exceptions
def safe_generate():
try:
selected_model.generate(**generate_kwargs)
except Exception as thread_e:
logger.error(f"Exception in generation thread: {thread_e}")
logger.error(f"Thread exception type: {type(thread_e)}")
# Store the exception so we can handle it in the main thread
import traceback
logger.error(f"Thread traceback: {traceback.format_exc()}")
raise
t = Thread(target=safe_generate)
t.start()
output = ""
generation_failed = False
try:
for delta in streamer:
if delta is None:
continue
output += delta
yield output
except Exception as stream_error:
logger.error(f"Streaming failed with {current_model_name}: {stream_error}")
generation_failed = True
# Wait for thread to complete
t.join(timeout=120) # 2 minute timeout
if t.is_alive() or generation_failed or not output.strip():
raise Exception(f"Generation failed or timed out with {current_model_name}")
except Exception as primary_error:
logger.error(f"Primary model ({current_model_name}) failed: {primary_error}")
# Try fallback model
try:
selected_model, fallback_processor, fallback_name = try_fallback_model(model_choice)
logger.info(f"Switching to fallback model: {fallback_name}")
# Rebuild inputs for fallback model
inputs = fallback_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=selected_model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
fallback_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
)
# Wrapper function to catch thread exceptions in fallback
def safe_fallback_generate():
try:
selected_model.generate(**generate_kwargs)
except Exception as thread_e:
logger.error(f"Exception in fallback generation thread: {thread_e}")
logger.error(f"Fallback thread exception type: {type(thread_e)}")
import traceback
logger.error(f"Fallback thread traceback: {traceback.format_exc()}")
raise
t = Thread(target=safe_fallback_generate)
t.start()
output = f"⚠️ Switched to {fallback_name} due to {current_model_name} failure.\n\n"
yield output
try:
for delta in streamer:
if delta is None:
continue
output += delta
yield output
except Exception as fallback_stream_error:
logger.error(f"Fallback streaming failed: {fallback_stream_error}")
raise fallback_stream_error
# Wait for fallback thread
t.join(timeout=120)
if t.is_alive() or not output.strip():
raise Exception(f"Fallback model {fallback_name} also failed")
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
# Final fallback - return error message
error_message = (
"❌ **Generation Failed**\n\n"
f"Both {model_choice} and fallback model encountered errors. "
"This could be due to:\n"
"- High server load\n"
"- Memory constraints\n"
"- Input complexity\n\n"
"**Suggestions:**\n"
"- Try reducing max tokens or image count\n"
"- Simplify your prompt\n"
"- Try again in a few moments\n\n"
f"*Error details: {str(primary_error)[:200]}...*"
)
yield error_message
demo = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=[".mp4", ".jpg", ".png", ".pdf"], file_count="multiple", autofocus=True
),
multimodal=True,
additional_inputs=[
gr.Dropdown(
label="System Prompt Preset",
choices=[
"General Assistant",
"Document Analyzer",
"Visual Content Expert",
"Educational Tutor",
"Technical Reviewer",
"Creative Storyteller",
"Custom Prompt"
],
value="General Assistant",
info="System prompts define the AI's role and behavior. Choose a preset that matches your task, or select 'Custom Prompt' to write your own specialized instructions."
),
gr.Textbox(
label="Custom System Prompt",
value="You are a helpful AI assistant capable of analyzing images, videos, and PDF documents. Provide clear, accurate, and helpful responses to user queries.",
lines=3,
info="Edit this field when 'Custom Prompt' is selected above, or modify any preset"
),
gr.Dropdown(
label="Model",
choices=["Gemma 3 12B", "Gemma 3n E4B"],
value="Gemma 3 12B",
info="Gemma 3 12B: More powerful and detailed responses, but slower processing. Gemma 3n E4B: Faster processing with efficient performance for most tasks."
),
gr.Slider(
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
),
gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2),
gr.Slider(
label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
),
gr.Slider(
label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
),
gr.Slider(
label="Top K", minimum=1, maximum=100, step=1, value=50
),
gr.Slider(
label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
)
],
stop_btn=False,
)
# Connect the dropdown to update the textbox
with demo:
preset_dropdown = demo.additional_inputs[0]
custom_textbox = demo.additional_inputs[1]
preset_dropdown.change(
fn=update_custom_prompt,
inputs=[preset_dropdown],
outputs=[custom_textbox]
)
if __name__ == "__main__":
demo.launch()