AC2513 commited on
Commit
4322777
·
1 Parent(s): 147ddab

added fallback for generation

Browse files
Files changed (1) hide show
  1. app.py +129 -33
app.py CHANGED
@@ -239,46 +239,142 @@ def run(
239
  f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
240
  )
241
 
 
 
 
 
 
 
242
  selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
 
243
 
244
- messages = []
245
- if system_prompt:
 
 
 
 
 
246
  messages.append(
247
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]}
248
  )
249
- messages.extend(process_history(history))
250
- messages.append(
251
- {"role": "user", "content": process_user_input(message, max_images)}
252
- )
253
 
254
- inputs = input_processor.apply_chat_template(
255
- messages,
256
- add_generation_prompt=True,
257
- tokenize=True,
258
- return_dict=True,
259
- return_tensors="pt",
260
- ).to(device=selected_model.device, dtype=torch.bfloat16)
261
 
262
- streamer = TextIteratorStreamer(
263
- input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
264
- )
265
- generate_kwargs = dict(
266
- inputs,
267
- streamer=streamer,
268
- max_new_tokens=max_new_tokens,
269
- temperature=temperature,
270
- top_p=top_p,
271
- top_k=top_k,
272
- repetition_penalty=repetition_penalty,
273
- do_sample=True,
274
- )
275
- t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
276
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- output = ""
279
- for delta in streamer:
280
- output += delta
281
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
 
284
  demo = gr.ChatInterface(
 
239
  f"system_prompt: {system_prompt} \n model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
240
  )
241
 
242
+ def try_fallback_model(original_model_choice: str):
243
+ fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
244
+ fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
245
+ logger.info(f"Attempting fallback to {fallback_name} model")
246
+ return fallback_model, fallback_name
247
+
248
  selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
249
+ current_model_name = model_choice
250
 
251
+ try:
252
+ messages = []
253
+ if system_prompt:
254
+ messages.append(
255
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]}
256
+ )
257
+ messages.extend(process_history(history))
258
  messages.append(
259
+ {"role": "user", "content": process_user_input(message, max_images)}
260
  )
 
 
 
 
261
 
262
+ inputs = input_processor.apply_chat_template(
263
+ messages,
264
+ add_generation_prompt=True,
265
+ tokenize=True,
266
+ return_dict=True,
267
+ return_tensors="pt",
268
+ ).to(device=selected_model.device, dtype=torch.bfloat16)
269
 
270
+ streamer = TextIteratorStreamer(
271
+ input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
272
+ )
273
+ generate_kwargs = dict(
274
+ inputs,
275
+ streamer=streamer,
276
+ max_new_tokens=max_new_tokens,
277
+ temperature=temperature,
278
+ top_p=top_p,
279
+ top_k=top_k,
280
+ repetition_penalty=repetition_penalty,
281
+ do_sample=True,
282
+ )
283
+
284
+ t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
285
+ t.start()
286
+
287
+ output = ""
288
+ generation_failed = False
289
+
290
+ try:
291
+ for delta in streamer:
292
+ if delta is None:
293
+ continue
294
+ output += delta
295
+ yield output
296
+
297
+ except Exception as stream_error:
298
+ logger.error(f"Streaming failed with {current_model_name}: {stream_error}")
299
+ generation_failed = True
300
+
301
+ # Wait for thread to complete
302
+ t.join(timeout=120) # 2 minute timeout
303
+
304
+ if t.is_alive() or generation_failed or not output.strip():
305
+ raise Exception(f"Generation failed or timed out with {current_model_name}")
306
+
307
+ except Exception as primary_error:
308
+ logger.error(f"Primary model ({current_model_name}) failed: {primary_error}")
309
+
310
+ # Try fallback model
311
+ try:
312
+ selected_model, fallback_name = try_fallback_model(model_choice)
313
+ logger.info(f"Switching to fallback model: {fallback_name}")
314
+
315
+ # Rebuild inputs for fallback model
316
+ inputs = input_processor.apply_chat_template(
317
+ messages,
318
+ add_generation_prompt=True,
319
+ tokenize=True,
320
+ return_dict=True,
321
+ return_tensors="pt",
322
+ ).to(device=selected_model.device, dtype=torch.bfloat16)
323
+
324
+ streamer = TextIteratorStreamer(
325
+ input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
326
+ )
327
+ generate_kwargs = dict(
328
+ inputs,
329
+ streamer=streamer,
330
+ max_new_tokens=max_new_tokens,
331
+ temperature=temperature,
332
+ top_p=top_p,
333
+ top_k=top_k,
334
+ repetition_penalty=repetition_penalty,
335
+ do_sample=True,
336
+ )
337
+
338
+ t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
339
+ t.start()
340
 
341
+ output = f"⚠️ Switched to {fallback_name} due to {current_model_name} failure.\n\n"
342
+ yield output
343
+
344
+ try:
345
+ for delta in streamer:
346
+ if delta is None:
347
+ continue
348
+ output += delta
349
+ yield output
350
+ except Exception as fallback_stream_error:
351
+ logger.error(f"Fallback streaming failed: {fallback_stream_error}")
352
+ raise fallback_stream_error
353
+
354
+ # Wait for fallback thread
355
+ t.join(timeout=120)
356
+
357
+ if t.is_alive() or not output.strip():
358
+ raise Exception(f"Fallback model {fallback_name} also failed")
359
+
360
+ except Exception as fallback_error:
361
+ logger.error(f"Fallback model also failed: {fallback_error}")
362
+
363
+ # Final fallback - return error message
364
+ error_message = (
365
+ "❌ **Generation Failed**\n\n"
366
+ f"Both {model_choice} and fallback model encountered errors. "
367
+ "This could be due to:\n"
368
+ "- High server load\n"
369
+ "- Memory constraints\n"
370
+ "- Input complexity\n\n"
371
+ "**Suggestions:**\n"
372
+ "- Try reducing max tokens or image count\n"
373
+ "- Simplify your prompt\n"
374
+ "- Try again in a few moments\n\n"
375
+ f"*Error details: {str(primary_error)[:200]}...*"
376
+ )
377
+ yield error_message
378
 
379
 
380
  demo = gr.ChatInterface(