Owen commited on
Commit
c6818dd
·
1 Parent(s): e909f31

add conformer

Browse files
Files changed (6) hide show
  1. .gitattributes +4 -2
  2. app.py +114 -16
  3. jawa.wav +3 -0
  4. requirements.txt +4 -1
  5. sunda.wav +3 -0
  6. test.py +5 -0
.gitattributes CHANGED
@@ -33,5 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- conformer.png filter=lfs diff=lfs merge=lfs -text
37
- whisper.png filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ conformer.png filter=lfs diff=lfs merge=lfs -text
37
+ whisper.png filter=lfs diff=lfs merge=lfs -text
38
+ *.wav filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,12 +1,76 @@
 
 
 
 
1
  import numpy as np # type: ignore
2
  import gradio as gr # type: ignore
3
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Load fine-tuned Whisper model
6
- transcriber = pipeline("automatic-speech-recognition", model="OwLim/whisper-java-SLR41-SLR35")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def transcribe(audio):
9
  sr, waveform = audio
 
10
  # Change into Mono Audio
11
  if waveform.ndim > 1:
12
  waveform = waveform.mean(axis=1)
@@ -15,7 +79,33 @@ def transcribe(audio):
15
  waveform = waveform.astype(np.float32)
16
  waveform /= np.max(np.abs(waveform))
17
 
18
- return transcriber({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  "sampling_rate" : sr,
20
  "raw" : waveform
21
  })["text"]
@@ -25,6 +115,7 @@ def clear():
25
 
26
  # --- Tab 1: Transcribe ---
27
  with gr.Blocks() as tab_transcribe:
 
28
  with gr.Row():
29
  with gr.Column(scale=1):
30
  audio_input = gr.Audio(sources="microphone", label="Record Your Voice")
@@ -35,7 +126,13 @@ with gr.Blocks() as tab_transcribe:
35
  with gr.Column(scale=1):
36
  output_text = gr.Textbox(label="Transcription", placeholder="Waiting for Input", lines=3)
37
 
38
- subBtn.click(fn=transcribe, inputs=audio_input, outputs=output_text)
 
 
 
 
 
 
39
  clrBtn.click(fn=clear, outputs=[audio_input, output_text])
40
 
41
  # --- Tab 2: Penjelasan Model Fine-Tuned ---
@@ -52,19 +149,20 @@ with gr.Blocks() as tab_background:
52
  Model yang telah kami fine tune merupakan hasil <b>fine-tuning dari Whisper dan Conformer</b> untuk mendukung bahasa lokal di Indonesia, khususnya bahasa Jawa dan Sunda.
53
  Model dilatih menggunakan kombinasi dataset <b>OpenSLR</b> berikut:
54
  <br>
55
- <a href="https://openslr.org/35/" target="_blank" style="text-decoration:none;>
56
- <b>SLR35</b> - Large Javanese ASR
57
- </a>
58
-
59
  <br>
60
- <a href="https://openslr.org/41/" target="_blank" style="text-decoration:none;">
61
- <b>SLR41</b> - High quality TTS data for Javanese
62
- </a>
63
-
 
 
 
64
  <br>
65
- <a href="https://openslr.org/36" target="_blank" style="text-decoration:none;">
66
- <b>SLR36</b>
67
- <b>SLR44</b> - Bilingual speech datasets
68
  </a>
69
 
70
  <br>
@@ -172,7 +270,7 @@ demo = gr.TabbedInterface(
172
  [tab_transcribe, tab_background, tab_architecture, tab_results, tab_authors],
173
  ["Transcribe", "Latar Belakang", "Arsitektur", "Evaluasi", "Fine-Tuned By"],
174
  theme=gr.themes.Soft(),
175
- title="Whisper VS Conformer Model"
176
  )
177
 
178
  if __name__ == "__main__":
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
  import numpy as np # type: ignore
6
  import gradio as gr # type: ignore
7
  from transformers import pipeline
8
+ from huggingface_hub import hf_hub_download
9
+ from torchaudio.models import Conformer
10
+
11
+ class ASRConformerModel(nn.Module):
12
+ def __init__(self, input_dim, vocab_size):
13
+ super().__init__()
14
+ self.encoder = Conformer(
15
+ input_dim=input_dim,
16
+ num_heads=4,
17
+ ffn_dim=512,
18
+ num_layers=4,
19
+ depthwise_conv_kernel_size=31,
20
+ dropout=0.1
21
+ )
22
+ self.classifier = nn.Linear(input_dim, vocab_size)
23
+
24
+ def forward(self, x, lengths):
25
+ x, lengths = self.encoder(x, lengths=lengths)
26
+ x = self.classifier(x)
27
+ return x, lengths
28
+
29
+
30
+ VOCAB = set("abcdefghijklmnopqrstuvwxyz '")
31
+ char_to_idx = {ch: i + 1 for i, ch in enumerate(sorted(VOCAB))} # 0 for CTC blank
32
+
33
+ def greedy_decode(log_probs, blank=0):
34
+ pred_ids = log_probs.argmax(dim=-1) # [T, B]
35
+ pred_ids = pred_ids.transpose(0, 1) # [B, T]
36
+ predictions = []
37
+ for seq in pred_ids:
38
+ prev = blank
39
+ pred = []
40
+ for i in seq:
41
+ if i != prev and i != blank:
42
+ pred.append(i.item())
43
+ prev = i
44
+ predictions.append(pred)
45
+ return predictions
46
+
47
+ def encode(text):
48
+ return torch.tensor([char_to_idx[c] for c in text.lower() if c in char_to_idx], dtype=torch.long)
49
+
50
+ def decode_to_text(predictions, idx_to_char):
51
+ return [''.join(idx_to_char[i] for i in pred if i in idx_to_char) for pred in predictions]
52
 
53
  # Load fine-tuned Whisper model
54
+ transcriber_whisper = pipeline("automatic-speech-recognition", model="OwLim/whisper-sundanese-finetune")
55
+ transcriber_wav2vec = pipeline("automatic-speech-recognition", model="indonesian-nlp/wav2vec2-indonesian-javanese-sundanese")
56
+
57
+
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ SAMPLE_RATE = 16_000
60
+
61
+ model_path = hf_hub_download(repo_id="Blebbyblub/javanese-conformer-asrV2", filename="pytorch_model.bin")
62
+ model = ASRConformerModel(input_dim=80, vocab_size=29).to(device)
63
+ model.load_state_dict(torch.load(model_path, map_location=device))
64
+
65
+ examples_audio = [
66
+ file for file in os.listdir("./") if file.endswith(".wav")
67
+ ]
68
+
69
+ idx_to_char = {v: k for k, v in char_to_idx.items()}
70
 
71
+ def transcribe(audio, model_selection):
72
  sr, waveform = audio
73
+
74
  # Change into Mono Audio
75
  if waveform.ndim > 1:
76
  waveform = waveform.mean(axis=1)
 
79
  waveform = waveform.astype(np.float32)
80
  waveform /= np.max(np.abs(waveform))
81
 
82
+ if "Conformer" == model_selection :
83
+ mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_mels=80)
84
+
85
+ waveform = torch.from_numpy(waveform).float()
86
+ if sr != SAMPLE_RATE:
87
+ waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
88
+
89
+ waveform = waveform.unsqueeze(0)
90
+ mel = mel_transform(waveform).squeeze(0).transpose(0, 1) # [time, mel]
91
+ mel = mel.unsqueeze(0).to(device)
92
+ input_length = torch.tensor([mel.size(1)]).to(device)
93
+
94
+ model.eval()
95
+ with torch.no_grad():
96
+ output, output_lengths = model(mel, input_length)
97
+ log_probs = output.log_softmax(2).transpose(0, 1)
98
+ pred_ids = greedy_decode(log_probs)
99
+ pred_text = decode_to_text(pred_ids, idx_to_char)[0]
100
+
101
+ return pred_text
102
+
103
+ if "Wav2Vec" == model_selection :
104
+ selected_model = transcriber_wav2vec
105
+ elif "Whisper" == model_selection:
106
+ selected_model = transcriber_whisper
107
+
108
+ return selected_model({
109
  "sampling_rate" : sr,
110
  "raw" : waveform
111
  })["text"]
 
