oliver-aizip commited on
Commit
1898f4f
·
1 Parent(s): 0276240

fixed qwen, disabled gemma3-1b and minicpm, re-enabled cogito

Browse files
Files changed (1) hide show
  1. utils/models.py +31 -26
utils/models.py CHANGED
@@ -8,24 +8,24 @@ from .prompts import format_rag_prompt
8
  from .shared import generation_interrupt
9
 
10
  models = {
11
- "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
12
- "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
13
- "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
14
- "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
15
- "Gemma-3-1b-it": "google/gemma-3-1b-it",
16
  #"Gemma-3-4b-it": "google/gemma-3-4b-it",
17
- "Gemma-2-2b-it": "google/gemma-2-2b-it",
18
- "Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
19
- #"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
20
- "IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
21
- #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
22
- #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
23
  "Qwen3-0.6b": "qwen/qwen3-0.6b",
24
- "Qwen3-1.7b": "qwen/qwen3-1.7b",
25
- "Qwen3-4b": "qwen/qwen3-4b",
26
- "SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
27
- "EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
28
- "OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
29
 
30
  }
31
 
@@ -145,23 +145,28 @@ def run_inference(model_name, context, question):
145
  device_map='cuda',
146
  trust_remote_code=True,
147
  torch_dtype=torch.bfloat16,
 
 
 
148
  )
149
 
150
  text_input = format_rag_prompt(question, context, accepts_sys)
151
- formatted = tokenizer.apply_chat_template(
152
- text_input,
153
- tokenize=False,
154
- **tokenizer_kwargs,
155
- )
 
156
 
157
- input_length = len(formatted)
158
  # Check interrupt before generation
159
- if generation_interrupt.is_set():
160
- return ""
161
 
162
- outputs = pipe(formatted, **generation_kwargs)
163
  #print(outputs[0]['generated_text'])
164
- result = outputs[0]['generated_text'][input_length:]
 
 
 
165
 
166
  except Exception as e:
167
  print(f"Error in inference for {model_name}: {e}")
 
8
  from .shared import generation_interrupt
9
 
10
  models = {
11
+ "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
12
+ "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
13
+ "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
14
+ "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
15
+ #"Gemma-3-1b-it": "google/gemma-3-1b-it",
16
  #"Gemma-3-4b-it": "google/gemma-3-4b-it",
17
+ "Gemma-2-2b-it": "google/gemma-2-2b-it",
18
+ "Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
19
+ "Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
20
+ "IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
21
+ # #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
22
+ # #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
23
  "Qwen3-0.6b": "qwen/qwen3-0.6b",
24
+ "Qwen3-1.7b": "qwen/qwen3-1.7b",
25
+ "Qwen3-4b": "qwen/qwen3-4b",
26
+ "SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
27
+ "EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
28
+ "OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
29
 
30
  }
31
 
 
145
  device_map='cuda',
146
  trust_remote_code=True,
147
  torch_dtype=torch.bfloat16,
148
+ model_kwargs={
149
+ "attn_implementation": "eager",
150
+ }
151
  )
152
 
153
  text_input = format_rag_prompt(question, context, accepts_sys)
154
+ if "Gemma-3".lower() not in model_name.lower():
155
+ formatted = pipe.tokenizer.apply_chat_template(
156
+ text_input,
157
+ tokenize=False,
158
+ **tokenizer_kwargs,
159
+ )
160
 
161
+ input_length = len(formatted)
162
  # Check interrupt before generation
 
 
163
 
164
+ outputs = pipe(formatted, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})
165
  #print(outputs[0]['generated_text'])
166
+ result = outputs[0]['generated_text'][input_length:]
167
+ else: # don't use apply chat template? I don't know why gemma keeps breaking
168
+ result = pipe(text_input, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})[0]['generated_text']
169
+ result = result[0]['generated_text'][-1]['content']
170
 
171
  except Exception as e:
172
  print(f"Error in inference for {model_name}: {e}")