Spaces:
Running
on
Zero
Running
on
Zero
added more models
Browse files- app.py +24 -6
- requirements.txt +2 -1
app.py
CHANGED
@@ -5,6 +5,8 @@ from transformers import (
|
|
5 |
Gemma3ForConditionalGeneration,
|
6 |
TextIteratorStreamer,
|
7 |
Gemma3Processor,
|
|
|
|
|
8 |
)
|
9 |
import spaces
|
10 |
import tempfile
|
@@ -20,12 +22,28 @@ dotenv_path = find_dotenv()
|
|
20 |
|
21 |
load_dotenv(dotenv_path)
|
22 |
|
23 |
-
|
|
|
|
|
24 |
|
25 |
-
input_processor = Gemma3Processor.from_pretrained(
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
torch_dtype=torch.bfloat16,
|
30 |
device_map="auto",
|
31 |
attn_implementation="eager",
|
@@ -157,7 +175,7 @@ def run(
|
|
157 |
tokenize=True,
|
158 |
return_dict=True,
|
159 |
return_tensors="pt",
|
160 |
-
).to(device=
|
161 |
|
162 |
streamer = TextIteratorStreamer(
|
163 |
input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
@@ -172,7 +190,7 @@ def run(
|
|
172 |
repetition_penalty=repetition_penalty,
|
173 |
do_sample=True,
|
174 |
)
|
175 |
-
t = Thread(target=
|
176 |
t.start()
|
177 |
|
178 |
output = ""
|
|
|
5 |
Gemma3ForConditionalGeneration,
|
6 |
TextIteratorStreamer,
|
7 |
Gemma3Processor,
|
8 |
+
Gemma3nForConditionalGeneration,
|
9 |
+
Gemma3ForCausalLM
|
10 |
)
|
11 |
import spaces
|
12 |
import tempfile
|
|
|
22 |
|
23 |
load_dotenv(dotenv_path)
|
24 |
|
25 |
+
model_27_id = os.getenv("MODEL_27_ID", "google/gemma-3-4b-it")
|
26 |
+
model_12_id = os.getenv("MODEL_12_ID", "google/gemma-3-4b-it")
|
27 |
+
model_3n_id = os.getenv("MODEL_3N_ID", "google/gemma-3-4b-it")
|
28 |
|
29 |
+
input_processor = Gemma3Processor.from_pretrained(model_27_id)
|
30 |
|
31 |
+
model_27 = Gemma3ForConditionalGeneration.from_pretrained(
|
32 |
+
model_27_id,
|
33 |
+
torch_dtype=torch.bfloat16,
|
34 |
+
device_map="auto",
|
35 |
+
attn_implementation="eager",
|
36 |
+
)
|
37 |
+
|
38 |
+
model_12 = Gemma3ForCausalLM.from_pretrained(
|
39 |
+
model_12_id,
|
40 |
+
torch_dtype=torch.bfloat16,
|
41 |
+
device_map="auto",
|
42 |
+
attn_implementation="eager",
|
43 |
+
)
|
44 |
+
|
45 |
+
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
|
46 |
+
model_3n_id,
|
47 |
torch_dtype=torch.bfloat16,
|
48 |
device_map="auto",
|
49 |
attn_implementation="eager",
|
|
|
175 |
tokenize=True,
|
176 |
return_dict=True,
|
177 |
return_tensors="pt",
|
178 |
+
).to(device=model_27.device, dtype=torch.bfloat16)
|
179 |
|
180 |
streamer = TextIteratorStreamer(
|
181 |
input_processor, timeout=60.0, skip_prompt=True, skip_special_tokens=True
|
|
|
190 |
repetition_penalty=repetition_penalty,
|
191 |
do_sample=True,
|
192 |
)
|
193 |
+
t = Thread(target=model_27.generate, kwargs=generate_kwargs)
|
194 |
t.start()
|
195 |
|
196 |
output = ""
|
requirements.txt
CHANGED
@@ -6,4 +6,5 @@ accelerate
|
|
6 |
pytest
|
7 |
loguru
|
8 |
python-dotenv
|
9 |
-
opencv-python
|
|
|
|
6 |
pytest
|
7 |
loguru
|
8 |
python-dotenv
|
9 |
+
opencv-python
|
10 |
+
timm
|