File size: 4,252 Bytes
3f64837 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import argparse
import logging
import time
from tqdm import tqdm
import pandas as pd
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
_CODE2LANG = {
"as": "Assamese",
"bn": "Bengali",
"en": "English",
"gu": "Gujarati",
"hi": "Hindi",
"kn": "Kannada",
"ml": "Malayalam",
"mr": "Marathi",
"ne": "Nepali",
"or": "Odia",
"pa": "Punjabi",
"sa": "Sanskrit",
"ta": "Tamil",
"te": "Telugu",
"ur": "Urdu"
}
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--input_file", type=str, required=True, help="input file")
parser.add_argument("--output_path", type=str, required=True, help="output path")
parser.add_argument("--input_column", type=str, required=True, help="input column")
parser.add_argument("--src_lang", type=str, required=True)
parser.add_argument("--tgt_lang", type=str, required=True)
parser.add_argument("--n_gpus", type=int, default=8)
parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type")
parser.add_argument("--temperature", type=float, default=0, help="temperature")
parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens")
parser.add_argument("--top_p", type=float, default=1, help="top p")
parser.add_argument("--top_k", type=int, default=64, help="top k")
parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty")
parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty")
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty")
return parser.parse_args()
def main(args):
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")
if args.input_type == "tsv":
df = pd.read_csv(args.input_file, sep="\t")
elif args.input_type == "txt":
with open(args.input_file) as f:
lines = f.readlines()
df = pd.DataFrame({args.input_column: lines})
else:
raise ValueError("Invalid input type")
logging.info(f"Translating {len(df)} examples")
src = df[args.input_column].tolist()
prompt_dicts = []
for s in src:
prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}])
prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts]
logging.info(f"Loading model from: {args.model}")
llm = LLM(
model=args.model,
trust_remote_code=True,
tensor_parallel_size=args.n_gpus
)
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
repetition_penalty=args.repetition_penalty,
)
outputs = llm.generate(prompts, sampling_params)
results = []
for input_, output in zip(src, outputs):
generated_text = output.outputs[0].text
results.append({
'model_path': args.model,
'input': input_,
'input_token_length': len(tokenizer(input_)['input_ids']),
'output': generated_text,
'output_token_length': len(tokenizer(generated_text)['input_ids']),
'meta': {
'model': args.model,
'temperature': args.temperature,
'max_tokens': args.max_tokens,
'top_p': args.top_p,
'top_k': args.top_k,
'frequency_penalty': args.frequency_penalty,
'presence_penalty': args.presence_penalty,
'repetition_penalty': args.repetition_penalty
}
})
predictions = pd.DataFrame(results)
os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True)
predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False)
if __name__ == "__main__":
args = parse_args()
main(args) |