|
import os |
|
import argparse |
|
import logging |
|
import time |
|
from tqdm import tqdm |
|
import pandas as pd |
|
|
|
import torch |
|
from vllm import LLM, SamplingParams |
|
from transformers import AutoTokenizer |
|
|
|
_CODE2LANG = { |
|
"as": "Assamese", |
|
"bn": "Bengali", |
|
"en": "English", |
|
"gu": "Gujarati", |
|
"hi": "Hindi", |
|
"kn": "Kannada", |
|
"ml": "Malayalam", |
|
"mr": "Marathi", |
|
"ne": "Nepali", |
|
"or": "Odia", |
|
"pa": "Punjabi", |
|
"sa": "Sanskrit", |
|
"ta": "Tamil", |
|
"te": "Telugu", |
|
"ur": "Urdu" |
|
} |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", type=str, required=True) |
|
parser.add_argument("--input_file", type=str, required=True, help="input file") |
|
parser.add_argument("--output_path", type=str, required=True, help="output path") |
|
parser.add_argument("--input_column", type=str, required=True, help="input column") |
|
parser.add_argument("--src_lang", type=str, required=True) |
|
parser.add_argument("--tgt_lang", type=str, required=True) |
|
parser.add_argument("--n_gpus", type=int, default=8) |
|
parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type") |
|
parser.add_argument("--temperature", type=float, default=0, help="temperature") |
|
parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens") |
|
parser.add_argument("--top_p", type=float, default=1, help="top p") |
|
parser.add_argument("--top_k", type=int, default=64, help="top k") |
|
parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty") |
|
parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty") |
|
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty") |
|
return parser.parse_args() |
|
|
|
|
|
def main(args): |
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left") |
|
|
|
if args.input_type == "tsv": |
|
df = pd.read_csv(args.input_file, sep="\t") |
|
elif args.input_type == "txt": |
|
with open(args.input_file) as f: |
|
lines = f.readlines() |
|
df = pd.DataFrame({args.input_column: lines}) |
|
else: |
|
raise ValueError("Invalid input type") |
|
|
|
logging.info(f"Translating {len(df)} examples") |
|
|
|
src = df[args.input_column].tolist() |
|
prompt_dicts = [] |
|
for s in src: |
|
prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}]) |
|
prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts] |
|
|
|
logging.info(f"Loading model from: {args.model}") |
|
llm = LLM( |
|
model=args.model, |
|
trust_remote_code=True, |
|
tensor_parallel_size=args.n_gpus |
|
) |
|
|
|
sampling_params = SamplingParams( |
|
temperature=args.temperature, |
|
top_p=args.top_p, |
|
max_tokens=args.max_tokens, |
|
repetition_penalty=args.repetition_penalty, |
|
) |
|
|
|
outputs = llm.generate(prompts, sampling_params) |
|
results = [] |
|
for input_, output in zip(src, outputs): |
|
generated_text = output.outputs[0].text |
|
results.append({ |
|
'model_path': args.model, |
|
'input': input_, |
|
'input_token_length': len(tokenizer(input_)['input_ids']), |
|
'output': generated_text, |
|
'output_token_length': len(tokenizer(generated_text)['input_ids']), |
|
'meta': { |
|
'model': args.model, |
|
'temperature': args.temperature, |
|
'max_tokens': args.max_tokens, |
|
'top_p': args.top_p, |
|
'top_k': args.top_k, |
|
'frequency_penalty': args.frequency_penalty, |
|
'presence_penalty': args.presence_penalty, |
|
'repetition_penalty': args.repetition_penalty |
|
} |
|
}) |
|
|
|
predictions = pd.DataFrame(results) |
|
os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True) |
|
predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |