File size: 4,252 Bytes
3f64837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import argparse
import logging
import time
from tqdm import tqdm
import pandas as pd

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

_CODE2LANG = {
    "as": "Assamese",
    "bn": "Bengali",
    "en": "English",
    "gu": "Gujarati",
    "hi": "Hindi",
    "kn": "Kannada",
    "ml": "Malayalam",
    "mr": "Marathi",
    "ne": "Nepali",
    "or": "Odia",
    "pa": "Punjabi",
    "sa": "Sanskrit",
    "ta": "Tamil",
    "te": "Telugu",
    "ur": "Urdu"
}

logging.basicConfig(level=logging.INFO)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--input_file", type=str, required=True, help="input file")
    parser.add_argument("--output_path", type=str, required=True, help="output path")
    parser.add_argument("--input_column", type=str, required=True, help="input column")
    parser.add_argument("--src_lang", type=str, required=True)
    parser.add_argument("--tgt_lang", type=str, required=True)
    parser.add_argument("--n_gpus", type=int, default=8)
    parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type")
    parser.add_argument("--temperature", type=float, default=0, help="temperature")
    parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens")
    parser.add_argument("--top_p", type=float, default=1, help="top p")
    parser.add_argument("--top_k", type=int, default=64, help="top k")
    parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty")
    parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty")
    parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty")
    return parser.parse_args()


def main(args):
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")

    if args.input_type == "tsv":
        df = pd.read_csv(args.input_file, sep="\t")
    elif args.input_type == "txt":
        with open(args.input_file) as f:
            lines = f.readlines()
        df = pd.DataFrame({args.input_column: lines})
    else:
        raise ValueError("Invalid input type")
    
    logging.info(f"Translating {len(df)} examples")
    
    src = df[args.input_column].tolist()
    prompt_dicts = []
    for s in src:
        prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}])
    prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts]

    logging.info(f"Loading model from: {args.model}")
    llm = LLM(
        model=args.model,
        trust_remote_code=True,
        tensor_parallel_size=args.n_gpus
    )

    sampling_params = SamplingParams(
        temperature=args.temperature, 
        top_p=args.top_p, 
        max_tokens=args.max_tokens,
        repetition_penalty=args.repetition_penalty,
    )

    outputs = llm.generate(prompts, sampling_params)
    results = []
    for input_, output in zip(src, outputs):
        generated_text = output.outputs[0].text
        results.append({
            'model_path': args.model,
            'input': input_,
            'input_token_length': len(tokenizer(input_)['input_ids']),
            'output': generated_text,
            'output_token_length': len(tokenizer(generated_text)['input_ids']),
            'meta': {
                'model': args.model,
                'temperature': args.temperature,
                'max_tokens': args.max_tokens,
                'top_p': args.top_p,
                'top_k': args.top_k,
                'frequency_penalty': args.frequency_penalty,
                'presence_penalty': args.presence_penalty,
                'repetition_penalty': args.repetition_penalty
            }
        })

    predictions = pd.DataFrame(results)
    os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True)
    predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False)


if __name__ == "__main__":
    args = parse_args()
    main(args)