medbot_meditron / app.py
techindia2025's picture
Update app.py
d7ab2f5 verified
import spaces
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor
# Model configurations
INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B"
EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B"
print("Loading IndicTrans2 models...")
# Load tokenizers
indic_en_tokenizer = AutoTokenizer.from_pretrained(INDIC_EN_MODEL, trust_remote_code=True)
en_indic_tokenizer = AutoTokenizer.from_pretrained(EN_INDIC_MODEL, trust_remote_code=True)
# Load models on CPU
indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
INDIC_EN_MODEL,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
en_indic_model = AutoModelForSeq2SeqLM.from_pretrained(
EN_INDIC_MODEL,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
# Initialize IndicProcessor (CRUCIAL for proper preprocessing)
ip = IndicProcessor(inference=True)
# Language mappings (exact codes from official documentation)
LANGUAGE_CODES = {
"Assamese": "asm_Beng",
"Bengali": "ben_Beng",
"Bodo": "brx_Deva",
"Dogri": "doi_Deva",
"Gujarati": "guj_Gujr",
"Hindi": "hin_Deva",
"Kannada": "kan_Knda",
"Kashmiri (Arabic)": "kas_Arab",
"Kashmiri (Devanagari)": "kas_Deva",
"Konkani": "gom_Deva",
"Maithili": "mai_Deva",
"Malayalam": "mal_Mlym",
"Manipuri (Bengali)": "mni_Beng",
"Manipuri (Meitei)": "mni_Mtei",
"Marathi": "mar_Deva",
"Nepali": "npi_Deva",
"Odia": "ory_Orya",
"Punjabi": "pan_Guru",
"Sanskrit": "san_Deva",
"Santali": "sat_Olck",
"Sindhi (Arabic)": "snd_Arab",
"Sindhi (Devanagari)": "snd_Deva",
"Tamil": "tam_Taml",
"Telugu": "tel_Telu",
"Urdu": "urd_Arab",
"English": "eng_Latn"
}
@spaces.GPU(duration=120)
def translate_text(input_text, source_lang, target_lang, max_length):
"""Translate using IndicTrans2 with proper preprocessing"""
if not input_text.strip():
return "Please enter text to translate."
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Get language codes
src_code = LANGUAGE_CODES[source_lang]
tgt_code = LANGUAGE_CODES[target_lang]
# Determine direction and select appropriate model/tokenizer
if source_lang == "English" and target_lang != "English":
# English to Indic
model_gpu = en_indic_model.to(device)
tokenizer = en_indic_tokenizer
direction = "en_to_indic"
elif source_lang != "English" and target_lang == "English":
# Indic to English
model_gpu = indic_en_model.to(device)
tokenizer = indic_en_tokenizer
direction = "indic_to_en"
else:
return "Please select English as either source or target language (not both)."
# CRUCIAL: Use IndicProcessor for proper preprocessing
input_sentences = [input_text.strip()]
# Preprocess using IndicProcessor (this handles the proper formatting)
batch = ip.preprocess_batch(
input_sentences,
src_lang=src_code,
tgt_lang=tgt_code,
)
# Tokenize the preprocessed batch
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(device)
# Generate translation
with torch.no_grad():
generated_tokens = model_gpu.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=max_length,
num_beams=5,
num_return_sequences=1,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode generated tokens
generated_tokens = tokenizer.batch_decode(
generated_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# CRUCIAL: Postprocess using IndicProcessor
translations = ip.postprocess_batch(generated_tokens, lang=tgt_code)
# Move model back to CPU
model_gpu.cpu()
torch.cuda.empty_cache()
return translations[0] if translations else "Translation failed."
except Exception as e:
if 'model_gpu' in locals():
model_gpu.cpu()
torch.cuda.empty_cache()
return f"Error during translation: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="IndicTrans2 Official Translator", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🇮🇳 IndicTrans2 - Official AI4Bharat Translator
High-quality neural machine translation between English and 22 Indian languages.
Uses official IndicTransToolkit for proper preprocessing.
**Supported Languages**: Assamese, Bengali, Bodo, Dogri, Gujarati, Hindi, Kannada, Kashmiri,
Konkani, Maithili, Malayalam, Manipuri, Marathi, Nepali, Odia, Punjabi, Sanskrit, Santali,
Sindhi, Tamil, Telugu, Urdu.
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text to translate...",
lines=5
)
with gr.Row():
source_lang = gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="English",
label="Source Language"
)
target_lang = gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="Hindi",
label="Target Language"
)
max_length = gr.Slider(
minimum=32,
maximum=256,
value=128,
step=16,
label="Max Output Length"
)
translate_btn = gr.Button("Translate", variant="primary", size="lg")
with gr.Column():
output_text = gr.Textbox(
label="Translation",
lines=5,
interactive=False
)
clear_btn = gr.Button("Clear", variant="secondary")
# Examples from official documentation
gr.Markdown("### 💡 Official Examples:")
examples = [
["When I was young, I used to go to the park every day.", "English", "Hindi", 128],
["We watched a new movie last week, which was very inspiring.", "English", "Bengali", 128],
["जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।", "Hindi", "English", 128],
["हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।", "Hindi", "English", 128],
["Technology is changing our world rapidly.", "English", "Tamil", 128]
]
gr.Examples(
examples=examples,
inputs=[input_text, source_lang, target_lang, max_length],
outputs=output_text,
fn=translate_text
)
# Event handlers
def clear_all():
return "", ""
translate_btn.click(
translate_text,
inputs=[input_text, source_lang, target_lang, max_length],
outputs=output_text
)
clear_btn.click(
clear_all,
outputs=[input_text, output_text]
)
if __name__ == "__main__":
demo.launch(
share=True,
show_error=True
)