sumanthd commited on
Commit
3f64837
·
1 Parent(s): 315903b

add README

Browse files
Files changed (2) hide show
  1. README.md +39 -0
  2. vllm-inference.py +116 -0
README.md CHANGED
@@ -1,3 +1,42 @@
1
  ---
2
  license: cc-by-4.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-4.0
3
  ---
4
+
5
+ # IndicTrans3
6
+
7
+ **IndicTrans3** is a multilingual translation model for 15 Indic languages. This repository provides an inference script that leverages [vLLM](https://github.com/vllm-project/vllm) for efficient and scalable translation.
8
+
9
+ The model is built on top of **Gemma-3** and fine-tuned for **document-level translation** tasks. It supports both **sentence-level** and **document-level** translation in **both directions**:
10
+ - English ↔ Indic Languages
11
+ - Indic Languages ↔ English
12
+
13
+ ---
14
+
15
+ ## 🌐 Supported Languages
16
+ The model supports the following Indic languages: Assamese, Bengali, Gujarati, Hindi, Kannada, Maithili, Malayalam, Marathi, Nepali, Odia, Punjabi, Sanskrit, Tamil, Telugu, Urdu.
17
+
18
+ ## 🛠️ Installation
19
+
20
+ 1. **Install PyTorch**
21
+ Follow the instructions based on your system and CUDA version from the official [PyTorch website](https://pytorch.org/get-started/locally/).
22
+
23
+ 2. **Install required dependencies**
24
+
25
+ ```bash
26
+ pip install vllm transformers
27
+ ```
28
+
29
+ 3. Run Inference with vllm
30
+ ```bash
31
+ python vllm-inference.py \
32
+ --model <model_path> \
33
+ --input_file <input_file> \
34
+ --output_path <output_file> \
35
+ --src_lang <source_language> \
36
+ --tgt_lang <target_language> \
37
+ --input_column <input_column> \
38
+ --input_type <input_type> \
39
+ ```
40
+
41
+ ## License
42
+ This model is licensed under the [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license. You are free to share and adapt the material for any purpose, even commercially, as long as you provide appropriate credit, indicate if changes were made, and distribute your contributions under the same license.
vllm-inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ import time
5
+ from tqdm import tqdm
6
+ import pandas as pd
7
+
8
+ import torch
9
+ from vllm import LLM, SamplingParams
10
+ from transformers import AutoTokenizer
11
+
12
+ _CODE2LANG = {
13
+ "as": "Assamese",
14
+ "bn": "Bengali",
15
+ "en": "English",
16
+ "gu": "Gujarati",
17
+ "hi": "Hindi",
18
+ "kn": "Kannada",
19
+ "ml": "Malayalam",
20
+ "mr": "Marathi",
21
+ "ne": "Nepali",
22
+ "or": "Odia",
23
+ "pa": "Punjabi",
24
+ "sa": "Sanskrit",
25
+ "ta": "Tamil",
26
+ "te": "Telugu",
27
+ "ur": "Urdu"
28
+ }
29
+
30
+ logging.basicConfig(level=logging.INFO)
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--model", type=str, required=True)
36
+ parser.add_argument("--input_file", type=str, required=True, help="input file")
37
+ parser.add_argument("--output_path", type=str, required=True, help="output path")
38
+ parser.add_argument("--input_column", type=str, required=True, help="input column")
39
+ parser.add_argument("--src_lang", type=str, required=True)
40
+ parser.add_argument("--tgt_lang", type=str, required=True)
41
+ parser.add_argument("--n_gpus", type=int, default=8)
42
+ parser.add_argument("--input_type", type=str, choices=["tsv", "jsonl", "hf", "txt"], required=True, help="input type")
43
+ parser.add_argument("--temperature", type=float, default=0, help="temperature")
44
+ parser.add_argument("--max_tokens", type=int, default=16384, help="max tokens")
45
+ parser.add_argument("--top_p", type=float, default=1, help="top p")
46
+ parser.add_argument("--top_k", type=int, default=64, help="top k")
47
+ parser.add_argument("--frequency_penalty", type=float, default=0, help="frequency penalty")
48
+ parser.add_argument("--presence_penalty", type=float, default=0, help="presence penalty")
49
+ parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty")
50
+ return parser.parse_args()
51
+
52
+
53
+ def main(args):
54
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")
55
+
56
+ if args.input_type == "tsv":
57
+ df = pd.read_csv(args.input_file, sep="\t")
58
+ elif args.input_type == "txt":
59
+ with open(args.input_file) as f:
60
+ lines = f.readlines()
61
+ df = pd.DataFrame({args.input_column: lines})
62
+ else:
63
+ raise ValueError("Invalid input type")
64
+
65
+ logging.info(f"Translating {len(df)} examples")
66
+
67
+ src = df[args.input_column].tolist()
68
+ prompt_dicts = []
69
+ for s in src:
70
+ prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[args.tgt_lang]}: {s}"}])
71
+ prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts]
72
+
73
+ logging.info(f"Loading model from: {args.model}")
74
+ llm = LLM(
75
+ model=args.model,
76
+ trust_remote_code=True,
77
+ tensor_parallel_size=args.n_gpus
78
+ )
79
+
80
+ sampling_params = SamplingParams(
81
+ temperature=args.temperature,
82
+ top_p=args.top_p,
83
+ max_tokens=args.max_tokens,
84
+ repetition_penalty=args.repetition_penalty,
85
+ )
86
+
87
+ outputs = llm.generate(prompts, sampling_params)
88
+ results = []
89
+ for input_, output in zip(src, outputs):
90
+ generated_text = output.outputs[0].text
91
+ results.append({
92
+ 'model_path': args.model,
93
+ 'input': input_,
94
+ 'input_token_length': len(tokenizer(input_)['input_ids']),
95
+ 'output': generated_text,
96
+ 'output_token_length': len(tokenizer(generated_text)['input_ids']),
97
+ 'meta': {
98
+ 'model': args.model,
99
+ 'temperature': args.temperature,
100
+ 'max_tokens': args.max_tokens,
101
+ 'top_p': args.top_p,
102
+ 'top_k': args.top_k,
103
+ 'frequency_penalty': args.frequency_penalty,
104
+ 'presence_penalty': args.presence_penalty,
105
+ 'repetition_penalty': args.repetition_penalty
106
+ }
107
+ })
108
+
109
+ predictions = pd.DataFrame(results)
110
+ os.makedirs(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}", exist_ok=True)
111
+ predictions.to_csv(f"{args.output_path}/{args.src_lang}-{args.tgt_lang}/outputs.tsv", sep="\t", index=False)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ args = parse_args()
116
+ main(args)