AC2513 commited on
Commit
0df6b7c
·
1 Parent(s): 29fdefa

update model sizes

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -22,13 +22,13 @@ dotenv_path = find_dotenv()
22
 
23
  load_dotenv(dotenv_path)
24
 
25
- model_27_id = os.getenv("MODEL_27_ID", "google/gemma-3-4b-it")
26
- model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3-4b-it")
27
 
28
- input_processor = Gemma3Processor.from_pretrained(model_27_id)
29
 
30
- model_27 = Gemma3ForConditionalGeneration.from_pretrained(
31
- model_27_id,
32
  torch_dtype=torch.bfloat16,
33
  device_map="auto",
34
  attn_implementation="eager",
@@ -167,7 +167,7 @@ def run(
167
  tokenize=True,
168
  return_dict=True,
169
  return_tensors="pt",
170
- ).to(device=model_27.device, dtype=torch.bfloat16)
171
 
172
  streamer = TextIteratorStreamer(
173
  input_processor, skip_prompt=True, skip_special_tokens=True
@@ -182,7 +182,7 @@ def run(
182
  repetition_penalty=repetition_penalty,
183
  do_sample=True,
184
  )
185
- t = Thread(target=model_27.generate, kwargs=generate_kwargs)
186
  t.start()
187
 
188
  output = ""
 
22
 
23
  load_dotenv(dotenv_path)
24
 
25
+ model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-1b-it")
26
+ model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3-1b-it")
27
 
28
+ input_processor = Gemma3Processor.from_pretrained(model_12_id)
29
 
30
+ model_12 = Gemma3ForConditionalGeneration.from_pretrained(
31
+ model_12_id,
32
  torch_dtype=torch.bfloat16,
33
  device_map="auto",
34
  attn_implementation="eager",
 
167
  tokenize=True,
168
  return_dict=True,
169
  return_tensors="pt",
170
+ ).to(device=model_12.device, dtype=torch.bfloat16)
171
 
172
  streamer = TextIteratorStreamer(
173
  input_processor, skip_prompt=True, skip_special_tokens=True
 
182
  repetition_penalty=repetition_penalty,
183
  do_sample=True,
184
  )
185
+ t = Thread(target=model_12.generate, kwargs=generate_kwargs)
186
  t.start()
187
 
188
  output = ""