Dominik Macháček
commited on
Commit
·
6387098
1
Parent(s):
7eeb73f
FixedSileroVADIterator to support other than 512-sized chunks with v5
Browse files- silero_vad.py +37 -0
- whisper_online.py +1 -1
silero_vad.py
CHANGED
|
@@ -94,4 +94,41 @@ class VADIterator:
|
|
| 94 |
|
| 95 |
return None
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
return None
|
| 96 |
|
| 97 |
+
#######################
|
| 98 |
+
# this is our workaround for Silero v5 requiring at least 512-sized audio chunks
|
| 99 |
+
# (see https://github.com/ufal/whisper_streaming/issues/116 )
|
| 100 |
|
| 101 |
+
import numpy as np
|
| 102 |
+
class FixedVADIterator(VADIterator):
|
| 103 |
+
|
| 104 |
+
def reset_states(self):
|
| 105 |
+
super().reset_states()
|
| 106 |
+
self.buffer = np.array([],dtype=np.float32)
|
| 107 |
+
|
| 108 |
+
def __call__(self, x, return_seconds=False):
|
| 109 |
+
self.buffer = np.append(self.buffer, x)
|
| 110 |
+
if len(self.buffer) >= 512:
|
| 111 |
+
ret = super().__call__(self.buffer, return_seconds=return_seconds)
|
| 112 |
+
self.buffer = np.array([],dtype=np.float32)
|
| 113 |
+
return ret
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
# test/demonstrate the need for FixedVADIterator:
|
| 118 |
+
|
| 119 |
+
import torch
|
| 120 |
+
model, _ = torch.hub.load(
|
| 121 |
+
repo_or_dir='snakers4/silero-vad',
|
| 122 |
+
model='silero_vad'
|
| 123 |
+
)
|
| 124 |
+
vac = FixedVADIterator(model)
|
| 125 |
+
# vac = VADIterator(model) # the second case crashes with this
|
| 126 |
+
|
| 127 |
+
# this works: for both
|
| 128 |
+
audio_buffer = np.array([0]*(512),dtype=np.float32)
|
| 129 |
+
vac(audio_buffer)
|
| 130 |
+
|
| 131 |
+
# this crashes on the non FixedVADIterator with
|
| 132 |
+
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
| 133 |
+
audio_buffer = np.array([0]*(512-1),dtype=np.float32)
|
| 134 |
+
vac(audio_buffer)
|
whisper_online.py
CHANGED
|
@@ -531,7 +531,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor):
|
|
| 531 |
# VAC:
|
| 532 |
import torch
|
| 533 |
model, _ = torch.hub.load(
|
| 534 |
-
repo_or_dir='snakers4/silero-vad
|
| 535 |
model='silero_vad'
|
| 536 |
)
|
| 537 |
from silero_vad import VADIterator
|
|
|
|
| 531 |
# VAC:
|
| 532 |
import torch
|
| 533 |
model, _ = torch.hub.load(
|
| 534 |
+
repo_or_dir='snakers4/silero-vad',
|
| 535 |
model='silero_vad'
|
| 536 |
)
|
| 537 |
from silero_vad import VADIterator
|