Spaces:
Sleeping
Sleeping
altered processor due to huggingface update
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from transformers import (
|
|
6 |
TextIteratorStreamer,
|
7 |
Gemma3Processor,
|
8 |
Gemma3nForConditionalGeneration,
|
|
|
9 |
)
|
10 |
import spaces
|
11 |
from threading import Thread
|
@@ -22,7 +23,8 @@ load_dotenv(dotenv_path)
|
|
22 |
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
|
23 |
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
|
24 |
|
25 |
-
|
|
|
26 |
|
27 |
model_12 = Gemma3ForConditionalGeneration.from_pretrained(
|
28 |
model_12_id,
|
@@ -70,11 +72,13 @@ def run(
|
|
70 |
|
71 |
def try_fallback_model(original_model_choice: str):
|
72 |
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
|
|
|
73 |
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
|
74 |
logger.info(f"Attempting fallback to {fallback_name} model")
|
75 |
-
return fallback_model, fallback_name
|
76 |
|
77 |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
|
|
78 |
current_model_name = model_choice
|
79 |
|
80 |
try:
|
@@ -94,7 +98,7 @@ def run(
|
|
94 |
for i, msg in enumerate(messages):
|
95 |
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
|
96 |
|
97 |
-
inputs =
|
98 |
messages,
|
99 |
add_generation_prompt=True,
|
100 |
tokenize=True,
|
@@ -103,7 +107,7 @@ def run(
|
|
103 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
104 |
|
105 |
streamer = TextIteratorStreamer(
|
106 |
-
|
107 |
)
|
108 |
generate_kwargs = dict(
|
109 |
inputs,
|
@@ -156,11 +160,11 @@ def run(
|
|
156 |
|
157 |
# Try fallback model
|
158 |
try:
|
159 |
-
selected_model, fallback_name = try_fallback_model(model_choice)
|
160 |
logger.info(f"Switching to fallback model: {fallback_name}")
|
161 |
|
162 |
# Rebuild inputs for fallback model
|
163 |
-
inputs =
|
164 |
messages,
|
165 |
add_generation_prompt=True,
|
166 |
tokenize=True,
|
@@ -169,7 +173,7 @@ def run(
|
|
169 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
170 |
|
171 |
streamer = TextIteratorStreamer(
|
172 |
-
|
173 |
)
|
174 |
generate_kwargs = dict(
|
175 |
inputs,
|
|
|
6 |
TextIteratorStreamer,
|
7 |
Gemma3Processor,
|
8 |
Gemma3nForConditionalGeneration,
|
9 |
+
Gemma3nProcessor
|
10 |
)
|
11 |
import spaces
|
12 |
from threading import Thread
|
|
|
23 |
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
|
24 |
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
|
25 |
|
26 |
+
input_processor_12 = Gemma3Processor.from_pretrained(model_12_id)
|
27 |
+
input_processor_3n = Gemma3nProcessor.from_pretrained(model_3n_id)
|
28 |
|
29 |
model_12 = Gemma3ForConditionalGeneration.from_pretrained(
|
30 |
model_12_id,
|
|
|
72 |
|
73 |
def try_fallback_model(original_model_choice: str):
|
74 |
fallback_model = model_3n if original_model_choice == "Gemma 3 12B" else model_12
|
75 |
+
fallback_processor = input_processor_3n if original_model_choice == "Gemma 3 12B" else input_processor_12
|
76 |
fallback_name = "Gemma 3n E4B" if original_model_choice == "Gemma 3 12B" else "Gemma 3 12B"
|
77 |
logger.info(f"Attempting fallback to {fallback_name} model")
|
78 |
+
return fallback_model, fallback_processor, fallback_name
|
79 |
|
80 |
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
81 |
+
selected_processor = input_processor_12 if model_choice == "Gemma 3 12B" else input_processor_3n
|
82 |
current_model_name = model_choice
|
83 |
|
84 |
try:
|
|
|
98 |
for i, msg in enumerate(messages):
|
99 |
logger.debug(f"Message {i}: role={msg.get('role', 'MISSING')}, content_type={type(msg.get('content', 'MISSING'))}")
|
100 |
|
101 |
+
inputs = selected_processor.apply_chat_template(
|
102 |
messages,
|
103 |
add_generation_prompt=True,
|
104 |
tokenize=True,
|
|
|
107 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
108 |
|
109 |
streamer = TextIteratorStreamer(
|
110 |
+
selected_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
111 |
)
|
112 |
generate_kwargs = dict(
|
113 |
inputs,
|
|
|
160 |
|
161 |
# Try fallback model
|
162 |
try:
|
163 |
+
selected_model, fallback_processor, fallback_name = try_fallback_model(model_choice)
|
164 |
logger.info(f"Switching to fallback model: {fallback_name}")
|
165 |
|
166 |
# Rebuild inputs for fallback model
|
167 |
+
inputs = fallback_processor.apply_chat_template(
|
168 |
messages,
|
169 |
add_generation_prompt=True,
|
170 |
tokenize=True,
|
|
|
173 |
).to(device=selected_model.device, dtype=torch.bfloat16)
|
174 |
|
175 |
streamer = TextIteratorStreamer(
|
176 |
+
fallback_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
177 |
)
|
178 |
generate_kwargs = dict(
|
179 |
inputs,
|