Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,18 +2,18 @@ import spaces
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
5 |
-
import
|
6 |
|
7 |
# Model configurations
|
8 |
INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B"
|
9 |
EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B"
|
10 |
|
11 |
-
|
12 |
-
|
13 |
indic_en_tokenizer = AutoTokenizer.from_pretrained(INDIC_EN_MODEL, trust_remote_code=True)
|
14 |
en_indic_tokenizer = AutoTokenizer.from_pretrained(EN_INDIC_MODEL, trust_remote_code=True)
|
15 |
|
16 |
-
|
17 |
indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
|
18 |
INDIC_EN_MODEL,
|
19 |
trust_remote_code=True,
|
@@ -28,7 +28,10 @@ en_indic_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
28 |
device_map="cpu"
|
29 |
)
|
30 |
|
31 |
-
#
|
|
|
|
|
|
|
32 |
LANGUAGE_CODES = {
|
33 |
"Assamese": "asm_Beng",
|
34 |
"Bengali": "ben_Beng",
|
@@ -58,22 +61,9 @@ LANGUAGE_CODES = {
|
|
58 |
"English": "eng_Latn"
|
59 |
}
|
60 |
|
61 |
-
def format_input_for_indictrans2(text, src_lang, tgt_lang, direction):
|
62 |
-
"""Format input text according to IndicTrans2 requirements"""
|
63 |
-
text = text.strip()
|
64 |
-
|
65 |
-
if direction == "en_to_indic":
|
66 |
-
# For English to Indic: format as "text </s> src_lang"
|
67 |
-
formatted_input = f"{text} </s> {src_lang}"
|
68 |
-
else: # indic_to_en
|
69 |
-
# For Indic to English: format as "text </s> src_lang"
|
70 |
-
formatted_input = f"{text} </s> {src_lang}"
|
71 |
-
|
72 |
-
return formatted_input
|
73 |
-
|
74 |
@spaces.GPU(duration=120)
|
75 |
def translate_text(input_text, source_lang, target_lang, max_length):
|
76 |
-
"""Translate
|
77 |
|
78 |
if not input_text.strip():
|
79 |
return "Please enter text to translate."
|
@@ -85,101 +75,86 @@ def translate_text(input_text, source_lang, target_lang, max_length):
|
|
85 |
src_code = LANGUAGE_CODES[source_lang]
|
86 |
tgt_code = LANGUAGE_CODES[target_lang]
|
87 |
|
88 |
-
# Determine direction and model
|
89 |
if source_lang == "English" and target_lang != "English":
|
90 |
-
# English to Indic
|
91 |
model_gpu = en_indic_model.to(device)
|
92 |
tokenizer = en_indic_tokenizer
|
93 |
direction = "en_to_indic"
|
94 |
elif source_lang != "English" and target_lang == "English":
|
95 |
-
# Indic to English
|
96 |
model_gpu = indic_en_model.to(device)
|
97 |
tokenizer = indic_en_tokenizer
|
98 |
direction = "indic_to_en"
|
99 |
else:
|
100 |
return "Please select English as either source or target language (not both)."
|
101 |
|
102 |
-
#
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
105 |
)
|
106 |
|
107 |
-
# Tokenize
|
108 |
inputs = tokenizer(
|
109 |
-
|
110 |
-
return_tensors="pt",
|
111 |
-
padding=True,
|
112 |
truncation=True,
|
113 |
-
|
114 |
-
|
|
|
115 |
).to(device)
|
116 |
|
117 |
-
# Remove any unwanted keys
|
118 |
-
if 'token_type_ids' in inputs:
|
119 |
-
del inputs['token_type_ids']
|
120 |
-
|
121 |
-
# Set up generation parameters based on direction
|
122 |
-
if direction == "en_to_indic":
|
123 |
-
# Get target language token for decoder start
|
124 |
-
tgt_lang_token = tokenizer.convert_tokens_to_ids(tgt_code)
|
125 |
-
else:
|
126 |
-
# For Indic to English, use English token
|
127 |
-
tgt_lang_token = tokenizer.convert_tokens_to_ids("eng_Latn")
|
128 |
-
|
129 |
# Generate translation
|
130 |
with torch.no_grad():
|
131 |
generated_tokens = model_gpu.generate(
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
max_length=max_length,
|
136 |
-
min_length=1,
|
137 |
num_beams=5,
|
138 |
num_return_sequences=1,
|
139 |
early_stopping=True,
|
140 |
-
do_sample=False,
|
141 |
pad_token_id=tokenizer.pad_token_id,
|
142 |
-
eos_token_id=tokenizer.eos_token_id
|
143 |
-
use_cache=True
|
144 |
)
|
145 |
|
146 |
-
# Decode
|
147 |
-
|
148 |
-
generated_tokens
|
149 |
skip_special_tokens=True,
|
150 |
-
clean_up_tokenization_spaces=True
|
151 |
)
|
152 |
|
153 |
-
#
|
154 |
-
|
155 |
-
cleaned_output = re.sub(r'<.*?>', '', translated_text)
|
156 |
-
cleaned_output = cleaned_output.strip()
|
157 |
|
158 |
# Move model back to CPU
|
159 |
model_gpu.cpu()
|
160 |
torch.cuda.empty_cache()
|
161 |
|
162 |
-
return
|
163 |
|
164 |
except Exception as e:
|
165 |
-
# Clean up GPU memory in case of error
|
166 |
if 'model_gpu' in locals():
|
167 |
model_gpu.cpu()
|
168 |
torch.cuda.empty_cache()
|
169 |
return f"Error during translation: {str(e)}"
|
170 |
|
171 |
# Create Gradio interface
|
172 |
-
with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
|
173 |
gr.Markdown("""
|
174 |
# 🇮🇳 IndicTrans2 - Official AI4Bharat Translator
|
175 |
|
176 |
High-quality neural machine translation between English and 22 Indian languages.
|
|
|
177 |
|
178 |
**Supported Languages**: Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri,
|
179 |
Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali,
|
180 |
Sindhi, Tamil, Telugu, Urdu.
|
181 |
-
|
182 |
-
**Note**: Select English as either source OR target language (not both).
|
183 |
""")
|
184 |
|
185 |
with gr.Row():
|
@@ -197,8 +172,6 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
|
|
197 |
label="Source Language"
|
198 |
)
|
199 |
|
200 |
-
swap_btn = gr.Button("⇄", size="sm")
|
201 |
-
|
202 |
target_lang = gr.Dropdown(
|
203 |
choices=list(LANGUAGE_CODES.keys()),
|
204 |
value="Hindi",
|
@@ -222,17 +195,17 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
|
|
222 |
interactive=False
|
223 |
)
|
224 |
|
225 |
-
clear_btn = gr.Button("Clear
|
226 |
|
227 |
-
# Examples
|
228 |
-
gr.Markdown("### 💡
|
229 |
|
230 |
examples = [
|
231 |
-
["
|
232 |
-
["
|
233 |
-
["
|
234 |
-
["
|
235 |
-
["Technology is
|
236 |
]
|
237 |
|
238 |
gr.Examples(
|
@@ -243,9 +216,6 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
|
|
243 |
)
|
244 |
|
245 |
# Event handlers
|
246 |
-
def swap_languages(src, tgt):
|
247 |
-
return tgt, src
|
248 |
-
|
249 |
def clear_all():
|
250 |
return "", ""
|
251 |
|
@@ -255,12 +225,6 @@ with gr.Blocks(title="IndicTrans2 Translator", theme=gr.themes.Soft()) as demo:
|
|
255 |
outputs=output_text
|
256 |
)
|
257 |
|
258 |
-
swap_btn.click(
|
259 |
-
swap_languages,
|
260 |
-
inputs=[source_lang, target_lang],
|
261 |
-
outputs=[source_lang, target_lang]
|
262 |
-
)
|
263 |
-
|
264 |
clear_btn.click(
|
265 |
clear_all,
|
266 |
outputs=[input_text, output_text]
|
|
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
5 |
+
from IndicTransToolkit.processor import IndicProcessor
|
6 |
|
7 |
# Model configurations
|
8 |
INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B"
|
9 |
EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B"
|
10 |
|
11 |
+
print("Loading IndicTrans2 models...")
|
12 |
+
# Load tokenizers
|
13 |
indic_en_tokenizer = AutoTokenizer.from_pretrained(INDIC_EN_MODEL, trust_remote_code=True)
|
14 |
en_indic_tokenizer = AutoTokenizer.from_pretrained(EN_INDIC_MODEL, trust_remote_code=True)
|
15 |
|
16 |
+
# Load models on CPU
|
17 |
indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
|
18 |
INDIC_EN_MODEL,
|
19 |
trust_remote_code=True,
|
|
|
28 |
device_map="cpu"
|
29 |
)
|
30 |
|
31 |
+
# Initialize IndicProcessor (CRUCIAL for proper preprocessing)
|
32 |
+
ip = IndicProcessor(inference=True)
|
33 |
+
|
34 |
+
# Language mappings (exact codes from official documentation)
|
35 |
LANGUAGE_CODES = {
|
36 |
"Assamese": "asm_Beng",
|
37 |
"Bengali": "ben_Beng",
|
|
|
61 |
"English": "eng_Latn"
|
62 |
}
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@spaces.GPU(duration=120)
|
65 |
def translate_text(input_text, source_lang, target_lang, max_length):
|
66 |
+
"""Translate using IndicTrans2 with proper preprocessing"""
|
67 |
|
68 |
if not input_text.strip():
|
69 |
return "Please enter text to translate."
|
|
|
75 |
src_code = LANGUAGE_CODES[source_lang]
|
76 |
tgt_code = LANGUAGE_CODES[target_lang]
|
77 |
|
78 |
+
# Determine direction and select appropriate model/tokenizer
|
79 |
if source_lang == "English" and target_lang != "English":
|
80 |
+
# English to Indic
|
81 |
model_gpu = en_indic_model.to(device)
|
82 |
tokenizer = en_indic_tokenizer
|
83 |
direction = "en_to_indic"
|
84 |
elif source_lang != "English" and target_lang == "English":
|
85 |
+
# Indic to English
|
86 |
model_gpu = indic_en_model.to(device)
|
87 |
tokenizer = indic_en_tokenizer
|
88 |
direction = "indic_to_en"
|
89 |
else:
|
90 |
return "Please select English as either source or target language (not both)."
|
91 |
|
92 |
+
# CRUCIAL: Use IndicProcessor for proper preprocessing
|
93 |
+
input_sentences = [input_text.strip()]
|
94 |
+
|
95 |
+
# Preprocess using IndicProcessor (this handles the proper formatting)
|
96 |
+
batch = ip.preprocess_batch(
|
97 |
+
input_sentences,
|
98 |
+
src_lang=src_code,
|
99 |
+
tgt_lang=tgt_code,
|
100 |
)
|
101 |
|
102 |
+
# Tokenize the preprocessed batch
|
103 |
inputs = tokenizer(
|
104 |
+
batch,
|
|
|
|
|
105 |
truncation=True,
|
106 |
+
padding="longest",
|
107 |
+
return_tensors="pt",
|
108 |
+
return_attention_mask=True,
|
109 |
).to(device)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# Generate translation
|
112 |
with torch.no_grad():
|
113 |
generated_tokens = model_gpu.generate(
|
114 |
+
**inputs,
|
115 |
+
use_cache=True,
|
116 |
+
min_length=0,
|
117 |
max_length=max_length,
|
|
|
118 |
num_beams=5,
|
119 |
num_return_sequences=1,
|
120 |
early_stopping=True,
|
|
|
121 |
pad_token_id=tokenizer.pad_token_id,
|
122 |
+
eos_token_id=tokenizer.eos_token_id
|
|
|
123 |
)
|
124 |
|
125 |
+
# Decode generated tokens
|
126 |
+
generated_tokens = tokenizer.batch_decode(
|
127 |
+
generated_tokens,
|
128 |
skip_special_tokens=True,
|
129 |
+
clean_up_tokenization_spaces=True,
|
130 |
)
|
131 |
|
132 |
+
# CRUCIAL: Postprocess using IndicProcessor
|
133 |
+
translations = ip.postprocess_batch(generated_tokens, lang=tgt_code)
|
|
|
|
|
134 |
|
135 |
# Move model back to CPU
|
136 |
model_gpu.cpu()
|
137 |
torch.cuda.empty_cache()
|
138 |
|
139 |
+
return translations[0] if translations else "Translation failed."
|
140 |
|
141 |
except Exception as e:
|
|
|
142 |
if 'model_gpu' in locals():
|
143 |
model_gpu.cpu()
|
144 |
torch.cuda.empty_cache()
|
145 |
return f"Error during translation: {str(e)}"
|
146 |
|
147 |
# Create Gradio interface
|
148 |
+
with gr.Blocks(title="IndicTrans2 Official Translator", theme=gr.themes.Soft()) as demo:
|
149 |
gr.Markdown("""
|
150 |
# 🇮🇳 IndicTrans2 - Official AI4Bharat Translator
|
151 |
|
152 |
High-quality neural machine translation between English and 22 Indian languages.
|
153 |
+
Uses official IndicTransToolkit for proper preprocessing.
|
154 |
|
155 |
**Supported Languages**: Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri,
|
156 |
Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali,
|
157 |
Sindhi, Tamil, Telugu, Urdu.
|
|
|
|
|
158 |
""")
|
159 |
|
160 |
with gr.Row():
|
|
|
172 |
label="Source Language"
|
173 |
)
|
174 |
|
|
|
|
|
175 |
target_lang = gr.Dropdown(
|
176 |
choices=list(LANGUAGE_CODES.keys()),
|
177 |
value="Hindi",
|
|
|
195 |
interactive=False
|
196 |
)
|
197 |
|
198 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
199 |
|
200 |
+
# Examples from official documentation
|
201 |
+
gr.Markdown("### 💡 Official Examples:")
|
202 |
|
203 |
examples = [
|
204 |
+
["When I was young, I used to go to the park every day.", "English", "Hindi", 128],
|
205 |
+
["We watched a new movie last week, which was very inspiring.", "English", "Bengali", 128],
|
206 |
+
["जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।", "Hindi", "English", 128],
|
207 |
+
["हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।", "Hindi", "English", 128],
|
208 |
+
["Technology is changing our world rapidly.", "English", "Tamil", 128]
|
209 |
]
|
210 |
|
211 |
gr.Examples(
|
|
|
216 |
)
|
217 |
|
218 |
# Event handlers
|
|
|
|
|
|
|
219 |
def clear_all():
|
220 |
return "", ""
|
221 |
|
|
|
225 |
outputs=output_text
|
226 |
)
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
clear_btn.click(
|
229 |
clear_all,
|
230 |
outputs=[input_text, output_text]
|