Spaces:
Runtime error
Runtime error
Fix unit test
Browse files- src/whisper/abstractWhisperContainer.py +10 -2
- tests/vad_test.py +10 -4
src/whisper/abstractWhisperContainer.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import abc
|
| 2 |
-
from typing import List
|
| 3 |
|
| 4 |
from src.config import ModelConfig, VadInitialPromptMode
|
| 5 |
|
|
@@ -9,7 +9,7 @@ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
|
| 9 |
|
| 10 |
class AbstractWhisperCallback:
|
| 11 |
def __init__(self):
|
| 12 |
-
|
| 13 |
|
| 14 |
@abc.abstractmethod
|
| 15 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
@@ -29,6 +29,14 @@ class AbstractWhisperCallback:
|
|
| 29 |
"""
|
| 30 |
raise NotImplementedError()
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
class AbstractWhisperContainer:
|
| 33 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
| 34 |
download_root: str = None,
|
|
|
|
| 1 |
import abc
|
| 2 |
+
from typing import Any, Callable, List
|
| 3 |
|
| 4 |
from src.config import ModelConfig, VadInitialPromptMode
|
| 5 |
|
|
|
|
| 9 |
|
| 10 |
class AbstractWhisperCallback:
|
| 11 |
def __init__(self):
|
| 12 |
+
pass
|
| 13 |
|
| 14 |
@abc.abstractmethod
|
| 15 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
|
|
| 29 |
"""
|
| 30 |
raise NotImplementedError()
|
| 31 |
|
| 32 |
+
class LambdaWhisperCallback(AbstractWhisperCallback):
|
| 33 |
+
def __init__(self, callback_lambda: Callable[[Any, int, str, str, ProgressListener], None]):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.callback_lambda = callback_lambda
|
| 36 |
+
|
| 37 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
| 38 |
+
return self.callback_lambda(audio, segment_index, prompt, detected_language, progress_listener)
|
| 39 |
+
|
| 40 |
class AbstractWhisperContainer:
|
| 41 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
| 42 |
download_root: str = None,
|
tests/vad_test.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
import pprint
|
| 2 |
import unittest
|
| 3 |
import numpy as np
|
| 4 |
import sys
|
| 5 |
|
| 6 |
sys.path.append('../whisper-webui')
|
|
|
|
| 7 |
|
|
|
|
| 8 |
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
| 9 |
|
| 10 |
class TestVad(unittest.TestCase):
|
|
@@ -13,10 +14,11 @@ class TestVad(unittest.TestCase):
|
|
| 13 |
self.transcribe_calls = []
|
| 14 |
|
| 15 |
def test_transcript(self):
|
| 16 |
-
mock = MockVadTranscription()
|
|
|
|
| 17 |
|
| 18 |
self.transcribe_calls.clear()
|
| 19 |
-
result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
|
| 20 |
|
| 21 |
self.assertListEqual(self.transcribe_calls, [
|
| 22 |
[30, 30],
|
|
@@ -45,8 +47,9 @@ class TestVad(unittest.TestCase):
|
|
| 45 |
}
|
| 46 |
|
| 47 |
class MockVadTranscription(AbstractTranscription):
|
| 48 |
-
def __init__(self):
|
| 49 |
super().__init__()
|
|
|
|
| 50 |
|
| 51 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
| 52 |
start_time_seconds = float(start_time.removesuffix("s"))
|
|
@@ -61,6 +64,9 @@ class MockVadTranscription(AbstractTranscription):
|
|
| 61 |
result.append( { 'start': 30, 'end': 60 } )
|
| 62 |
result.append( { 'start': 100, 'end': 200 } )
|
| 63 |
return result
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
if __name__ == '__main__':
|
| 66 |
unittest.main()
|
|
|
|
|
|
|
| 1 |
import unittest
|
| 2 |
import numpy as np
|
| 3 |
import sys
|
| 4 |
|
| 5 |
sys.path.append('../whisper-webui')
|
| 6 |
+
#print("Sys path: " + str(sys.path))
|
| 7 |
|
| 8 |
+
from src.whisper.abstractWhisperContainer import LambdaWhisperCallback
|
| 9 |
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
| 10 |
|
| 11 |
class TestVad(unittest.TestCase):
|
|
|
|
| 14 |
self.transcribe_calls = []
|
| 15 |
|
| 16 |
def test_transcript(self):
|
| 17 |
+
mock = MockVadTranscription(mock_audio_length=120)
|
| 18 |
+
config = TranscriptionConfig()
|
| 19 |
|
| 20 |
self.transcribe_calls.clear()
|
| 21 |
+
result = mock.transcribe("mock", LambdaWhisperCallback(lambda segment, _1, _2, _3, _4: self.transcribe_segments(segment)), config)
|
| 22 |
|
| 23 |
self.assertListEqual(self.transcribe_calls, [
|
| 24 |
[30, 30],
|
|
|
|
| 47 |
}
|
| 48 |
|
| 49 |
class MockVadTranscription(AbstractTranscription):
|
| 50 |
+
def __init__(self, mock_audio_length: float = 1000):
|
| 51 |
super().__init__()
|
| 52 |
+
self.mock_audio_length = mock_audio_length
|
| 53 |
|
| 54 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
| 55 |
start_time_seconds = float(start_time.removesuffix("s"))
|
|
|
|
| 64 |
result.append( { 'start': 30, 'end': 60 } )
|
| 65 |
result.append( { 'start': 100, 'end': 200 } )
|
| 66 |
return result
|
| 67 |
+
|
| 68 |
+
def get_audio_duration(self, audio: str, config: TranscriptionConfig):
|
| 69 |
+
return self.mock_audio_length
|
| 70 |
|
| 71 |
if __name__ == '__main__':
|
| 72 |
unittest.main()
|