Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,450 Bytes
b5bcf5a ce64ebd b5bcf5a 76e1435 78ef809 a1ac37a 76e1435 b5bcf5a 6a31985 b5bcf5a ceb2ea0 6a31985 c0fc237 6a31985 d7dcf58 c0fc237 a1ac37a 97a997a f3a1509 97a997a f3a1509 45a53c4 76e1435 7641a99 397b627 707a100 d7dcf58 397b627 fd3c6d5 7641a99 c0fc237 707a100 ceb2ea0 707a100 993e246 707a100 0c4170f 707a100 0c4170f 7641a99 4322777 a1ac37a 4322777 a1ac37a 4322777 d7dcf58 a1ac37a 4322777 d7dcf58 4322777 7f350a5 7641a99 7f350a5 7641a99 7f350a5 7641a99 a1ac37a 4322777 7641a99 4322777 a1ac37a 4322777 7f350a5 4322777 a1ac37a 4322777 a1ac37a 4322777 a1ac37a 4322777 7f350a5 4322777 7641a99 4322777 397b627 b87bea7 c0fc237 1a184e0 c0fc237 707a100 993e246 707a100 993e246 707a100 993e246 707a100 d7dcf58 7d5f6e0 1a184e0 d7dcf58 c0fc237 397b627 707a100 1a184e0 707a100 1a184e0 707a100 397b627 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
import torch
torch._dynamo.config.disable = True
from collections.abc import Iterator
from transformers import (
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
Gemma3Processor,
Gemma3nForConditionalGeneration,
Gemma3nProcessor
)
import spaces
from threading import Thread
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv
from loguru import logger
from utils import *
dotenv_path = find_dotenv()
load_dotenv(dotenv_path)
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
input_processor_12 = Gemma3Processor.from_pretrained(model_12_id)
input_processor_3n = Gemma3nProcessor.from_pretrained(model_3n_id)
model_12 = Gemma3ForConditionalGeneration.from_pretrained(
model_12_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
model_3n_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager",
)
@spaces.GPU(duration=120)
def run(
message: dict,
history: list[dict],
system_prompt_preset: str,
custom_system_prompt: str,
model_choice: str,
max_new_tokens: int,
max_images: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
) -> Iterator[str]:
# Define preset system prompts
preset_prompts = get_preset_prompts()
# Determine which system prompt to use
if system_prompt_preset == "Custom Prompt":
system_prompt = custom_system_prompt
else:
system_prompt = preset_prompts.get(system_prompt_preset, custom_system_prompt)
logger.debug(
f"\n message: {message} \n history: {history} \n system_prompt_preset: {system_prompt_preset} \n "
f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
)
def try_fallback_model(original_model_choice: str):
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
fallback_processor = input_processor_3n if original_model_choice == "Gemma 3 12B" else input_processor_12
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
logger.info(f"Attempting fallback to {fallback_name} model")
return fallback_model, fallback_processor, fallback_name
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
selected_processor = input_processor_12 if model_choice == "Gemma 3 12B" else input_processor_3n
current_model_name = model_choice
try:
messages = []
if system_prompt:
messages.append(
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
)
messages.extend(process_history(history))
user_content = process_user_input(message, max_images)
messages.append(
{"role": "user", "content": user_content}
)
# Validate messages structure before processing
logger.debug(f"Final messages structure: {len(messages)} messages")
for i, msg in enumerate(messages):
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
inputs = selected_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=selected_model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
selected_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
)
# Wrapper function to catch thread exceptions
def safe_generate():
try:
selected_model.generate(**generate_kwargs)
except Exception as thread_e:
logger.error(f"Exception in generation thread: {thread_e}")
logger.error(f"Thread exception type: {type(thread_e)}")
# Store the exception so we can handle it in the main thread
import traceback
logger.error(f"Thread traceback: {traceback.format_exc()}")
raise
t = Thread(target=safe_generate)
t.start()
output = ""
generation_failed = False
try:
for delta in streamer:
if delta is None:
continue
output += delta
yield output
except Exception as stream_error:
logger.error(f"Streaming failed with {current_model_name}: {stream_error}")
generation_failed = True
# Wait for thread to complete
t.join(timeout=120) # 2 minute timeout
if t.is_alive() or generation_failed or not output.strip():
raise Exception(f"Generation failed or timed out with {current_model_name}")
except Exception as primary_error:
logger.error(f"Primary model ({current_model_name}) failed: {primary_error}")
# Try fallback model
try:
selected_model, fallback_processor, fallback_name = try_fallback_model(model_choice)
logger.info(f"Switching to fallback model: {fallback_name}")
# Rebuild inputs for fallback model
inputs = fallback_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=selected_model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
fallback_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
)
# Wrapper function to catch thread exceptions in fallback
def safe_fallback_generate():
try:
selected_model.generate(**generate_kwargs)
except Exception as thread_e:
logger.error(f"Exception in fallback generation thread: {thread_e}")
logger.error(f"Fallback thread exception type: {type(thread_e)}")
import traceback
logger.error(f"Fallback thread traceback: {traceback.format_exc()}")
raise
t = Thread(target=safe_fallback_generate)
t.start()
output = f"⚠️ Switched to {fallback_name} due to {current_model_name} failure.\n\n"
yield output
try:
for delta in streamer:
if delta is None:
continue
output += delta
yield output
except Exception as fallback_stream_error:
logger.error(f"Fallback streaming failed: {fallback_stream_error}")
raise fallback_stream_error
# Wait for fallback thread
t.join(timeout=120)
if t.is_alive() or not output.strip():
raise Exception(f"Fallback model {fallback_name} also failed")
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
# Final fallback - return error message
error_message = (
"❌ **Generation Failed**\n\n"
f"Both {model_choice} and fallback model encountered errors. "
"This could be due to:\n"
"- High server load\n"
"- Memory constraints\n"
"- Input complexity\n\n"
"**Suggestions:**\n"
"- Try reducing max tokens or image count\n"
"- Simplify your prompt\n"
"- Try again in a few moments\n\n"
f"*Error details: {str(primary_error)[:200]}...*"
)
yield error_message
demo = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=[".mp4", ".jpg", ".png", ".pdf"], file_count="multiple", autofocus=True
),
multimodal=True,
additional_inputs=[
gr.Dropdown(
label="System Prompt Preset",
choices=[
"General Assistant",
"Document Analyzer",
"Visual Content Expert",
"Educational Tutor",
"Technical Reviewer",
"Creative Storyteller",
"Custom Prompt"
],
value="General Assistant",
info="System prompts define the AI's role and behavior. Choose a preset that matches your task, or select 'Custom Prompt' to write your own specialized instructions."
),
gr.Textbox(
label="Custom System Prompt",
value="You are a helpful AI assistant capable of analyzing images, videos, and PDF documents. Provide clear, accurate, and helpful responses to user queries.",
lines=3,
info="Edit this field when 'Custom Prompt' is selected above, or modify any preset"
),
gr.Dropdown(
label="Model",
choices=["Gemma 3 12B", "Gemma 3n E4B"],
value="Gemma 3 12B",
info="Gemma 3 12B: More powerful and detailed responses, but slower processing. Gemma 3n E4B: Faster processing with efficient performance for most tasks."
),
gr.Slider(
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
),
gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2),
gr.Slider(
label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
),
gr.Slider(
label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
),
gr.Slider(
label="Top K", minimum=1, maximum=100, step=1, value=50
),
gr.Slider(
label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
)
],
stop_btn=False,
)
# Connect the dropdown to update the textbox
with demo:
preset_dropdown = demo.additional_inputs[0]
custom_textbox = demo.additional_inputs[1]
preset_dropdown.change(
fn=update_custom_prompt,
inputs=[preset_dropdown],
outputs=[custom_textbox]
)
if __name__ == "__main__":
demo.launch()
|