115
 
116
  # --- Tab 1: Transcribe ---
117
  with gr.Blocks() as tab_transcribe:
118
+ model_selector = gr.Radio(choices=["Whisper", "Conformer", "Wav2Vec"], label="Choose Model", info="This will effect the model that you use for transcribing", )
119
  with gr.Row():
120
  with gr.Column(scale=1):
121
  audio_input = gr.Audio(sources="microphone", label="Record Your Voice")
 
126
  with gr.Column(scale=1):
127
  output_text = gr.Textbox(label="Transcription", placeholder="Waiting for Input", lines=3)
128
 
129
+ gr.Examples(
130
+ examples=examples_audio, # List of audio file paths
131
+ inputs=audio_input,
132
+ label="Try with Example Audio"
133
+ )
134
+
135
+ subBtn.click(fn=transcribe, inputs=[audio_input, model_selector], outputs=output_text)
136
  clrBtn.click(fn=clear, outputs=[audio_input, output_text])
137
 
138
  # --- Tab 2: Penjelasan Model Fine-Tuned ---
 
149
  Model yang telah kami fine tune merupakan hasil <b>fine-tuning dari Whisper dan Conformer</b> untuk mendukung bahasa lokal di Indonesia, khususnya bahasa Jawa dan Sunda.
150
  Model dilatih menggunakan kombinasi dataset <b>OpenSLR</b> berikut:
151
  <br>
152
+ <a href="https://openslr.org/35/" target="_blank" style="text-decoration:none;">
153
+ <b>SLR35</b> - Large Javanese ASR training data set
154
+ </a>
 
155
  <br>
156
+ <a href="https://openslr.org/36/" target="_blank" style="text-decoration:none;">
157
+ <b>SLR36</b> - Large Sundanese ASR training data set
158
+ </a>
159
+ <br>
160
+ <a href="https://openslr.org/41/" target="_blank" style="text-decoration:none;">
161
+ <b>SLR41</b> - High quality TTS data for Javanese
162
+ </a>
163
  <br>
164
+ <a href="https://openslr.org/44" target="_blank" style="text-decoration:none;">
165
+ <b>SLR44</b> - High quality TTS data for Sundanese.
 
166
  </a>
167
 
168
  <br>
 
270
  [tab_transcribe, tab_background, tab_architecture, tab_results, tab_authors],
271
  ["Transcribe", "Latar Belakang", "Arsitektur", "Evaluasi", "Fine-Tuned By"],
272
  theme=gr.themes.Soft(),
273
+ title="Multilingual ASR Model"
274
  )
275
 
276
  if __name__ == "__main__":
jawa.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28e88d2a129ae797fde52637b187fef218c30eddc891b8189eed8c0b40bf9dec
3
+ size 200812
requirements.txt CHANGED
@@ -1,3 +1,6 @@
 
 
1
  numpy
2
  torchaudio
3
- transformers
 
 
1
+ os
2
+ torch
3
  numpy
4
  torchaudio
5
+ transformers
6
+ huggingface_hub
sunda.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbc472fcba6f3f5203a9ccde45219f1dac1242a829451f6e397c905c3774eeac
3
+ size 615864
test.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ examples_audio = [
3
+ 'data/'+ file for file in os.listdir("data")
4
+ ]
5
+ print(examples_audio)