syl-eng / app.py
shbhro's picture
Update app.py
b1d5be4 verified
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import subprocess
import sys
import os
# --- Configuration ---
SYLHETI_TO_BN_MODEL = "shbhro/sylhetit5"
BN_TO_EN_MODEL = "csebuetnlp/banglat5_nmt_bn_en"
NORMALIZER_REPO = "https://github.com/csebuetnlp/normalizer.git"
# --- Helper function to install/import normalizer ---
normalizer_module = None
dummy_normalizer_flag = False # Flag to indicate if dummy is used
def dummy_normalize_func(text): # Define the dummy function clearly
raise RuntimeError("Normalizer library could not be loaded. Please check installation and logs.")
try:
from normalizer import normalize as normalize_fn_imported
normalizer_module = normalize_fn_imported
print("Normalizer imported successfully.")
except ImportError:
print(f"Normalizer library not found. Attempting to install from {NORMALIZER_REPO}...")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+{NORMALIZER_REPO}#egg=normalizer"])
from normalizer import normalize as normalize_fn_imported_after_install
normalizer_module = normalize_fn_imported_after_install
print("Normalizer installed and imported successfully after pip install.")
except Exception as e:
print(f"Failed to install or import normalizer: {e}")
print("Please ensure 'git+https://github.com/csebuetnlp/normalizer.git#egg=normalizer' is in your requirements.txt for Hugging Face Spaces.")
normalizer_module = dummy_normalize_func # Assign the actual dummy function
dummy_normalizer_flag = True
# --- Model Loading (Globally, when the script starts) ---
sylheti_to_bn_pipe = None
bn_to_en_model = None
bn_to_en_tokenizer = None
model_device = None
print("Loading translation models...")
try:
model_device_type = "cuda" if torch.cuda.is_available() else "cpu"
model_device = torch.device(model_device_type)
hf_device_param = 0 if model_device_type == "cuda" else -1 # For pipeline
print(f"Using device: {model_device_type}")
sylheti_to_bn_pipe = pipeline(
"text2text-generation",
model=SYLHETI_TO_BN_MODEL,
device=hf_device_param
)
print(f"Sylheti-to-Bengali model ({SYLHETI_TO_BN_MODEL}) loaded.")
bn_to_en_model = AutoModelForSeq2SeqLM.from_pretrained(BN_TO_EN_MODEL)
bn_to_en_tokenizer = AutoTokenizer.from_pretrained(BN_TO_EN_MODEL, use_fast=False)
bn_to_en_model.to(model_device)
print(f"Bengali-to-English model ({BN_TO_EN_MODEL}) loaded.")
except Exception as e:
print(f"FATAL: Error loading one or more models: {e}")
sylheti_to_bn_pipe = None
bn_to_en_model = None
bn_to_en_tokenizer = None
# --- Main Translation Logic ---
def translate_sylheti_to_english_gradio(sylheti_text_input):
if not sylheti_text_input.strip():
return "Please enter some Sylheti text.", ""
if not sylheti_to_bn_pipe:
return "Error: Sylheti-to-Bengali model not loaded. Check logs.", ""
if not bn_to_en_model or not bn_to_en_tokenizer:
return "Error: Bengali-to-English model not loaded. Check logs.", ""
# Check if the normalizer is the dummy function
if dummy_normalizer_flag or normalizer_module is None:
return "Error: Bengali normalizer library not available. Check logs.", ""
bengali_text_intermediate = "Error in Sylheti to Bengali step."
english_text_final = "Error in Bengali to English step."
# Step 1: Sylheti → Bengali
try:
print(f"Translating Sylheti to Bengali: '{sylheti_text_input}'")
bengali_translation_outputs = sylheti_to_bn_pipe(
sylheti_text_input,
max_length=128,
num_beams=5,
early_stopping=True
)
bengali_text_intermediate = bengali_translation_outputs[0]['generated_text']
print(f"Intermediate Bengali: '{bengali_text_intermediate}'")
except Exception as e:
print(f"Error during Sylheti to Bengali translation: {e}")
bengali_text_intermediate = f"Sylheti->Bengali Error: {str(e)}"
return bengali_text_intermediate, english_text_final
# Step 2: Bengali → English
try:
print(f"Normalizing and translating Bengali to English: '{bengali_text_intermediate}'")
# Ensure normalizer_module is callable before calling
if callable(normalizer_module):
normalized_bn_text = normalizer_module(bengali_text_intermediate)
else:
# This case should ideally be caught by the check above, but as a safeguard:
raise RuntimeError("Normalizer function is not callable.")
print(f"Normalized Bengali: '{normalized_bn_text}'")
input_ids = bn_to_en_tokenizer(
normalized_bn_text,
return_tensors="pt"
).input_ids.to(model_device)
generated_tokens = bn_to_en_model.generate(
input_ids,
max_length=128,
num_beams=5,
early_stopping=True
)
english_text_list = bn_to_en_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
english_text_final = english_text_list[0] if english_text_list else "No English output generated."
print(f"Final English: '{english_text_final}'")
except Exception as e:
print(f"Error during Bengali to English translation: {e}")
english_text_final = f"Bengali->English Error: {str(e)}"
return bengali_text_intermediate, english_text_final
# --- Gradio Interface Definition ---
iface = gr.Interface(
fn=translate_sylheti_to_english_gradio,
inputs=gr.Textbox(
lines=4,
label="Enter Sylheti Text",
placeholder="কিতা কিতা কিনলায় তে?"
),
outputs=[
gr.Textbox(label="Intermediate Bengali Output", lines=4),
gr.Textbox(label="Final English Output", lines=4)
],
title="🌍 Sylheti to English Translator (via Bengali)",
description=(
"Translates Sylheti text to English in two steps:\n"
f"1. Sylheti → Bengali (using `{SYLHETI_TO_BN_MODEL}`)\n"
f"2. Bengali → English (using `{BN_TO_EN_MODEL}` with text normalization from `{NORMALIZER_REPO.split('/')[-1]}`)"
),
examples=[
["কিতা কিতা কিনলায় তে?"],
["তুমি কিতা কররায়?"],
["আমি ভাত খাইছি।"],
["আফনে ভালা আছনি?"]
],
allow_flagging="never",
cache_examples=False, # Explicitly disable example caching
theme=gr.themes.Soft()
)
# --- Launch the Gradio app ---
if __name__ == "__main__":
iface.launch()