zensalaria's picture
Update app.py
ce7e39a verified
import gradio as gr
import logging
import tempfile
import os
# Same logger setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
try:
# We use "NllbTokenizer" and "AutoModelForSeq2SeqLM" from HF
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
except ImportError:
logger.error("transformers library not found. Ensure 'transformers' is in requirements.txt.")
raise
# We'll assume your HF model is publicly available at "zensalaria/my-nllb-distilled"
MODEL_NAME = "zensalaria/my-nllb-distilled"
logger.info("Loading NLLB model from Hugging Face...")
try:
tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
logger.info("Model and tokenizer loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise e
def translate_text(input_text, target_lang="urd_Arab", max_length=512):
"""
Replicates your run script's translation logic, but in-memory (no local file writes).
"""
logger.info(f"Translating text to {target_lang}...")
try:
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=max_length,
truncation=True
)
# Forced BOS token for target language
if hasattr(tokenizer, "lang_code_to_id"):
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id[target_lang]
else:
inputs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(target_lang)
outputs = model.generate(**inputs, max_length=max_length)
translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info("Translation complete!")
return translated
except Exception as e:
logger.error(f"Error during translation: {e}")
return "Error translating text"
def process_translation_request(input_text, target_lang="urd_Arab"):
"""
Logic from process_translation_request, but uses the in-memory translate_text.
"""
if not input_text.strip():
return "Error: No text provided."
return translate_text(input_text, target_lang)
def gradio_interface(text, lang):
return process_translation_request(text, lang)
# Example language list for Gradio
LANG_CHOICES = [
("English (Latin)", "eng_Latn"),
("Urdu (Arabic)", "urd_Arab"),
("Spanish (Latin)", "spa_Latn"),
]
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Input Text", lines=3),
gr.Dropdown(
choices=[c[1] for c in LANG_CHOICES],
label="Target Language",
value="urd_Arab"
)
],
outputs="text",
title="NLLB-200 Translator",
description="Translate text using your NLLB-200-distilled-600M model."
)
if __name__ == "__main__":
demo.launch()