Ne-En-Trn / app.py
bishaltwr's picture
init
5b4b058
raw
history blame
7.55 kB
import gradio as gr
import torch
import os
import io
from gtts import gTTS
import soundfile as sf
import tempfile
import logging
# Import your existing functionality
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import Wav2Vec2ForCTC, AutoProcessor
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Initialize translation model
checkpoint_dir = "bishaltwr/final_m2m100" # Change to Hugging Face model ID when deployed
try:
tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_m2m.to(device)
m2m_available = True
except Exception as e:
logging.error(f"Error loading M2M100 model: {e}")
m2m_available = False
# Initialize ASR model
model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
try:
processor = AutoProcessor.from_pretrained(model_id)
model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
asr_available = True
except Exception as e:
logging.error(f"Error loading ASR model: {e}")
asr_available = False
# Initialize X-Transformer model
try:
from inference import translate as xtranslate
xtransformer_available = True
except Exception as e:
logging.error(f"Error loading XTransformer model: {e}")
xtransformer_available = False
def m2m_translate(text, source_lang, target_lang):
"""Translation using M2M100 model"""
if not m2m_available:
return "M2M100 model not available"
tokenizer.src_lang = source_lang
inputs = tokenizer(text, return_tensors="pt").to(device)
translated_tokens = model_m2m.generate(
**inputs,
forced_bos_token_id=tokenizer.get_lang_id(target_lang)
)
translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return translated_text
def transcribe_audio(audio_path, language="npi"):
"""Transcribe audio using ASR model"""
if not asr_available:
return "ASR model not available"
import librosa
audio, sr = librosa.load(audio_path, sr=16000)
processor.tokenizer.set_target_lang(language)
model_asr.load_adapter(language)
inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = model_asr(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids, skip_special_tokens=True)
if language == "eng":
transcription = transcription.replace('<pad>','').replace('<unk>','')
else:
transcription = transcription.replace('<pad>',' ').replace('<unk>','')
return transcription
def text_to_speech(text):
"""Convert text to speech using gTTS"""
if not text:
return None
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
tts = gTTS(text=text)
tts.save(temp_audio.name)
return temp_audio.name
except Exception as e:
logging.error(f"TTS error: {e}")
return None
def detect_language(text):
"""Simple language detection function"""
english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
return "en" if english_chars > len(text) * 0.5 else "ne"
def translate_text(text, model_choice, source_lang=None, target_lang=None):
"""Main translation function"""
if not text:
return "Please enter some text to translate"
# Auto-detect language if not specified
if not source_lang:
source_lang = detect_language(text)
target_lang = "ne" if source_lang == "en" else "en"
# Choose the translation model
if model_choice == "XTransformer" and xtransformer_available:
return xtranslate(text)
elif model_choice == "M2M100" and m2m_available:
return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
else:
return "Selected model is not available"
# Set up the Gradio interface
with gr.Blocks(title="Nepali-English Translator") as demo:
gr.Markdown("# Nepali-English Translation Service")
gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
# Set up tabs for different functions
with gr.Tabs():
# Text Translation Tab
with gr.TabItem("Text Translation"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Input Text", lines=5)
with gr.Row():
model_choice = gr.Radio(
choices=["XTransformer", "M2M100"],
value="XTransformer",
label="Translation Model"
)
with gr.Row():
source_lang = gr.Dropdown(
choices=["Auto-detect", "en", "ne"],
value="Auto-detect",
label="Source Language",
visible=True
)
target_lang = gr.Dropdown(
choices=["Auto-select", "en", "ne"],
value="Auto-select",
label="Target Language",
visible=True
)
translate_button = gr.Button("Translate")
with gr.Column():
translation_output = gr.Textbox(label="Translation Output", lines=5)
tts_button = gr.Button("Convert to Speech")
audio_output = gr.Audio(label="Audio Output")
# Speech to Text Tab
with gr.TabItem("Speech to Text"):
with gr.Column():
audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
asr_language = gr.Radio(
choices=["eng", "npi"],
value="npi",
label="Speech Language"
)
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription Output", lines=3)
# Define event handlers
def process_translation(text, model, src_lang, tgt_lang):
if src_lang == "Auto-detect":
src_lang = None
if tgt_lang == "Auto-select":
tgt_lang = None
return translate_text(text, model, src_lang, tgt_lang)
def process_tts(text):
return text_to_speech(text)
def process_transcription(audio_path, language):
if not audio_path:
return "Please upload or record audio"
return transcribe_audio(audio_path, language)
# Connect the components
translate_button.click(
process_translation,
inputs=[text_input, model_choice, source_lang, target_lang],
outputs=translation_output
)
tts_button.click(
process_tts,
inputs=translation_output,
outputs=audio_output
)
transcribe_button.click(
process_transcription,
inputs=[audio_input, asr_language],
outputs=transcription_output
)
# Launch the app
if __name__ == "__main__":
demo.launch()