Lguyogiro commited on
Commit
819cd84
·
1 Parent(s): e678dfe

fix buggy inference

Browse files
Files changed (2) hide show
  1. asr.py +18 -6
  2. requirements.txt +1 -0
asr.py CHANGED
@@ -2,6 +2,7 @@ from transformers import Wav2Vec2ForCTC, AutoProcessor
2
  import torchaudio
3
  import torch
4
  import os
 
5
 
6
  hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
7
 
@@ -18,12 +19,23 @@ def load_model():
18
 
19
 
20
  def inference(processor, model, audio_path):
21
- arr, rate = read_audio_data(audio_path)
22
- inputs = processor(arr.squeeze().numpy(), sampling_rate=16_000, return_tensors="pt")
23
-
24
  with torch.no_grad():
25
- outputs = model(**inputs).logits
26
- ids = torch.argmax(outputs, dim=-1)[0]
27
- transcription = processor.decode(ids)
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  return transcription
 
2
  import torchaudio
3
  import torch
4
  import os
5
+ import librosa
6
 
7
  hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
8
 
 
19
 
20
 
21
  def inference(processor, model, audio_path):
22
+ audio, sampling_rate = librosa.load(audio_path, sr=16000) # Ensure the correct sampling rate
23
+ inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
24
+
25
  with torch.no_grad():
26
+ logits = model(inputs.input_values).logits
27
+
28
+ # Decode predicted tokens
29
+ predicted_ids = torch.argmax(logits, dim=-1)
30
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
31
+
32
+
33
+ #arr, rate = read_audio_data(audio_path)
34
+ #inputs = processor(arr.squeeze().numpy(), sampling_rate=16_000, return_tensors="pt")
35
+
36
+ #with torch.no_grad():
37
+ # outputs = model(**inputs).logits
38
+ #ids = torch.argmax(outputs, dim=-1)[0]
39
+ #transcription = processor.decode(ids)
40
 
41
  return transcription
requirements.txt CHANGED
@@ -4,3 +4,4 @@ torch
4
  torchaudio
5
  streamlit_webrtc
6
  audio_recorder_streamlit
 
 
4
  torchaudio
5
  streamlit_webrtc
6
  audio_recorder_streamlit
7
+ librosa