|
import torch
|
|
import gradio as gr
|
|
import speechbrain as sb
|
|
import torchaudio
|
|
from hyperpyyaml import load_hyperpyyaml
|
|
from pyctcdecode import build_ctcdecoder
|
|
import os
|
|
|
|
|
|
hparams_file = "train.yaml"
|
|
with open(hparams_file, "r") as fin:
|
|
hparams = load_hyperpyyaml(fin)
|
|
|
|
|
|
label_encoder = sb.dataio.encoder.CTCTextEncoder()
|
|
lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
|
|
special_labels = {
|
|
"blank_label": hparams["blank_index"],
|
|
"unk_label": hparams["unk_index"]
|
|
}
|
|
label_encoder.load_or_create(
|
|
path=lab_enc_file,
|
|
from_didatasets=[[]],
|
|
output_key="char_list",
|
|
special_labels=special_labels,
|
|
sequence_input=True,
|
|
)
|
|
|
|
|
|
ind2lab = label_encoder.ind2lab
|
|
labels = [ind2lab[x] for x in range(len(ind2lab))]
|
|
labels = [""] + labels[1:-1] + ["1"]
|
|
|
|
|
|
decoder = build_ctcdecoder(
|
|
labels,
|
|
kenlm_model_path=hparams["ngram_lm_path"],
|
|
alpha=0.5,
|
|
beta=1.0,
|
|
)
|
|
|
|
|
|
|
|
class ASR(sb.core.Brain):
|
|
def treat_wav(self, sig):
|
|
"""Process a waveform and return the transcribed text."""
|
|
feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
|
|
feats = self.modules.enc(feats)
|
|
logits = self.modules.ctc_lin(feats)
|
|
p_ctc = self.hparams.log_softmax(logits)
|
|
predicted_words = []
|
|
for logs in p_ctc:
|
|
text = decoder.decode(logs.detach().cpu().numpy())
|
|
predicted_words.append(text.split(" "))
|
|
return " ".join(predicted_words[0])
|
|
|
|
|
|
|
|
asr_brain = ASR(
|
|
modules=hparams["modules"],
|
|
hparams=hparams,
|
|
run_opts={"device": "cpu"},
|
|
checkpointer=hparams["checkpointer"],
|
|
)
|
|
asr_brain.tokenizer = label_encoder
|
|
asr_brain.checkpointer.recover_if_possible()
|
|
asr_brain.modules.eval()
|
|
|
|
|
|
|
|
def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
|
|
if file_mic is not None:
|
|
wav = file_mic
|
|
elif file_upload is not None:
|
|
wav = file_upload
|
|
else:
|
|
return "ERROR: You have to either use the microphone or upload an audio file"
|
|
|
|
|
|
info = torchaudio.info(wav)
|
|
sr = info.sample_rate
|
|
sig = sb.dataio.dataio.read_audio(wav)
|
|
if len(sig.shape) > 1:
|
|
sig = torch.mean(sig, dim=1)
|
|
sig = torch.unsqueeze(sig, 0)
|
|
tensor_wav = sig.to(device)
|
|
resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
|
|
|
|
|
|
sentence = asr.treat_wav(resampled)
|
|
return sentence
|
|
|
|
|
|
|
|
title = "Tunisian Speech Recognition"
|
|
description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
|
|
\n
|
|
\n Interesting isn\'t it !'''
|
|
|
|
gr.Interface(
|
|
fn=treat_wav_file,
|
|
inputs=[
|
|
gr.Audio(sources="microphone", type='filepath', label="Record"),
|
|
gr.Audio(sources="upload", type='filepath', label="Upload File")
|
|
],
|
|
outputs="text",
|
|
title=title,
|
|
description=description
|
|
).launch() |