|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
```python |
|
|
|
import os |
|
from tqdm import tqdm |
|
import torch |
|
from transformers import AutoTokenizer |
|
from vllm import LLM, SamplingParams |
|
|
|
def generate(question_list,model_path): |
|
llm = LLM( |
|
model=model_path, |
|
trust_remote_code=True, |
|
tensor_parallel_size=1, |
|
) |
|
sampling_params = SamplingParams( |
|
max_tokens=4096, |
|
temperature=0.0, |
|
n=1 |
|
) |
|
outputs = llm.generate(question_list, sampling_params, use_tqdm=True) |
|
completions = [[output.text for output in output_item.outputs] for output_item in outputs] |
|
return completions |
|
|
|
def prepare_prompt(question, tokenizer): |
|
content = f"<|im_start|>user\nSolve the following math problem efficiently and clearly.\nPlease reason step by step, and put your final answer within \\boxed{{}}.\nProblem: {question}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
msg = [ |
|
{"role": "user", "content": content} |
|
] |
|
prompt = tokenizer.apply_chat_template( |
|
msg, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
return prompt |
|
|
|
def run(): |
|
model_path = "Satori-reasoning/Satori-round2" |
|
all_problems = [ |
|
"which number is larger? 9.11 or 9.9?", |
|
] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
completions = generate( |
|
[prepare_prompt(problem_data, tokenizer) for problem_data in all_problems], |
|
model_path |
|
) |
|
|
|
for completion in completions: |
|
print(completion[0]) |
|
if __name__ == "__main__": |
|
run() |
|
|
|
``` |