Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b8ee0a2
1
Parent(s):
5021e53
prepare for zeroGPU
Browse files- requirements.txt +2 -1
- utils/models.py +43 -28
requirements.txt
CHANGED
@@ -6,4 +6,5 @@ numpy==1.26.4
|
|
6 |
openai>=1.60.2
|
7 |
torch>=2.5.1
|
8 |
tqdm==4.67.1
|
9 |
-
vllm>=0.8.5
|
|
|
|
6 |
openai>=1.60.2
|
7 |
torch>=2.5.1
|
8 |
tqdm==4.67.1
|
9 |
+
vllm>=0.8.5
|
10 |
+
spaces
|
utils/models.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import os
|
2 |
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
|
|
3 |
|
4 |
import torch
|
5 |
-
from transformers import
|
6 |
from .prompts import format_rag_prompt
|
7 |
from .shared import generation_interrupt
|
8 |
import threading
|
@@ -104,6 +105,7 @@ def generate_summaries(example, model_a_name, model_b_name):
|
|
104 |
|
105 |
|
106 |
# Modified run_inference to run in a thread and use a queue for results
|
|
|
107 |
def run_inference(model_name, context, question, result_queue):
|
108 |
"""
|
109 |
Run inference using the specified model. Designed to be run in a thread.
|
@@ -125,14 +127,26 @@ def run_inference(model_name, context, question, result_queue):
|
|
125 |
"System role not supported" not in tokenizer.chat_template
|
126 |
if tokenizer.chat_template else False # Handle missing chat_template
|
127 |
)
|
|
|
128 |
|
129 |
-
|
130 |
-
|
131 |
|
132 |
-
#
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
# model = AutoModelForCausalLM.from_pretrained(
|
138 |
# model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
|
@@ -141,10 +155,10 @@ def run_inference(model_name, context, question, result_queue):
|
|
141 |
|
142 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
143 |
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
|
149 |
# actual_input = tokenizer.apply_chat_template(
|
150 |
# text_input,
|
@@ -156,7 +170,8 @@ def run_inference(model_name, context, question, result_queue):
|
|
156 |
# max_length=2048, # Keep original max_length for now
|
157 |
# add_generation_prompt=True,
|
158 |
# ).to(device)
|
159 |
-
|
|
|
160 |
# # Ensure input does not exceed model max length after adding generation prompt
|
161 |
# # This check might be redundant if tokenizer handles it, but good for safety
|
162 |
# # if actual_input.shape[1] > tokenizer.model_max_length:
|
@@ -193,23 +208,23 @@ def run_inference(model_name, context, question, result_queue):
|
|
193 |
# else:
|
194 |
# # Decode the generated tokens, excluding the input tokens
|
195 |
# result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
196 |
-
llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True, device="cpu")
|
197 |
-
params = SamplingParams(
|
198 |
-
|
199 |
-
|
200 |
|
201 |
-
# Check interrupt before generation
|
202 |
-
if generation_interrupt.is_set():
|
203 |
-
|
204 |
-
|
205 |
-
# Generate the response
|
206 |
-
outputs = llm.chat(
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
)
|
211 |
-
# Check interrupt immediately after generation finishes or stops
|
212 |
-
result_queue.put(
|
213 |
|
214 |
except Exception as e:
|
215 |
print(f"Error in inference thread for {model_name}: {e}")
|
|
|
1 |
import os
|
2 |
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
3 |
+
import spaces
|
4 |
|
5 |
import torch
|
6 |
+
from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
7 |
from .prompts import format_rag_prompt
|
8 |
from .shared import generation_interrupt
|
9 |
import threading
|
|
|
105 |
|
106 |
|
107 |
# Modified run_inference to run in a thread and use a queue for results
|
108 |
+
@spaces.GPU
|
109 |
def run_inference(model_name, context, question, result_queue):
|
110 |
"""
|
111 |
Run inference using the specified model. Designed to be run in a thread.
|
|
|
127 |
"System role not supported" not in tokenizer.chat_template
|
128 |
if tokenizer.chat_template else False # Handle missing chat_template
|
129 |
)
|
130 |
+
outputs = ""
|
131 |
|
132 |
+
if tokenizer.pad_token is None:
|
133 |
+
tokenizer.pad_token = tokenizer.eos_token
|
134 |
|
135 |
+
# Check interrupt before loading the model
|
136 |
+
if generation_interrupt.is_set():
|
137 |
+
result_queue.put("")
|
138 |
+
return
|
139 |
+
|
140 |
+
pipe = pipeline(
|
141 |
+
"text-generation",
|
142 |
+
model=model_name,
|
143 |
+
tokenizer=tokenizer,
|
144 |
+
device_map='auto',
|
145 |
+
max_length=512,
|
146 |
+
do_sample=True,
|
147 |
+
temperature=0.6,
|
148 |
+
top_p=0.9,
|
149 |
+
)
|
150 |
|
151 |
# model = AutoModelForCausalLM.from_pretrained(
|
152 |
# model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
|
|
|
155 |
|
156 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
157 |
|
158 |
+
# Check interrupt before tokenization/template application
|
159 |
+
if generation_interrupt.is_set():
|
160 |
+
result_queue.put("")
|
161 |
+
return
|
162 |
|
163 |
# actual_input = tokenizer.apply_chat_template(
|
164 |
# text_input,
|
|
|
170 |
# max_length=2048, # Keep original max_length for now
|
171 |
# add_generation_prompt=True,
|
172 |
# ).to(device)
|
173 |
+
output = pipe(text_input, max_new_tokens=512)
|
174 |
+
result = output[0]['generated_text'][-1]['content']
|
175 |
# # Ensure input does not exceed model max length after adding generation prompt
|
176 |
# # This check might be redundant if tokenizer handles it, but good for safety
|
177 |
# # if actual_input.shape[1] > tokenizer.model_max_length:
|
|
|
208 |
# else:
|
209 |
# # Decode the generated tokens, excluding the input tokens
|
210 |
# result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
211 |
+
# llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True, device="cpu")
|
212 |
+
# params = SamplingParams(
|
213 |
+
# max_tokens=512,
|
214 |
+
# )
|
215 |
|
216 |
+
# # Check interrupt before generation
|
217 |
+
# if generation_interrupt.is_set():
|
218 |
+
# result_queue.put("")
|
219 |
+
# return
|
220 |
+
# # Generate the response
|
221 |
+
# outputs = llm.chat(
|
222 |
+
# text_input,
|
223 |
+
# sampling_params=params,
|
224 |
+
# # stopping_criteria=StoppingCriteriaList([InterruptCriteria(generation_interrupt)]),
|
225 |
+
# )
|
226 |
+
# # Check interrupt immediately after generation finishes or stops
|
227 |
+
result_queue.put(result)
|
228 |
|
229 |
except Exception as e:
|
230 |
print(f"Error in inference thread for {model_name}: {e}")
|