Spaces:
Runtime error
Runtime error
Vaibhav Srivastav
commited on
Commit
Β·
b8af00e
1
Parent(s):
851eb15
adding decoding w lm
Browse files- 4gram_small.arpa.gz +3 -0
- app.py +24 -2
4gram_small.arpa.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4c4fe64751abecdeb7040fe6ed7f2440c2d3f36ed35c43e3510f7cf95578f2a
|
| 3 |
+
size 18358716
|
app.py
CHANGED
|
@@ -42,6 +42,28 @@ def predict_and_ctc_decode(input_file, model_name):
|
|
| 42 |
|
| 43 |
return transcribed_text
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def predict_and_greedy_decode(input_file, model_name):
|
| 46 |
processor, model = return_processor_and_model(model_name)
|
| 47 |
speech = load_and_fix_data(input_file)
|
|
@@ -57,12 +79,12 @@ def predict_and_greedy_decode(input_file, model_name):
|
|
| 57 |
return transcribed_text
|
| 58 |
|
| 59 |
def return_all_predictions(input_file, model_name):
|
| 60 |
-
return predict_and_ctc_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
|
| 61 |
|
| 62 |
|
| 63 |
gr.Interface(return_all_predictions,
|
| 64 |
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
|
| 65 |
-
outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Greedy decoding")],
|
| 66 |
title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
|
| 67 |
description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
|
| 68 |
layout = "horizontal",
|
|
|
|
| 42 |
|
| 43 |
return transcribed_text
|
| 44 |
|
| 45 |
+
def predict_and_ctc_lm_decode(input_file, model_name):
|
| 46 |
+
processor, model = return_processor_and_model(model_name)
|
| 47 |
+
speech = load_and_fix_data(input_file)
|
| 48 |
+
|
| 49 |
+
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
|
| 50 |
+
logits = model(input_values).logits.cpu().detach().numpy()[0]
|
| 51 |
+
|
| 52 |
+
vocab_list = list(processor.tokenizer.get_vocab().keys())
|
| 53 |
+
vocab_dict = processor.tokenizer.get_vocab()
|
| 54 |
+
sorted_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
|
| 55 |
+
|
| 56 |
+
decoder = build_ctcdecoder(
|
| 57 |
+
list(sorted_dict.keys()),
|
| 58 |
+
"4gram_small.arpa.gz",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
pred = decoder.decode(logits)
|
| 62 |
+
|
| 63 |
+
transcribed_text = fix_transcription_casing(pred.lower())
|
| 64 |
+
|
| 65 |
+
return transcribed_text
|
| 66 |
+
|
| 67 |
def predict_and_greedy_decode(input_file, model_name):
|
| 68 |
processor, model = return_processor_and_model(model_name)
|
| 69 |
speech = load_and_fix_data(input_file)
|
|
|
|
| 79 |
return transcribed_text
|
| 80 |
|
| 81 |
def return_all_predictions(input_file, model_name):
|
| 82 |
+
return predict_and_ctc_decode(input_file, model_name), predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
|
| 83 |
|
| 84 |
|
| 85 |
gr.Interface(return_all_predictions,
|
| 86 |
inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
|
| 87 |
+
outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
|
| 88 |
title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
|
| 89 |
description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
|
| 90 |
layout = "horizontal",
|