translators / app.py
oslohaze's picture
Update app.py
059c884 verified
import gradio as gr
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, MarianMTModel, MarianTokenizer
# Load the translation model and tokenizer
translation_model_name = "Helsinki-NLP/opus-mt-en-ml"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name)
# Load the speech recognition model and tokenizer
asr_model_name = "facebook/wav2vec2-large-960h"
asr_processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
# Translation function
def translate_text(text):
inputs = translation_tokenizer(text, return_tensors="pt", padding=True)
outputs = translation_model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True)
translated_text = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
# Speech-to-text function
def speech_to_text(audio_path):
speech, rate = torchaudio.load(audio_path)
input_values = asr_processor(speech.squeeze(), sampling_rate=rate, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = logits.argmax(dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)[0]
return transcription
# Combined function for Gradio interface
def translate_speech(audio_path):
text = speech_to_text(audio_path)
translation = translate_text(text)
return translation
# Gradio interface
iface = gr.Interface(
fn=lambda text, audio_path: (translate_text(text), translate_speech(audio_path) if audio_path else None),
inputs=[gr.Textbox(label="Input English Text"), gr.Audio(type="filepath")],
outputs=[gr.Textbox(label="Translated Malayalam Text (from Text)"), gr.Textbox(label="Translated Malayalam Text (from Speech)")],
title="English to Malayalam Translator",
description="Translate English text or speech to Malayalam. Either enter text or speak into the microphone."
)
if __name__ == "__main__":
iface.launch()