Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from IndicTransToolkit.processor import IndicProcessor | |
# Model configurations | |
INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B" | |
EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B" | |
print("Loading IndicTrans2 models...") | |
# Load tokenizers | |
indic_en_tokenizer = AutoTokenizer.from_pretrained(INDIC_EN_MODEL, trust_remote_code=True) | |
en_indic_tokenizer = AutoTokenizer.from_pretrained(EN_INDIC_MODEL, trust_remote_code=True) | |
# Load models on CPU | |
indic_en_model = AutoModelForSeq2SeqLM.from_pretrained( | |
INDIC_EN_MODEL, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
device_map="cpu" | |
) | |
en_indic_model = AutoModelForSeq2SeqLM.from_pretrained( | |
EN_INDIC_MODEL, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
device_map="cpu" | |
) | |
# Initialize IndicProcessor (CRUCIAL for proper preprocessing) | |
ip = IndicProcessor(inference=True) | |
# Language mappings (exact codes from official documentation) | |
LANGUAGE_CODES = { | |
"Assamese": "asm_Beng", | |
"Bengali": "ben_Beng", | |
"Bodo": "brx_Deva", | |
"Dogri": "doi_Deva", | |
"Gujarati": "guj_Gujr", | |
"Hindi": "hin_Deva", | |
"Kannada": "kan_Knda", | |
"Kashmiri (Arabic)": "kas_Arab", | |
"Kashmiri (Devanagari)": "kas_Deva", | |
"Konkani": "gom_Deva", | |
"Maithili": "mai_Deva", | |
"Malayalam": "mal_Mlym", | |
"Manipuri (Bengali)": "mni_Beng", | |
"Manipuri (Meitei)": "mni_Mtei", | |
"Marathi": "mar_Deva", | |
"Nepali": "npi_Deva", | |
"Odia": "ory_Orya", | |
"Punjabi": "pan_Guru", | |
"Sanskrit": "san_Deva", | |
"Santali": "sat_Olck", | |
"Sindhi (Arabic)": "snd_Arab", | |
"Sindhi (Devanagari)": "snd_Deva", | |
"Tamil": "tam_Taml", | |
"Telugu": "tel_Telu", | |
"Urdu": "urd_Arab", | |
"English": "eng_Latn" | |
} | |
def translate_text(input_text, source_lang, target_lang, max_length): | |
"""Translate using IndicTrans2 with proper preprocessing""" | |
if not input_text.strip(): | |
return "Please enter text to translate." | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Get language codes | |
src_code = LANGUAGE_CODES[source_lang] | |
tgt_code = LANGUAGE_CODES[target_lang] | |
# Determine direction and select appropriate model/tokenizer | |
if source_lang == "English" and target_lang != "English": | |
# English to Indic | |
model_gpu = en_indic_model.to(device) | |
tokenizer = en_indic_tokenizer | |
direction = "en_to_indic" | |
elif source_lang != "English" and target_lang == "English": | |
# Indic to English | |
model_gpu = indic_en_model.to(device) | |
tokenizer = indic_en_tokenizer | |
direction = "indic_to_en" | |
else: | |
return "Please select English as either source or target language (not both)." | |
# CRUCIAL: Use IndicProcessor for proper preprocessing | |
input_sentences = [input_text.strip()] | |
# Preprocess using IndicProcessor (this handles the proper formatting) | |
batch = ip.preprocess_batch( | |
input_sentences, | |
src_lang=src_code, | |
tgt_lang=tgt_code, | |
) | |
# Tokenize the preprocessed batch | |
inputs = tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(device) | |
# Generate translation | |
with torch.no_grad(): | |
generated_tokens = model_gpu.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=max_length, | |
num_beams=5, | |
num_return_sequences=1, | |
early_stopping=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode generated tokens | |
generated_tokens = tokenizer.batch_decode( | |
generated_tokens, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
# CRUCIAL: Postprocess using IndicProcessor | |
translations = ip.postprocess_batch(generated_tokens, lang=tgt_code) | |
# Move model back to CPU | |
model_gpu.cpu() | |
torch.cuda.empty_cache() | |
return translations[0] if translations else "Translation failed." | |
except Exception as e: | |
if 'model_gpu' in locals(): | |
model_gpu.cpu() | |
torch.cuda.empty_cache() | |
return f"Error during translation: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="IndicTrans2 Official Translator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🇮🇳 IndicTrans2 - Official AI4Bharat Translator | |
High-quality neural machine translation between English and 22 Indian languages. | |
Uses official IndicTransToolkit for proper preprocessing. | |
**Supported Languages**: Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri, | |
Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali, | |
Sindhi, Tamil, Telugu, Urdu. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter text to translate...", | |
lines=5 | |
) | |
with gr.Row(): | |
source_lang = gr.Dropdown( | |
choices=list(LANGUAGE_CODES.keys()), | |
value="English", | |
label="Source Language" | |
) | |
target_lang = gr.Dropdown( | |
choices=list(LANGUAGE_CODES.keys()), | |
value="Hindi", | |
label="Target Language" | |
) | |
max_length = gr.Slider( | |
minimum=32, | |
maximum=256, | |
value=128, | |
step=16, | |
label="Max Output Length" | |
) | |
translate_btn = gr.Button("Translate", variant="primary", size="lg") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Translation", | |
lines=5, | |
interactive=False | |
) | |
clear_btn = gr.Button("Clear", variant="secondary") | |
# Examples from official documentation | |
gr.Markdown("### 💡 Official Examples:") | |
examples = [ | |
["When I was young, I used to go to the park every day.", "English", "Hindi", 128], | |
["We watched a new movie last week, which was very inspiring.", "English", "Bengali", 128], | |
["जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।", "Hindi", "English", 128], | |
["हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।", "Hindi", "English", 128], | |
["Technology is changing our world rapidly.", "English", "Tamil", 128] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[input_text, source_lang, target_lang, max_length], | |
outputs=output_text, | |
fn=translate_text | |
) | |
# Event handlers | |
def clear_all(): | |
return "", "" | |
translate_btn.click( | |
translate_text, | |
inputs=[input_text, source_lang, target_lang, max_length], | |
outputs=output_text | |
) | |
clear_btn.click( | |
clear_all, | |
outputs=[input_text, output_text] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
show_error=True | |
) | |