|
import gradio as gr |
|
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
import subprocess |
|
import sys |
|
import os |
|
|
|
|
|
SYLHETI_TO_BN_MODEL = "shbhro/sylhetit5" |
|
BN_TO_EN_MODEL = "csebuetnlp/banglat5_nmt_bn_en" |
|
NORMALIZER_REPO = "https://github.com/csebuetnlp/normalizer.git" |
|
|
|
|
|
normalizer_module = None |
|
dummy_normalizer_flag = False |
|
|
|
def dummy_normalize_func(text): |
|
raise RuntimeError("Normalizer library could not be loaded. Please check installation and logs.") |
|
|
|
try: |
|
from normalizer import normalize as normalize_fn_imported |
|
normalizer_module = normalize_fn_imported |
|
print("Normalizer imported successfully.") |
|
except ImportError: |
|
print(f"Normalizer library not found. Attempting to install from {NORMALIZER_REPO}...") |
|
try: |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+{NORMALIZER_REPO}#egg=normalizer"]) |
|
from normalizer import normalize as normalize_fn_imported_after_install |
|
normalizer_module = normalize_fn_imported_after_install |
|
print("Normalizer installed and imported successfully after pip install.") |
|
except Exception as e: |
|
print(f"Failed to install or import normalizer: {e}") |
|
print("Please ensure 'git+https://github.com/csebuetnlp/normalizer.git#egg=normalizer' is in your requirements.txt for Hugging Face Spaces.") |
|
normalizer_module = dummy_normalize_func |
|
dummy_normalizer_flag = True |
|
|
|
|
|
|
|
sylheti_to_bn_pipe = None |
|
bn_to_en_model = None |
|
bn_to_en_tokenizer = None |
|
model_device = None |
|
|
|
print("Loading translation models...") |
|
try: |
|
model_device_type = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_device = torch.device(model_device_type) |
|
hf_device_param = 0 if model_device_type == "cuda" else -1 |
|
|
|
print(f"Using device: {model_device_type}") |
|
|
|
sylheti_to_bn_pipe = pipeline( |
|
"text2text-generation", |
|
model=SYLHETI_TO_BN_MODEL, |
|
device=hf_device_param |
|
) |
|
print(f"Sylheti-to-Bengali model ({SYLHETI_TO_BN_MODEL}) loaded.") |
|
|
|
bn_to_en_model = AutoModelForSeq2SeqLM.from_pretrained(BN_TO_EN_MODEL) |
|
bn_to_en_tokenizer = AutoTokenizer.from_pretrained(BN_TO_EN_MODEL, use_fast=False) |
|
bn_to_en_model.to(model_device) |
|
print(f"Bengali-to-English model ({BN_TO_EN_MODEL}) loaded.") |
|
|
|
except Exception as e: |
|
print(f"FATAL: Error loading one or more models: {e}") |
|
sylheti_to_bn_pipe = None |
|
bn_to_en_model = None |
|
bn_to_en_tokenizer = None |
|
|
|
|
|
def translate_sylheti_to_english_gradio(sylheti_text_input): |
|
if not sylheti_text_input.strip(): |
|
return "Please enter some Sylheti text.", "" |
|
|
|
if not sylheti_to_bn_pipe: |
|
return "Error: Sylheti-to-Bengali model not loaded. Check logs.", "" |
|
if not bn_to_en_model or not bn_to_en_tokenizer: |
|
return "Error: Bengali-to-English model not loaded. Check logs.", "" |
|
|
|
|
|
if dummy_normalizer_flag or normalizer_module is None: |
|
return "Error: Bengali normalizer library not available. Check logs.", "" |
|
|
|
|
|
bengali_text_intermediate = "Error in Sylheti to Bengali step." |
|
english_text_final = "Error in Bengali to English step." |
|
|
|
|
|
try: |
|
print(f"Translating Sylheti to Bengali: '{sylheti_text_input}'") |
|
bengali_translation_outputs = sylheti_to_bn_pipe( |
|
sylheti_text_input, |
|
max_length=128, |
|
num_beams=5, |
|
early_stopping=True |
|
) |
|
bengali_text_intermediate = bengali_translation_outputs[0]['generated_text'] |
|
print(f"Intermediate Bengali: '{bengali_text_intermediate}'") |
|
except Exception as e: |
|
print(f"Error during Sylheti to Bengali translation: {e}") |
|
bengali_text_intermediate = f"Sylheti->Bengali Error: {str(e)}" |
|
return bengali_text_intermediate, english_text_final |
|
|
|
|
|
try: |
|
print(f"Normalizing and translating Bengali to English: '{bengali_text_intermediate}'") |
|
|
|
if callable(normalizer_module): |
|
normalized_bn_text = normalizer_module(bengali_text_intermediate) |
|
else: |
|
|
|
raise RuntimeError("Normalizer function is not callable.") |
|
|
|
print(f"Normalized Bengali: '{normalized_bn_text}'") |
|
|
|
input_ids = bn_to_en_tokenizer( |
|
normalized_bn_text, |
|
return_tensors="pt" |
|
).input_ids.to(model_device) |
|
|
|
generated_tokens = bn_to_en_model.generate( |
|
input_ids, |
|
max_length=128, |
|
num_beams=5, |
|
early_stopping=True |
|
) |
|
english_text_list = bn_to_en_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
english_text_final = english_text_list[0] if english_text_list else "No English output generated." |
|
print(f"Final English: '{english_text_final}'") |
|
except Exception as e: |
|
print(f"Error during Bengali to English translation: {e}") |
|
english_text_final = f"Bengali->English Error: {str(e)}" |
|
|
|
return bengali_text_intermediate, english_text_final |
|
|
|
|
|
iface = gr.Interface( |
|
fn=translate_sylheti_to_english_gradio, |
|
inputs=gr.Textbox( |
|
lines=4, |
|
label="Enter Sylheti Text", |
|
placeholder="কিতা কিতা কিনলায় তে?" |
|
), |
|
outputs=[ |
|
gr.Textbox(label="Intermediate Bengali Output", lines=4), |
|
gr.Textbox(label="Final English Output", lines=4) |
|
], |
|
title="🌍 Sylheti to English Translator (via Bengali)", |
|
description=( |
|
"Translates Sylheti text to English in two steps:\n" |
|
f"1. Sylheti → Bengali (using `{SYLHETI_TO_BN_MODEL}`)\n" |
|
f"2. Bengali → English (using `{BN_TO_EN_MODEL}` with text normalization from `{NORMALIZER_REPO.split('/')[-1]}`)" |
|
), |
|
examples=[ |
|
["কিতা কিতা কিনলায় তে?"], |
|
["তুমি কিতা কররায়?"], |
|
["আমি ভাত খাইছি।"], |
|
["আফনে ভালা আছনি?"] |
|
], |
|
allow_flagging="never", |
|
cache_examples=False, |
|
theme=gr.themes.Soft() |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|