AC2513 commited on
Commit
a1ac37a
·
1 Parent(s): 1a184e0

altered processor due to huggingface update

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import (
6
  TextIteratorStreamer,
7
  Gemma3Processor,
8
  Gemma3nForConditionalGeneration,
 
9
  )
10
  import spaces
11
  from threading import Thread
@@ -22,7 +23,8 @@ load_dotenv(dotenv_path)
22
  model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
23
  model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
24
 
25
- input_processor = Gemma3Processor.from_pretrained(model_12_id)
 
26
 
27
  model_12 = Gemma3ForConditionalGeneration.from_pretrained(
28
  model_12_id,
@@ -70,11 +72,13 @@ def run(
70
 
71
  def try_fallback_model(original_model_choice: str):
72
  fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
 
73
  fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
74
  logger.info(f"Attempting fallback to {fallback_name} model")
75
- return fallback_model, fallback_name
76
 
77
  selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
 
78
  current_model_name = model_choice
79
 
80
  try:
@@ -94,7 +98,7 @@ def run(
94
  for i, msg in enumerate(messages):
95
  logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
96
 
97
- inputs = input_processor.apply_chat_template(
98
  messages,
99
  add_generation_prompt=True,
100
  tokenize=True,
@@ -103,7 +107,7 @@ def run(
103
  ).to(device=selected_model.device, dtype=torch.bfloat16)
104
 
105
  streamer = TextIteratorStreamer(
106
- input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
107
  )
108
  generate_kwargs = dict(
109
  inputs,
@@ -156,11 +160,11 @@ def run(
156
 
157
  # Try fallback model
158
  try:
159
- selected_model, fallback_name = try_fallback_model(model_choice)
160
  logger.info(f"Switching to fallback model: {fallback_name}")
161
 
162
  # Rebuild inputs for fallback model
163
- inputs = input_processor.apply_chat_template(
164
  messages,
165
  add_generation_prompt=True,
166
  tokenize=True,
@@ -169,7 +173,7 @@ def run(
169
  ).to(device=selected_model.device, dtype=torch.bfloat16)
170
 
171
  streamer = TextIteratorStreamer(
172
- input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
173
  )
174
  generate_kwargs = dict(
175
  inputs,
 
6
  TextIteratorStreamer,
7
  Gemma3Processor,
8
  Gemma3nForConditionalGeneration,
9
+ Gemma3nProcessor
10
  )
11
  import spaces
12
  from threading import Thread
 
23
  model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
24
  model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
25
 
26
+ input_processor_12 = Gemma3Processor.from_pretrained(model_12_id)
27
+ input_processor_3n = Gemma3nProcessor.from_pretrained(model_3n_id)
28
 
29
  model_12 = Gemma3ForConditionalGeneration.from_pretrained(
30
  model_12_id,
 
72
 
73
  def try_fallback_model(original_model_choice: str):
74
  fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
75
+ fallback_processor = input_processor_3n if original_model_choice == "Gemma 3 12B" else input_processor_12
76
  fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
77
  logger.info(f"Attempting fallback to {fallback_name} model")
78
+ return fallback_model, fallback_processor, fallback_name
79
 
80
  selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
81
+ selected_processor = input_processor_12 if model_choice == "Gemma 3 12B" else input_processor_3n
82
  current_model_name = model_choice
83
 
84
  try:
 
98
  for i, msg in enumerate(messages):
99
  logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
100
 
101
+ inputs = selected_processor.apply_chat_template(
102
  messages,
103
  add_generation_prompt=True,
104
  tokenize=True,
 
107
  ).to(device=selected_model.device, dtype=torch.bfloat16)
108
 
109
  streamer = TextIteratorStreamer(
110
+ selected_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
111
  )
112
  generate_kwargs = dict(
113
  inputs,
 
160
 
161
  # Try fallback model
162
  try:
163
+ selected_model, fallback_processor, fallback_name = try_fallback_model(model_choice)
164
  logger.info(f"Switching to fallback model: {fallback_name}")
165
 
166
  # Rebuild inputs for fallback model
167
+ inputs = fallback_processor.apply_chat_template(
168
  messages,
169
  add_generation_prompt=True,
170
  tokenize=True,
 
173
  ).to(device=selected_model.device, dtype=torch.bfloat16)
174
 
175
  streamer = TextIteratorStreamer(
176
+ fallback_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
177
  )
178
  generate_kwargs = dict(
179
  inputs,