Spaces:
Running
on
Zero
Running
on
Zero
added dropdown for user to select models
Browse files
app.py
CHANGED
@@ -22,8 +22,8 @@ dotenv_path = find_dotenv()
|
|
22 |
|
23 |
load_dotenv(dotenv_path)
|
24 |
|
25 |
-
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-
|
26 |
-
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-
|
27 |
|
28 |
input_processor = Gemma3Processor.from_pretrained(model_12_id)
|
29 |
|
@@ -138,6 +138,7 @@ def run(
|
|
138 |
message: dict,
|
139 |
history: list[dict],
|
140 |
system_prompt: str,
|
|
|
141 |
max_new_tokens: int,
|
142 |
max_images: int,
|
143 |
temperature: float,
|
@@ -148,9 +149,11 @@ def run(
|
|
148 |
|
149 |
logger.debug(
|
150 |
f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
|
151 |
-
f"max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
|
152 |
)
|
153 |
|
|
|
|
|
154 |
messages = []
|
155 |
if system_prompt:
|
156 |
messages.append(
|
@@ -167,7 +170,7 @@ def run(
|
|
167 |
tokenize=True,
|
168 |
return_dict=True,
|
169 |
return_tensors="pt",
|
170 |
-
).to(device=
|
171 |
|
172 |
streamer = TextIteratorStreamer(
|
173 |
input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
@@ -182,7 +185,7 @@ def run(
|
|
182 |
repetition_penalty=repetition_penalty,
|
183 |
do_sample=True,
|
184 |
)
|
185 |
-
t = Thread(target=
|
186 |
t.start()
|
187 |
|
188 |
output = ""
|
@@ -201,6 +204,11 @@ demo = gr.ChatInterface(
|
|
201 |
multimodal=True,
|
202 |
additional_inputs=[
|
203 |
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
|
|
|
|
|
|
|
|
|
|
204 |
gr.Slider(
|
205 |
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
|
206 |
),
|
|
|
22 |
|
23 |
load_dotenv(dotenv_path)
|
24 |
|
25 |
+
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
|
26 |
+
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
|
27 |
|
28 |
input_processor = Gemma3Processor.from_pretrained(model_12_id)
|
29 |
|
|
|
138 |
message: dict,
|
139 |
history: list[dict],
|
140 |
system_prompt: str,
|
141 |
+
model_choice: str,
|
142 |
max_new_tokens: int,
|
143 |
max_images: int,
|
144 |
temperature: float,
|
|
|
149 |
|
150 |
logger.debug(
|
151 |
f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
|
152 |
+
f"model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
|
153 |
)
|
154 |
|
155 |
+
selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
|
156 |
+
|
157 |
messages = []
|
158 |
if system_prompt:
|
159 |
messages.append(
|
|
|
170 |
tokenize=True,
|
171 |
return_dict=True,
|
172 |
return_tensors="pt",
|
173 |
+
).to(device=selected_model.device, dtype=torch.bfloat16)
|
174 |
|
175 |
streamer = TextIteratorStreamer(
|
176 |
input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
|
|
|
185 |
repetition_penalty=repetition_penalty,
|
186 |
do_sample=True,
|
187 |
)
|
188 |
+
t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
|
189 |
t.start()
|
190 |
|
191 |
output = ""
|
|
|
204 |
multimodal=True,
|
205 |
additional_inputs=[
|
206 |
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
207 |
+
gr.Dropdown(
|
208 |
+
label="Model",
|
209 |
+
choices=["Gemma 3 12B", "Gemma 3n E4B"],
|
210 |
+
value="Gemma 3 12B"
|
211 |
+
),
|
212 |
gr.Slider(
|
213 |
label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
|
214 |
),
|