AC2513 commited on
Commit
d7dcf58
·
1 Parent(s): 686226c

added dropdown for user to select models

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -22,8 +22,8 @@ dotenv_path = find_dotenv()
22
 
23
  load_dotenv(dotenv_path)
24
 
25
- model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-1b-it")
26
- model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3-1b-it")
27
 
28
  input_processor = Gemma3Processor.from_pretrained(model_12_id)
29
 
@@ -138,6 +138,7 @@ def run(
138
  message: dict,
139
  history: list[dict],
140
  system_prompt: str,
 
141
  max_new_tokens: int,
142
  max_images: int,
143
  temperature: float,
@@ -148,9 +149,11 @@ def run(
148
 
149
  logger.debug(
150
  f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
151
- f"max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
152
  )
153
 
 
 
154
  messages = []
155
  if system_prompt:
156
  messages.append(
@@ -167,7 +170,7 @@ def run(
167
  tokenize=True,
168
  return_dict=True,
169
  return_tensors="pt",
170
- ).to(device=model_12.device, dtype=torch.bfloat16)
171
 
172
  streamer = TextIteratorStreamer(
173
  input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
@@ -182,7 +185,7 @@ def run(
182
  repetition_penalty=repetition_penalty,
183
  do_sample=True,
184
  )
185
- t = Thread(target=model_12.generate, kwargs=generate_kwargs)
186
  t.start()
187
 
188
  output = ""
@@ -201,6 +204,11 @@ demo = gr.ChatInterface(
201
  multimodal=True,
202
  additional_inputs=[
203
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
 
 
 
 
 
204
  gr.Slider(
205
  label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
206
  ),
 
22
 
23
  load_dotenv(dotenv_path)
24
 
25
+ model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-12b-it")
26
+ model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3n-E4B-it")
27
 
28
  input_processor = Gemma3Processor.from_pretrained(model_12_id)
29
 
 
138
  message: dict,
139
  history: list[dict],
140
  system_prompt: str,
141
+ model_choice: str,
142
  max_new_tokens: int,
143
  max_images: int,
144
  temperature: float,
 
149
 
150
  logger.debug(
151
  f"\n message: {message} \n history: {history} \n system_prompt: {system_prompt} \n "
152
+ f"model_choice: {model_choice} \n max_new_tokens: {max_new_tokens} \n max_images: {max_images}"
153
  )
154
 
155
+ selected_model = model_12 if model_choice == "Gemma 3 12B" else model_3n
156
+
157
  messages = []
158
  if system_prompt:
159
  messages.append(
 
170
  tokenize=True,
171
  return_dict=True,
172
  return_tensors="pt",
173
+ ).to(device=selected_model.device, dtype=torch.bfloat16)
174
 
175
  streamer = TextIteratorStreamer(
176
  input_processor, skip_prompt=True, skip_special_tokens=True, timeout=60.0
 
185
  repetition_penalty=repetition_penalty,
186
  do_sample=True,
187
  )
188
+ t = Thread(target=selected_model.generate, kwargs=generate_kwargs)
189
  t.start()
190
 
191
  output = ""
 
204
  multimodal=True,
205
  additional_inputs=[
206
  gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
207
+ gr.Dropdown(
208
+ label="Model",
209
+ choices=["Gemma 3 12B", "Gemma 3n E4B"],
210
+ value="Gemma 3 12B"
211
+ ),
212
  gr.Slider(
213
  label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
214
  ),