luluw's picture
Update app.py
bd7af00 verified
import os
import time
import torch
import torchaudio
import gradio as gr
from torchaudio.transforms import Resample
from torchaudio.models.decoder import download_pretrained_files, ctc_decoder
# Constants for decoding
LM_WEIGHT = 1.23
WORD_SCORE = -0.26
def get_featurizer():
return torchaudio.transforms.MelSpectrogram(
sample_rate=16000,
n_fft=400,
win_length=400,
hop_length=160,
n_mels=80,
)
def preprocess_audio(audio_file, featurizer, target_sample_rate=16000):
"""
Preprocess the audio: load, resample, and extract features.
"""
try:
# Wait for file to be saved
wait_time = 0
while not os.path.exists(audio_file) and wait_time < 3:
time.sleep(0.1)
wait_time += 0.1
waveform, sample_rate = torchaudio.load(audio_file)
if sample_rate != target_sample_rate:
waveform = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(waveform)
return featurizer(waveform).permute(0, 2, 1)
except Exception as e:
raise ValueError(f"Error in preprocessing audio: {e}")
def decode_emission(emission, tokens, files):
try:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=tokens,
lm=files.lm,
nbest=1,
beam_size=100,
beam_threshold=50,
beam_size_token=25,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
beam_search_result = beam_search_decoder(emission)
return " ".join(beam_search_result[0][0].words).strip()
except Exception as e:
raise ValueError(f"Error in decoding: {e}")
def transcribe(audio_file, model, featurizer, tokens, files):
try:
waveform = preprocess_audio(audio_file, featurizer)
emission = model(waveform)
return decode_emission(emission, tokens, files)
except Exception as e:
return f"Error processing audio: {e}"
def launch_app(model_path, token_path="tokens.txt", share=False):
model = torch.jit.load(model_path)
model.eval().to('cpu')
with open(token_path, 'r') as f:
tokens = f.read().splitlines()
files = download_pretrained_files("librispeech-4-gram")
featurizer = get_featurizer()
def gradio_transcribe(audio_file):
return transcribe(audio_file, model, featurizer, tokens, files)
interface = gr.Interface(
fn=gradio_transcribe,
inputs=gr.Audio(sources="microphone", type="filepath", label="Speak into the microphone"),
outputs="text",
title="Conformer-Small ASR Model",
description="""<b>Trained on:</b> Mozilla Corpus, Personal Recordings, and LibriSpeech β€” 2900 hrs of audio data.<br>
<b>Training Script and Experiment Results</b> available <a href="https://github.com/LuluW8071/Conformer" target="_blank">here</a>""",
)
interface.launch(share=share)
if __name__ == "__main__":
try:
model_path = "optimized_model.pt"
token_path = "tokens.txt"
share = False
launch_app(model_path, token_path, share)
except Exception as e:
raise ValueError(f"Fatal error: {e}")