add README
Browse files- README.md +39 -0
- vllm-inference.py +116 -0
README.md
CHANGED
@@ -1,3 +1,42 @@
|
|
1 |
---
|
2 |
license: cc-by-4.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: cc-by-4.0
|
3 |
---
|
4 |
+
|
5 |
+
# IndicTrans3
|
6 |
+
|
7 |
+
**IndicTrans3** is a multilingual translation model for 15 Indic languages. This repository provides an inference script that leverages [vLLM](https://github.com/vllm-project/vllm) for efficient and scalable translation.
|
8 |
+
|
9 |
+
The model is built on top of **Gemma-3** and fine-tuned for **document-level translation** tasks. It supports both **sentence-level** and **document-level** translation in **both directions**:
|
10 |
+
- English ↔ Indic Languages
|
11 |
+
- Indic Languages ↔ English
|
12 |
+
|
13 |
+
---
|
14 |
+
|
15 |
+
## 🌐 Supported Languages
|
16 |
+
The model supports the following Indic languages: Assamese, Bengali, Gujarati, Hindi, Kannada, Maithili, Malayalam, Marathi, Nepali, Odia, Punjabi, Sanskrit, Tamil, Telugu, Urdu.
|
17 |
+
|
18 |
+
## 🛠️ Installation
|
19 |
+
|
20 |
+
1. **Install PyTorch**
|
21 |
+
Follow the instructions based on your system and CUDA version from the official [PyTorch website](https://pytorch.org/get-started/locally/).
|
22 |
+
|
23 |
+
2. **Install required dependencies**
|
24 |
+
|
25 |
+
```bash
|
26 |
+
pip install vllm transformers
|
27 |
+
```
|
28 |
+
|
29 |
+
3. Run Inference with vllm
|
30 |
+
```bash
|
31 |
+
python vllm-inference.py \
|
32 |
+
--model <model_path> \
|
33 |
+
--input_file <input_file> \
|
34 |
+
--output_path <output_file> \
|
35 |
+
--src_lang <source_language> \
|
36 |
+
--tgt_lang <target_language> \
|
37 |
+
--input_column <input_column> \
|
38 |
+
--input_type <input_type> \
|
39 |
+
```
|
40 |
+
|
41 |
+
## License
|
42 |
+
This model is licensed under the [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license. You are free to share and adapt the material for any purpose, even commercially, as long as you provide appropriate credit, indicate if changes were made, and distribute your contributions under the same license.
|
vllm-inference.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from tqdm import tqdm
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from vllm import LLM, SamplingParams
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
_CODE2LANG = {
|
13 |
+
"as": "Assamese",
|
14 |
+
"bn": "Bengali",
|
15 |
+
"en": "English",
|
16 |
+
"gu": "Gujarati",
|
17 |
+
"hi": "Hindi",
|
18 |
+
"kn": "Kannada",
|
19 |
+
"ml": "Malayalam",
|
20 |
+
"mr": "Marathi",
|
21 |
+
"ne": "Nepali",
|
22 |
+
"or": "Odia",
|
23 |
+
"pa": "Punjabi",
|
24 |
+
"sa": "Sanskrit",
|
25 |
+
"ta": "Tamil",
|
26 |
+
"te": "Telugu",
|
27 |
+
"ur": "Urdu"
|
28 |
+
}
|
29 |
+
|
30 |
+
logging.basicConfig(level=logging.INFO)
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--model", type=str, required=True)
|
36 |
+
parser.add_argument("--input_file", type=str, required=True, help="input file")
|
37 |
+
parser.add_argument("--output_path", type=str, required=True, help="output path")
|
38 |
+
parser.add_argument("--input_column", type=str, required=True, help="input column")
|
39 |
+
parser.add_argument("--src_lang", type=str, required=True)
|
40 |
+
parser.add_argument("--tgt_lang", type=str, required=True)
|
41 |
+
parser.add_argument("--n_gpus", type=int, default=8)
|
42 |
+
parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type")
|
43 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature")
|
44 |
+
parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens")
|
45 |
+
parser.add_argument("--top_p", type=float, default=1, help="top p")
|
46 |
+
parser.add_argument("--top_k", type=int, default=64, help="top k")
|
47 |
+
parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty")
|
48 |
+
parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty")
|
49 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty")
|
50 |
+
return parser.parse_args()
|
51 |
+
|
52 |
+
|
53 |
+
def main(args):
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")
|
55 |
+
|
56 |
+
if args.input_type == "tsv":
|
57 |
+
df = pd.read_csv(args.input_file, sep="\t")
|
58 |
+
elif args.input_type == "txt":
|
59 |
+
with open(args.input_file) as f:
|
60 |
+
lines = f.readlines()
|
61 |
+
df = pd.DataFrame({args.input_column: lines})
|
62 |
+
else:
|
63 |
+
raise ValueError("Invalid input type")
|
64 |
+
|
65 |
+
logging.info(f"Translating {len(df)} examples")
|
66 |
+
|
67 |
+
src = df[args.input_column].tolist()
|
68 |
+
prompt_dicts = []
|
69 |
+
for s in src:
|
70 |
+
prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}])
|
71 |
+
prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts]
|
72 |
+
|
73 |
+
logging.info(f"Loading model from: {args.model}")
|
74 |
+
llm = LLM(
|
75 |
+
model=args.model,
|
76 |
+
trust_remote_code=True,
|
77 |
+
tensor_parallel_size=args.n_gpus
|
78 |
+
)
|
79 |
+
|
80 |
+
sampling_params = SamplingParams(
|
81 |
+
temperature=args.temperature,
|
82 |
+
top_p=args.top_p,
|
83 |
+
max_tokens=args.max_tokens,
|
84 |
+
repetition_penalty=args.repetition_penalty,
|
85 |
+
)
|
86 |
+
|
87 |
+
outputs = llm.generate(prompts, sampling_params)
|
88 |
+
results = []
|
89 |
+
for input_, output in zip(src, outputs):
|
90 |
+
generated_text = output.outputs[0].text
|
91 |
+
results.append({
|
92 |
+
'model_path': args.model,
|
93 |
+
'input': input_,
|
94 |
+
'input_token_length': len(tokenizer(input_)['input_ids']),
|
95 |
+
'output': generated_text,
|
96 |
+
'output_token_length': len(tokenizer(generated_text)['input_ids']),
|
97 |
+
'meta': {
|
98 |
+
'model': args.model,
|
99 |
+
'temperature': args.temperature,
|
100 |
+
'max_tokens': args.max_tokens,
|
101 |
+
'top_p': args.top_p,
|
102 |
+
'top_k': args.top_k,
|
103 |
+
'frequency_penalty': args.frequency_penalty,
|
104 |
+
'presence_penalty': args.presence_penalty,
|
105 |
+
'repetition_penalty': args.repetition_penalty
|
106 |
+
}
|
107 |
+
})
|
108 |
+
|
109 |
+
predictions = pd.DataFrame(results)
|
110 |
+
os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True)
|
111 |
+
predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False)
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
args = parse_args()
|
116 |
+
main(args)
|