IndicTrans3-beta / vllm-inference.py
sumanthd's picture
add README
3f64837
import os
import argparse
import logging
import time
from tqdm import tqdm
import pandas as pd
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
_CODE2LANG = {
"as": "Assamese",
"bn": "Bengali",
"en": "English",
"gu": "Gujarati",
"hi": "Hindi",
"kn": "Kannada",
"ml": "Malayalam",
"mr": "Marathi",
"ne": "Nepali",
"or": "Odia",
"pa": "Punjabi",
"sa": "Sanskrit",
"ta": "Tamil",
"te": "Telugu",
"ur": "Urdu"
}
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--input_file", type=str, required=True, help="input file")
parser.add_argument("--output_path", type=str, required=True, help="output path")
parser.add_argument("--input_column", type=str, required=True, help="input column")
parser.add_argument("--src_lang", type=str, required=True)
parser.add_argument("--tgt_lang", type=str, required=True)
parser.add_argument("--n_gpus", type=int, default=8)
parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type")
parser.add_argument("--temperature", type=float, default=0, help="temperature")
parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens")
parser.add_argument("--top_p", type=float, default=1, help="top p")
parser.add_argument("--top_k", type=int, default=64, help="top k")
parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty")
parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty")
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty")
return parser.parse_args()
def main(args):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")
if args.input_type == "tsv":
df = pd.read_csv(args.input_file, sep="\t")
elif args.input_type == "txt":
with open(args.input_file) as f:
lines = f.readlines()
df = pd.DataFrame({args.input_column: lines})
else:
raise ValueError("Invalid input type")
logging.info(f"Translating {len(df)} examples")
src = df[args.input_column].tolist()
prompt_dicts = []
for s in src:
prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}])
prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts]
logging.info(f"Loading model from: {args.model}")
llm = LLM(
model=args.model,
trust_remote_code=True,
tensor_parallel_size=args.n_gpus
)
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
repetition_penalty=args.repetition_penalty,
)
outputs = llm.generate(prompts, sampling_params)
results = []
for input_, output in zip(src, outputs):
generated_text = output.outputs[0].text
results.append({
'model_path': args.model,
'input': input_,
'input_token_length': len(tokenizer(input_)['input_ids']),
'output': generated_text,
'output_token_length': len(tokenizer(generated_text)['input_ids']),
'meta': {
'model': args.model,
'temperature': args.temperature,
'max_tokens': args.max_tokens,
'top_p': args.top_p,
'top_k': args.top_k,
'frequency_penalty': args.frequency_penalty,
'presence_penalty': args.presence_penalty,
'repetition_penalty': args.repetition_penalty
}
})
predictions = pd.DataFrame(results)
os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True)
predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False)
if __name__ == "__main__":
args = parse_args()
main(args)