Spaces:
Sleeping
Sleeping
fix buggy inference
Browse files- asr.py +18 -6
- requirements.txt +1 -0
asr.py
CHANGED
@@ -2,6 +2,7 @@ from transformers import Wav2Vec2ForCTC, AutoProcessor
|
|
2 |
import torchaudio
|
3 |
import torch
|
4 |
import os
|
|
|
5 |
|
6 |
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
7 |
|
@@ -18,12 +19,23 @@ def load_model():
|
|
18 |
|
19 |
|
20 |
def inference(processor, model, audio_path):
|
21 |
-
|
22 |
-
inputs = processor(
|
23 |
-
|
24 |
with torch.no_grad():
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
return transcription
|
|
|
2 |
import torchaudio
|
3 |
import torch
|
4 |
import os
|
5 |
+
import librosa
|
6 |
|
7 |
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
8 |
|
|
|
19 |
|
20 |
|
21 |
def inference(processor, model, audio_path):
|
22 |
+
audio, sampling_rate = librosa.load(audio_path, sr=16000) # Ensure the correct sampling rate
|
23 |
+
inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
|
24 |
+
|
25 |
with torch.no_grad():
|
26 |
+
logits = model(inputs.input_values).logits
|
27 |
+
|
28 |
+
# Decode predicted tokens
|
29 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
30 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
31 |
+
|
32 |
+
|
33 |
+
#arr, rate = read_audio_data(audio_path)
|
34 |
+
#inputs = processor(arr.squeeze().numpy(), sampling_rate=16_000, return_tensors="pt")
|
35 |
+
|
36 |
+
#with torch.no_grad():
|
37 |
+
# outputs = model(**inputs).logits
|
38 |
+
#ids = torch.argmax(outputs, dim=-1)[0]
|
39 |
+
#transcription = processor.decode(ids)
|
40 |
|
41 |
return transcription
|
requirements.txt
CHANGED
@@ -4,3 +4,4 @@ torch
|
|
4 |
torchaudio
|
5 |
streamlit_webrtc
|
6 |
audio_recorder_streamlit
|
|
|
|
4 |
torchaudio
|
5 |
streamlit_webrtc
|
6 |
audio_recorder_streamlit
|
7 |
+
librosa
|