Text Generation
Transformers
Safetensors
qwen2
conversational
text-generation-inference
Satori-7B-Round2 / README.md
chaoscodes's picture
Update README.md
ad6607c verified
|
raw
history blame
1.53 kB
metadata
license: apache-2.0

import os
from tqdm import tqdm
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

def generate(question_list,model_path):
    llm = LLM(
        model=model_path,
        trust_remote_code=True,
        tensor_parallel_size=1,
    )
    sampling_params = SamplingParams(
        max_tokens=4096,
        temperature=0.0,
        n=1
    )
    outputs = llm.generate(question_list, sampling_params, use_tqdm=True)
    completions = [[output.text for output in output_item.outputs] for output_item in outputs]
    return completions

def prepare_prompt(question, tokenizer):
    content = f"<|im_start|>user\nSolve the following math problem efficiently and clearly.\nPlease reason step by step, and put your final answer within \\boxed{{}}.\nProblem: {question}<|im_end|>\n<|im_start|>assistant\n"

    msg = [
        {"role": "user", "content": content}
    ]
    prompt = tokenizer.apply_chat_template(
        msg,
        tokenize=False,
        add_generation_prompt=True
    )
    return prompt
    
def run():
    model_path = "Satori-reasoning/Satori-round2"
    all_problems = [
        "which number is larger? 9.11 or 9.9?",
    ]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    completions = generate(
        [prepare_prompt(problem_data, tokenizer) for problem_data in all_problems],
        model_path
    )
    
    for completion in completions:
        print(completion[0])
if __name__ == "__main__":
    run()