File size: 1,527 Bytes
ad6607c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
---
license: apache-2.0
---
```python
import os
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
def generate(question_list,model_path):
llm = LLM(
model=model_path,
trust_remote_code=True,
tensor_parallel_size=1,
)
sampling_params = SamplingParams(
max_tokens=4096,
temperature=0.0,
n=1
)
outputs = llm.generate(question_list, sampling_params, use_tqdm=True)
completions = [[output.text for output in output_item.outputs] for output_item in outputs]
return completions
def prepare_prompt(question, tokenizer):
content = f"<|im_start|>user\nSolve the following math problem efficiently and clearly.\nPlease reason step by step, and put your final answer within \\boxed{{}}.\nProblem: {question}<|im_end|>\n<|im_start|>assistant\n"
msg = [
{"role": "user", "content": content}
]
prompt = tokenizer.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=True
)
return prompt
def run():
model_path = "Satori-reasoning/Satori-round2"
all_problems = [
"which number is larger? 9.11 or 9.9?",
]
tokenizer = AutoTokenizer.from_pretrained(model_path)
completions = generate(
[prepare_prompt(problem_data, tokenizer) for problem_data in all_problems],
model_path
)
for completion in completions:
print(completion[0])
if __name__ == "__main__":
run()
``` |