Spaces:
Running
on
Zero
Running
on
Zero
added fallback for generation
Browse files
app.py
CHANGED
@@ -239,46 +239,142 @@ def run(
|
|
239 |
f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
|
240 |
)
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
|
|
243 |
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
246 |
messages.append(
|
247 |
-
{"role": "
|
248 |
)
|
249 |
-
messages.extend(process_history(history))
|
250 |
-
messages.append(
|
251 |
-
{"role": "user", "content": process_user_input(message, max_images)}
|
252 |
-
)
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
|
284 |
demo = gr.ChatInterface(
|
|
|
239 |
f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
|
240 |
)
|
241 |
|
242 |
+
def try_fallback_model(original_model_choice: str):
|
243 |
+
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
|
244 |
+
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
|
245 |
+
logger.info(f"Attempting fallback to {fallback_name} model")
|
246 |
+
return fallback_model, fallback_name
|
247 |
+
|
248 |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
249 |
+
current_model_name = model_choice
|
250 |
|
251 |
+
try:
|
252 |
+
messages = []
|
253 |
+
if system_prompt:
|
254 |
+
messages.append(
|
255 |
+
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
|
256 |
+
)
|
257 |
+
messages.extend(process_history(history))
|
258 |
messages.append(
|
259 |
+
{"role": "user", "content": process_user_input(message, max_images)}
|
260 |
)
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
inputs = input_processor.apply_chat_template(
|
263 |
+
messages,
|
264 |
+
add_generation_prompt=True,
|
265 |
+
tokenize=True,
|
266 |
+
return_dict=True,
|
267 |
+
return_tensors="pt",
|
268 |
+
).to(device=selected_model.device, dtype=torch.bfloat16)
|
269 |
|
270 |
+
streamer = TextIteratorStreamer(
|
271 |
+
input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
272 |
+
)
|
273 |
+
generate_kwargs = dict(
|
274 |
+
inputs,
|
275 |
+
streamer=streamer,
|
276 |
+
max_new_tokens=max_new_tokens,
|
277 |
+
temperature=temperature,
|
278 |
+
top_p=top_p,
|
279 |
+
top_k=top_k,
|
280 |
+
repetition_penalty=repetition_penalty,
|
281 |
+
do_sample=True,
|
282 |
+
)
|
283 |
+
|
284 |
+
t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
|
285 |
+
t.start()
|
286 |
+
|
287 |
+
output = ""
|
288 |
+
generation_failed = False
|
289 |
+
|
290 |
+
try:
|
291 |
+
for delta in streamer:
|
292 |
+
if delta is None:
|
293 |
+
continue
|
294 |
+
output += delta
|
295 |
+
yield output
|
296 |
+
|
297 |
+
except Exception as stream_error:
|
298 |
+
logger.error(f"Streaming failed with {current_model_name}: {stream_error}")
|
299 |
+
generation_failed = True
|
300 |
+
|
301 |
+
# Wait for thread to complete
|
302 |
+
t.join(timeout=120) # 2 minute timeout
|
303 |
+
|
304 |
+
if t.is_alive() or generation_failed or not output.strip():
|
305 |
+
raise Exception(f"Generation failed or timed out with {current_model_name}")
|
306 |
+
|
307 |
+
except Exception as primary_error:
|
308 |
+
logger.error(f"Primary model ({current_model_name}) failed: {primary_error}")
|
309 |
+
|
310 |
+
# Try fallback model
|
311 |
+
try:
|
312 |
+
selected_model, fallback_name = try_fallback_model(model_choice)
|
313 |
+
logger.info(f"Switching to fallback model: {fallback_name}")
|
314 |
+
|
315 |
+
# Rebuild inputs for fallback model
|
316 |
+
inputs = input_processor.apply_chat_template(
|
317 |
+
messages,
|
318 |
+
add_generation_prompt=True,
|
319 |
+
tokenize=True,
|
320 |
+
return_dict=True,
|
321 |
+
return_tensors="pt",
|
322 |
+
).to(device=selected_model.device, dtype=torch.bfloat16)
|
323 |
+
|
324 |
+
streamer = TextIteratorStreamer(
|
325 |
+
input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
326 |
+
)
|
327 |
+
generate_kwargs = dict(
|
328 |
+
inputs,
|
329 |
+
streamer=streamer,
|
330 |
+
max_new_tokens=max_new_tokens,
|
331 |
+
temperature=temperature,
|
332 |
+
top_p=top_p,
|
333 |
+
top_k=top_k,
|
334 |
+
repetition_penalty=repetition_penalty,
|
335 |
+
do_sample=True,
|
336 |
+
)
|
337 |
+
|
338 |
+
t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
|
339 |
+
t.start()
|
340 |
|
341 |
+
output = f"⚠️ Switched to {fallback_name} due to {current_model_name} failure.\n\n"
|
342 |
+
yield output
|
343 |
+
|
344 |
+
try:
|
345 |
+
for delta in streamer:
|
346 |
+
if delta is None:
|
347 |
+
continue
|
348 |
+
output += delta
|
349 |
+
yield output
|
350 |
+
except Exception as fallback_stream_error:
|
351 |
+
logger.error(f"Fallback streaming failed: {fallback_stream_error}")
|
352 |
+
raise fallback_stream_error
|
353 |
+
|
354 |
+
# Wait for fallback thread
|
355 |
+
t.join(timeout=120)
|
356 |
+
|
357 |
+
if t.is_alive() or not output.strip():
|
358 |
+
raise Exception(f"Fallback model {fallback_name} also failed")
|
359 |
+
|
360 |
+
except Exception as fallback_error:
|
361 |
+
logger.error(f"Fallback model also failed: {fallback_error}")
|
362 |
+
|
363 |
+
# Final fallback - return error message
|
364 |
+
error_message = (
|
365 |
+
"❌ **Generation Failed**\n\n"
|
366 |
+
f"Both {model_choice} and fallback model encountered errors. "
|
367 |
+
"This could be due to:\n"
|
368 |
+
"- High server load\n"
|
369 |
+
"- Memory constraints\n"
|
370 |
+
"- Input complexity\n\n"
|
371 |
+
"**Suggestions:**\n"
|
372 |
+
"- Try reducing max tokens or image count\n"
|
373 |
+
"- Simplify your prompt\n"
|
374 |
+
"- Try again in a few moments\n\n"
|
375 |
+
f"*Error details: {str(primary_error)[:200]}...*"
|
376 |
+
)
|
377 |
+
yield error_message
|
378 |
|
379 |
|
380 |
demo = gr.ChatInterface(
|