oliver-aizip commited on
Commit
b8ee0a2
·
1 Parent(s): 5021e53

prepare for zeroGPU

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. utils/models.py +43 -28
requirements.txt CHANGED
@@ -6,4 +6,5 @@ numpy==1.26.4
6
  openai>=1.60.2
7
  torch>=2.5.1
8
  tqdm==4.67.1
9
- vllm>=0.8.5
 
 
6
  openai>=1.60.2
7
  torch>=2.5.1
8
  tqdm==4.67.1
9
+ vllm>=0.8.5
10
+ spaces
utils/models.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  os.environ['MKL_THREADING_LAYER'] = 'GNU'
 
3
 
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
6
  from .prompts import format_rag_prompt
7
  from .shared import generation_interrupt
8
  import threading
@@ -104,6 +105,7 @@ def generate_summaries(example, model_a_name, model_b_name):
104
 
105
 
106
  # Modified run_inference to run in a thread and use a queue for results
 
107
  def run_inference(model_name, context, question, result_queue):
108
  """
109
  Run inference using the specified model. Designed to be run in a thread.
@@ -125,14 +127,26 @@ def run_inference(model_name, context, question, result_queue):
125
  "System role not supported" not in tokenizer.chat_template
126
  if tokenizer.chat_template else False # Handle missing chat_template
127
  )
 
128
 
129
- # if tokenizer.pad_token is None:
130
- # tokenizer.pad_token = tokenizer.eos_token
131
 
132
- # # Check interrupt before loading the model
133
- # if generation_interrupt.is_set():
134
- # result_queue.put("")
135
- # return
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # model = AutoModelForCausalLM.from_pretrained(
138
  # model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
@@ -141,10 +155,10 @@ def run_inference(model_name, context, question, result_queue):
141
 
142
  text_input = format_rag_prompt(question, context, accepts_sys)
143
 
144
- # # Check interrupt before tokenization/template application
145
- # if generation_interrupt.is_set():
146
- # result_queue.put("")
147
- # return
148
 
149
  # actual_input = tokenizer.apply_chat_template(
150
  # text_input,
@@ -156,7 +170,8 @@ def run_inference(model_name, context, question, result_queue):
156
  # max_length=2048, # Keep original max_length for now
157
  # add_generation_prompt=True,
158
  # ).to(device)
159
-
 
160
  # # Ensure input does not exceed model max length after adding generation prompt
161
  # # This check might be redundant if tokenizer handles it, but good for safety
162
  # # if actual_input.shape[1] > tokenizer.model_max_length:
@@ -193,23 +208,23 @@ def run_inference(model_name, context, question, result_queue):
193
  # else:
194
  # # Decode the generated tokens, excluding the input tokens
195
  # result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
196
- llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True, device="cpu")
197
- params = SamplingParams(
198
- max_tokens=512,
199
- )
200
 
201
- # Check interrupt before generation
202
- if generation_interrupt.is_set():
203
- result_queue.put("")
204
- return
205
- # Generate the response
206
- outputs = llm.chat(
207
- text_input,
208
- sampling_params=params,
209
- # stopping_criteria=StoppingCriteriaList([InterruptCriteria(generation_interrupt)]),
210
- )
211
- # Check interrupt immediately after generation finishes or stops
212
- result_queue.put(outputs[0].outputs[0].text)
213
 
214
  except Exception as e:
215
  print(f"Error in inference thread for {model_name}: {e}")
 
1
  import os
2
  os.environ['MKL_THREADING_LAYER'] = 'GNU'
3
+ import spaces
4
 
5
  import torch
6
+ from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
7
  from .prompts import format_rag_prompt
8
  from .shared import generation_interrupt
9
  import threading
 
105
 
106
 
107
  # Modified run_inference to run in a thread and use a queue for results
108
+ @spaces.GPU
109
  def run_inference(model_name, context, question, result_queue):
110
  """
111
  Run inference using the specified model. Designed to be run in a thread.
 
127
  "System role not supported" not in tokenizer.chat_template
128
  if tokenizer.chat_template else False # Handle missing chat_template
129
  )
130
+ outputs = ""
131
 
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
 
135
+ # Check interrupt before loading the model
136
+ if generation_interrupt.is_set():
137
+ result_queue.put("")
138
+ return
139
+
140
+ pipe = pipeline(
141
+ "text-generation",
142
+ model=model_name,
143
+ tokenizer=tokenizer,
144
+ device_map='auto',
145
+ max_length=512,
146
+ do_sample=True,
147
+ temperature=0.6,
148
+ top_p=0.9,
149
+ )
150
 
151
  # model = AutoModelForCausalLM.from_pretrained(
152
  # model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
 
155
 
156
  text_input = format_rag_prompt(question, context, accepts_sys)
157
 
158
+ # Check interrupt before tokenization/template application
159
+ if generation_interrupt.is_set():
160
+ result_queue.put("")
161
+ return
162
 
163
  # actual_input = tokenizer.apply_chat_template(
164
  # text_input,
 
170
  # max_length=2048, # Keep original max_length for now
171
  # add_generation_prompt=True,
172
  # ).to(device)
173
+ output = pipe(text_input, max_new_tokens=512)
174
+ result = output[0]['generated_text'][-1]['content']
175
  # # Ensure input does not exceed model max length after adding generation prompt
176
  # # This check might be redundant if tokenizer handles it, but good for safety
177
  # # if actual_input.shape[1] > tokenizer.model_max_length:
 
208
  # else:
209
  # # Decode the generated tokens, excluding the input tokens
210
  # result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
211
+ # llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True, device="cpu")
212
+ # params = SamplingParams(
213
+ # max_tokens=512,
214
+ # )
215
 
216
+ # # Check interrupt before generation
217
+ # if generation_interrupt.is_set():
218
+ # result_queue.put("")
219
+ # return
220
+ # # Generate the response
221
+ # outputs = llm.chat(
222
+ # text_input,
223
+ # sampling_params=params,
224
+ # # stopping_criteria=StoppingCriteriaList([InterruptCriteria(generation_interrupt)]),
225
+ # )
226
+ # # Check interrupt immediately after generation finishes or stops
227
+ result_queue.put(result)
228
 
229
  except Exception as e:
230
  print(f"Error in inference thread for {model_name}: {e}")