Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1898f4f
1
Parent(s):
0276240
fixed qwen, disabled gemma3-1b and minicpm, re-enabled cogito
Browse files- utils/models.py +31 -26
utils/models.py
CHANGED
@@ -8,24 +8,24 @@ from .prompts import format_rag_prompt
|
|
8 |
from .shared import generation_interrupt
|
9 |
|
10 |
models = {
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
"Gemma-3-1b-it": "google/gemma-3-1b-it",
|
16 |
#"Gemma-3-4b-it": "google/gemma-3-4b-it",
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
#"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
|
22 |
-
#"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
|
23 |
"Qwen3-0.6b": "qwen/qwen3-0.6b",
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
}
|
31 |
|
@@ -145,23 +145,28 @@ def run_inference(model_name, context, question):
|
|
145 |
device_map='cuda',
|
146 |
trust_remote_code=True,
|
147 |
torch_dtype=torch.bfloat16,
|
|
|
|
|
|
|
148 |
)
|
149 |
|
150 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
156 |
|
157 |
-
|
158 |
# Check interrupt before generation
|
159 |
-
if generation_interrupt.is_set():
|
160 |
-
return ""
|
161 |
|
162 |
-
|
163 |
#print(outputs[0]['generated_text'])
|
164 |
-
|
|
|
|
|
|
|
165 |
|
166 |
except Exception as e:
|
167 |
print(f"Error in inference for {model_name}: {e}")
|
|
|
8 |
from .shared import generation_interrupt
|
9 |
|
10 |
models = {
|
11 |
+
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
|
12 |
+
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
|
13 |
+
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
|
14 |
+
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
|
15 |
+
#"Gemma-3-1b-it": "google/gemma-3-1b-it",
|
16 |
#"Gemma-3-4b-it": "google/gemma-3-4b-it",
|
17 |
+
"Gemma-2-2b-it": "google/gemma-2-2b-it",
|
18 |
+
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
|
19 |
+
"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
|
20 |
+
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
|
21 |
+
# #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
|
22 |
+
# #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
|
23 |
"Qwen3-0.6b": "qwen/qwen3-0.6b",
|
24 |
+
"Qwen3-1.7b": "qwen/qwen3-1.7b",
|
25 |
+
"Qwen3-4b": "qwen/qwen3-4b",
|
26 |
+
"SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
27 |
+
"EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
|
28 |
+
"OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
|
29 |
|
30 |
}
|
31 |
|
|
|
145 |
device_map='cuda',
|
146 |
trust_remote_code=True,
|
147 |
torch_dtype=torch.bfloat16,
|
148 |
+
model_kwargs={
|
149 |
+
"attn_implementation": "eager",
|
150 |
+
}
|
151 |
)
|
152 |
|
153 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
154 |
+
if "Gemma-3".lower() not in model_name.lower():
|
155 |
+
formatted = pipe.tokenizer.apply_chat_template(
|
156 |
+
text_input,
|
157 |
+
tokenize=False,
|
158 |
+
**tokenizer_kwargs,
|
159 |
+
)
|
160 |
|
161 |
+
input_length = len(formatted)
|
162 |
# Check interrupt before generation
|
|
|
|
|
163 |
|
164 |
+
outputs = pipe(formatted, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})
|
165 |
#print(outputs[0]['generated_text'])
|
166 |
+
result = outputs[0]['generated_text'][input_length:]
|
167 |
+
else: # don't use apply chat template? I don't know why gemma keeps breaking
|
168 |
+
result = pipe(text_input, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})[0]['generated_text']
|
169 |
+
result = result[0]['generated_text'][-1]['content']
|
170 |
|
171 |
except Exception as e:
|
172 |
print(f"Error in inference for {model_name}: {e}")
|