brdhaker3 commited on
Commit
df657b0
·
verified ·
1 Parent(s): 20df2da

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import speechbrain as sb
4
+ import torchaudio
5
+ from hyperpyyaml import load_hyperpyyaml
6
+ from pyctcdecode import build_ctcdecoder
7
+ import os
8
+
9
+ # Load hyperparameters and initialize the ASR model
10
+ hparams_file = "train.yaml"
11
+ with open(hparams_file, "r") as fin:
12
+ hparams = load_hyperpyyaml(fin)
13
+
14
+ # Initialize the label encoder
15
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
16
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
17
+ special_labels = {
18
+ "blank_label": hparams["blank_index"],
19
+ "unk_label": hparams["unk_index"]
20
+ }
21
+ label_encoder.load_or_create(
22
+ path=lab_enc_file,
23
+ from_didatasets=[[]],
24
+ output_key="char_list",
25
+ special_labels=special_labels,
26
+ sequence_input=True,
27
+ )
28
+
29
+ # Prepare labels for the CTC decoder
30
+ ind2lab = label_encoder.ind2lab
31
+ labels = [ind2lab[x] for x in range(len(ind2lab))]
32
+ labels = [""] + labels[1:-1] + ["1"]
33
+
34
+ # Initialize the CTC decoder
35
+ decoder = build_ctcdecoder(
36
+ labels,
37
+ kenlm_model_path=hparams["ngram_lm_path"],
38
+ alpha=0.5,
39
+ beta=1.0,
40
+ )
41
+
42
+
43
+ # Define the ASR class with the `treat_wav` method
44
+ class ASR(sb.core.Brain):
45
+ def treat_wav(self, sig):
46
+ """Process a waveform and return the transcribed text."""
47
+ feats = self.modules.wav2vec2(sig.to("cpu"), torch.tensor([1]).to("cpu"))
48
+ feats = self.modules.enc(feats)
49
+ logits = self.modules.ctc_lin(feats)
50
+ p_ctc = self.hparams.log_softmax(logits)
51
+ predicted_words = []
52
+ for logs in p_ctc:
53
+ text = decoder.decode(logs.detach().cpu().numpy())
54
+ predicted_words.append(text.split(" "))
55
+ return " ".join(predicted_words[0])
56
+
57
+
58
+ # Initialize the ASR model
59
+ asr_brain = ASR(
60
+ modules=hparams["modules"],
61
+ hparams=hparams,
62
+ run_opts={"device": "cpu"},
63
+ checkpointer=hparams["checkpointer"],
64
+ )
65
+ asr_brain.tokenizer = label_encoder
66
+ asr_brain.checkpointer.recover_if_possible()
67
+ asr_brain.modules.eval()
68
+
69
+
70
+ # Function to process audio files
71
+ def treat_wav_file(file_mic, file_upload, asr=asr_brain, device="cpu"):
72
+ if file_mic is not None:
73
+ wav = file_mic
74
+ elif file_upload is not None:
75
+ wav = file_upload
76
+ else:
77
+ return "ERROR: You have to either use the microphone or upload an audio file"
78
+
79
+ # Read and preprocess the audio file
80
+ info = torchaudio.info(wav)
81
+ sr = info.sample_rate
82
+ sig = sb.dataio.dataio.read_audio(wav)
83
+ if len(sig.shape) > 1:
84
+ sig = torch.mean(sig, dim=1)
85
+ sig = torch.unsqueeze(sig, 0)
86
+ tensor_wav = sig.to(device)
87
+ resampled = torchaudio.functional.resample(tensor_wav, sr, 16000)
88
+
89
+ # Transcribe the audio
90
+ sentence = asr.treat_wav(resampled)
91
+ return sentence
92
+
93
+
94
+ # Gradio interface
95
+ title = "Tunisian Speech Recognition"
96
+ description = ''' This is a Tunisian ASR based on WavLM Model, fine-tuned on a dataset of 2.5 Hours resulting in a W.E.R of 24% and a C.E.R of 9 %.
97
+ \n
98
+ \n Interesting isn\'t it !'''
99
+
100
+ gr.Interface(
101
+ fn=treat_wav_file,
102
+ inputs=[
103
+ gr.Audio(sources="microphone", type='filepath', label="Record"),
104
+ gr.Audio(sources="upload", type='filepath', label="Upload File")
105
+ ],
106
+ outputs="text",
107
+ title=title,
108
+ description=description
109
+ ).launch()