metadata
license: apache-2.0
import os
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
def generate(question_list,model_path):
llm = LLM(
model=model_path,
trust_remote_code=True,
tensor_parallel_size=1,
)
sampling_params = SamplingParams(
max_tokens=4096,
temperature=0.0,
n=1
)
outputs = llm.generate(question_list, sampling_params, use_tqdm=True)
completions = [[output.text for output in output_item.outputs] for output_item in outputs]
return completions
def prepare_prompt(question, tokenizer):
content = f"<|im_start|>user\nSolve the following math problem efficiently and clearly.\nPlease reason step by step, and put your final answer within \\boxed{{}}.\nProblem: {question}<|im_end|>\n<|im_start|>assistant\n"
msg = [
{"role": "user", "content": content}
]
prompt = tokenizer.apply_chat_template(
msg,
tokenize=False,
add_generation_prompt=True
)
return prompt
def run():
model_path = "Satori-reasoning/Satori-round2"
all_problems = [
"which number is larger? 9.11 or 9.9?",
]
tokenizer = AutoTokenizer.from_pretrained(model_path)
completions = generate(
[prepare_prompt(problem_data, tokenizer) for problem_data in all_problems],
model_path
)
for completion in completions:
print(completion[0])
if __name__ == "__main__":
run()