AC2513 commited on
Commit
78ef809
·
1 Parent(s): c0fc237

added more models

Browse files
Files changed (2) hide show
  1. app.py +24 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,6 +5,8 @@ from transformers import (
5
  Gemma3ForConditionalGeneration,
6
  TextIteratorStreamer,
7
  Gemma3Processor,
 
 
8
  )
9
  import spaces
10
  import tempfile
@@ -20,12 +22,28 @@ dotenv_path = find_dotenv()
20
 
21
  load_dotenv(dotenv_path)
22
 
23
- model_id = os.getenv("MODEL_ID", "google/gemma-3-4b-it")
 
 
24
 
25
- input_processor = Gemma3Processor.from_pretrained(model_id)
26
 
27
- model = Gemma3ForConditionalGeneration.from_pretrained(
28
- model_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  torch_dtype=torch.bfloat16,
30
  device_map="auto",
31
  attn_implementation="eager",
@@ -157,7 +175,7 @@ def run(
157
  tokenize=True,
158
  return_dict=True,
159
  return_tensors="pt",
160
- ).to(device=model.device, dtype=torch.bfloat16)
161
 
162
  streamer = TextIteratorStreamer(
163
  input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
@@ -172,7 +190,7 @@ def run(
172
  repetition_penalty=repetition_penalty,
173
  do_sample=True,
174
  )
175
- t = Thread(target=model.generate, kwargs=generate_kwargs)
176
  t.start()
177
 
178
  output = ""
 
5
  Gemma3ForConditionalGeneration,
6
  TextIteratorStreamer,
7
  Gemma3Processor,
8
+ Gemma3nForConditionalGeneration,
9
+ Gemma3ForCausalLM
10
  )
11
  import spaces
12
  import tempfile
 
22
 
23
  load_dotenv(dotenv_path)
24
 
25
+ model_27_id = os.getenv("MODEL_27_ID", "google/gemma-3-4b-it")
26
+ model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-4b-it")
27
+ model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3-4b-it")
28
 
29
+ input_processor = Gemma3Processor.from_pretrained(model_27_id)
30
 
31
+ model_27 = Gemma3ForConditionalGeneration.from_pretrained(
32
+ model_27_id,
33
+ torch_dtype=torch.bfloat16,
34
+ device_map="auto",
35
+ attn_implementation="eager",
36
+ )
37
+
38
+ model_12 = Gemma3ForCausalLM.from_pretrained(
39
+ model_12_id,
40
+ torch_dtype=torch.bfloat16,
41
+ device_map="auto",
42
+ attn_implementation="eager",
43
+ )
44
+
45
+ model_3n = Gemma3nForConditionalGeneration.from_pretrained(
46
+ model_3n_id,
47
  torch_dtype=torch.bfloat16,
48
  device_map="auto",
49
  attn_implementation="eager",
 
175
  tokenize=True,
176
  return_dict=True,
177
  return_tensors="pt",
178
+ ).to(device=model_27.device, dtype=torch.bfloat16)
179
 
180
  streamer = TextIteratorStreamer(
181
  input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
 
190
  repetition_penalty=repetition_penalty,
191
  do_sample=True,
192
  )
193
+ t = Thread(target=model_27.generate, kwargs=generate_kwargs)
194
  t.start()
195
 
196
  output = ""
requirements.txt CHANGED
@@ -6,4 +6,5 @@ accelerate
6
  pytest
7
  loguru
8
  python-dotenv
9
- opencv-python
 
 
6
  pytest
7
  loguru
8
  python-dotenv
9
+ opencv-python
10
+ timm