import os import argparse import logging import time from tqdm import tqdm import pandas as pd import torch from vllm import LLM, SamplingParams from transformers import AutoTokenizer _CODE2LANG = { "as": "Assamese", "bn": "Bengali", "en": "English", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "ta": "Tamil", "te": "Telugu", "ur": "Urdu" } logging.basicConfig(level=logging.INFO) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--input_file", type=str, required=True, help="input file") parser.add_argument("--output_path", type=str, required=True, help="output path") parser.add_argument("--input_column", type=str, required=True, help="input column") parser.add_argument("--src_lang", type=str, required=True) parser.add_argument("--tgt_lang", type=str, required=True) parser.add_argument("--n_gpus", type=int, default=8) parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type") parser.add_argument("--temperature", type=float, default=0, help="temperature") parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens") parser.add_argument("--top_p", type=float, default=1, help="top p") parser.add_argument("--top_k", type=int, default=64, help="top k") parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty") parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty") parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty") return parser.parse_args() def main(args): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left") if args.input_type == "tsv": df = pd.read_csv(args.input_file, sep="\t") elif args.input_type == "txt": with open(args.input_file) as f: lines = f.readlines() df = pd.DataFrame({args.input_column: lines}) else: raise ValueError("Invalid input type") logging.info(f"Translating {len(df)} examples") src = df[args.input_column].tolist() prompt_dicts = [] for s in src: prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}]) prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts] logging.info(f"Loading model from: {args.model}") llm = LLM( model=args.model, trust_remote_code=True, tensor_parallel_size=args.n_gpus ) sampling_params = SamplingParams( temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, repetition_penalty=args.repetition_penalty, ) outputs = llm.generate(prompts, sampling_params) results = [] for input_, output in zip(src, outputs): generated_text = output.outputs[0].text results.append({ 'model_path': args.model, 'input': input_, 'input_token_length': len(tokenizer(input_)['input_ids']), 'output': generated_text, 'output_token_length': len(tokenizer(generated_text)['input_ids']), 'meta': { 'model': args.model, 'temperature': args.temperature, 'max_tokens': args.max_tokens, 'top_p': args.top_p, 'top_k': args.top_k, 'frequency_penalty': args.frequency_penalty, 'presence_penalty': args.presence_penalty, 'repetition_penalty': args.repetition_penalty } }) predictions = pd.DataFrame(results) os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True) predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False) if __name__ == "__main__": args = parse_args() main(args)