Spaces:
Build error
Build error
Add support for parallel execution on multiple GPUs
Browse files- app.py +37 -14
- cli.py +2 -0
- src/vad.py +42 -24
- src/vadParallel.py +81 -0
- src/whisperContainer.py +91 -0
app.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
| 1 |
from typing import Iterator
|
|
|
|
| 2 |
|
| 3 |
from io import StringIO
|
| 4 |
import os
|
| 5 |
import pathlib
|
| 6 |
import tempfile
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# External programs
|
| 9 |
import whisper
|
|
@@ -14,7 +18,7 @@ import gradio as gr
|
|
| 14 |
|
| 15 |
from src.download import ExceededMaximumDuration, download_url
|
| 16 |
from src.utils import slugify, write_srt, write_vtt
|
| 17 |
-
from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
| 18 |
|
| 19 |
# Limitations (set to -1 to disable)
|
| 20 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
|
@@ -48,6 +52,7 @@ LANGUAGES = [
|
|
| 48 |
class WhisperTranscriber:
|
| 49 |
def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
|
| 50 |
self.model_cache = dict()
|
|
|
|
| 51 |
|
| 52 |
self.vad_model = None
|
| 53 |
self.inputAudioMaxDuration = inputAudioMaxDuration
|
|
@@ -64,7 +69,7 @@ class WhisperTranscriber:
|
|
| 64 |
model = self.model_cache.get(selectedModel, None)
|
| 65 |
|
| 66 |
if not model:
|
| 67 |
-
model =
|
| 68 |
self.model_cache[selectedModel] = model
|
| 69 |
|
| 70 |
# Execute whisper
|
|
@@ -87,7 +92,7 @@ class WhisperTranscriber:
|
|
| 87 |
except ExceededMaximumDuration as e:
|
| 88 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
| 89 |
|
| 90 |
-
def transcribe_file(self, model:
|
| 91 |
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
|
| 92 |
|
| 93 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
|
@@ -96,35 +101,42 @@ class WhisperTranscriber:
|
|
| 96 |
task = decodeOptions.pop('task')
|
| 97 |
|
| 98 |
# Callable for processing an audio file
|
| 99 |
-
whisperCallable =
|
| 100 |
-
language=language if language else detected_language, task=task, \
|
| 101 |
-
initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
|
| 102 |
-
**decodeOptions)
|
| 103 |
|
| 104 |
# The results
|
| 105 |
if (vad == 'silero-vad'):
|
| 106 |
# Silero VAD where non-speech gaps are transcribed
|
| 107 |
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
| 108 |
-
result = self.
|
| 109 |
elif (vad == 'silero-vad-skip-gaps'):
|
| 110 |
# Silero VAD where non-speech gaps are simply ignored
|
| 111 |
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
| 112 |
-
result = self.
|
| 113 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
| 114 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
| 115 |
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
| 116 |
-
result = self.
|
| 117 |
elif (vad == 'periodic-vad'):
|
| 118 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
| 119 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
| 120 |
periodic_vad = VadPeriodicTranscription()
|
| 121 |
-
|
|
|
|
|
|
|
| 122 |
else:
|
| 123 |
# Default VAD
|
| 124 |
result = whisperCallable(audio_path, 0, None, None)
|
| 125 |
|
| 126 |
return result
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
def _concat_prompt(self, prompt1, prompt2):
|
| 129 |
if (prompt1 is None):
|
| 130 |
return prompt2
|
|
@@ -218,9 +230,12 @@ class WhisperTranscriber:
|
|
| 218 |
return file.name
|
| 219 |
|
| 220 |
|
| 221 |
-
def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
|
| 222 |
ui = WhisperTranscriber(inputAudioMaxDuration)
|
| 223 |
|
|
|
|
|
|
|
|
|
|
| 224 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
| 225 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
| 226 |
ui_description += " as well as speech translation and language identification. "
|
|
@@ -250,7 +265,15 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
|
|
| 250 |
gr.Text(label="Segments")
|
| 251 |
])
|
| 252 |
|
| 253 |
-
demo.launch(share=share, server_name=server_name)
|
| 254 |
|
| 255 |
if __name__ == '__main__':
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Iterator
|
| 2 |
+
import argparse
|
| 3 |
|
| 4 |
from io import StringIO
|
| 5 |
import os
|
| 6 |
import pathlib
|
| 7 |
import tempfile
|
| 8 |
+
from src.vadParallel import ParallelTranscription
|
| 9 |
+
|
| 10 |
+
from src.whisperContainer import WhisperContainer
|
| 11 |
|
| 12 |
# External programs
|
| 13 |
import whisper
|
|
|
|
| 18 |
|
| 19 |
from src.download import ExceededMaximumDuration, download_url
|
| 20 |
from src.utils import slugify, write_srt, write_vtt
|
| 21 |
+
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
| 22 |
|
| 23 |
# Limitations (set to -1 to disable)
|
| 24 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
|
|
|
| 52 |
class WhisperTranscriber:
|
| 53 |
def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
|
| 54 |
self.model_cache = dict()
|
| 55 |
+
self.parallel_device_list = None
|
| 56 |
|
| 57 |
self.vad_model = None
|
| 58 |
self.inputAudioMaxDuration = inputAudioMaxDuration
|
|
|
|
| 69 |
model = self.model_cache.get(selectedModel, None)
|
| 70 |
|
| 71 |
if not model:
|
| 72 |
+
model = WhisperContainer(selectedModel)
|
| 73 |
self.model_cache[selectedModel] = model
|
| 74 |
|
| 75 |
# Execute whisper
|
|
|
|
| 92 |
except ExceededMaximumDuration as e:
|
| 93 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
| 94 |
|
| 95 |
+
def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
|
| 96 |
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
|
| 97 |
|
| 98 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
|
|
|
| 101 |
task = decodeOptions.pop('task')
|
| 102 |
|
| 103 |
# Callable for processing an audio file
|
| 104 |
+
whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# The results
|
| 107 |
if (vad == 'silero-vad'):
|
| 108 |
# Silero VAD where non-speech gaps are transcribed
|
| 109 |
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
| 110 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
|
| 111 |
elif (vad == 'silero-vad-skip-gaps'):
|
| 112 |
# Silero VAD where non-speech gaps are simply ignored
|
| 113 |
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
| 114 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
|
| 115 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
| 116 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
| 117 |
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
| 118 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
|
| 119 |
elif (vad == 'periodic-vad'):
|
| 120 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
| 121 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
| 122 |
periodic_vad = VadPeriodicTranscription()
|
| 123 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
|
| 124 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
|
| 125 |
+
|
| 126 |
else:
|
| 127 |
# Default VAD
|
| 128 |
result = whisperCallable(audio_path, 0, None, None)
|
| 129 |
|
| 130 |
return result
|
| 131 |
|
| 132 |
+
def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
|
| 133 |
+
if (self.parallel_device_list is None or len(self.parallel_device_list) == 0):
|
| 134 |
+
# No parallel devices, so just run the VAD and Whisper in sequence
|
| 135 |
+
return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
|
| 136 |
+
|
| 137 |
+
parallell_vad = ParallelTranscription()
|
| 138 |
+
return parallell_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable, config=vadConfig, devices=self.parallel_device_list)
|
| 139 |
+
|
| 140 |
def _concat_prompt(self, prompt1, prompt2):
|
| 141 |
if (prompt1 is None):
|
| 142 |
return prompt2
|
|
|
|
| 230 |
return file.name
|
| 231 |
|
| 232 |
|
| 233 |
+
def create_ui(inputAudioMaxDuration, share=False, server_name: str = None, server_port: int = 7860, vad_parallel_devices: str = None):
|
| 234 |
ui = WhisperTranscriber(inputAudioMaxDuration)
|
| 235 |
|
| 236 |
+
# Specify a list of devices to use for parallel processing
|
| 237 |
+
ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
| 238 |
+
|
| 239 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
| 240 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
| 241 |
ui_description += " as well as speech translation and language identification. "
|
|
|
|
| 265 |
gr.Text(label="Segments")
|
| 266 |
])
|
| 267 |
|
| 268 |
+
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
| 269 |
|
| 270 |
if __name__ == '__main__':
|
| 271 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 272 |
+
parser.add_argument("--inputAudioMaxDuration", type=int, default=600, help="Maximum audio file length in seconds, or -1 for no limit.")
|
| 273 |
+
parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
|
| 274 |
+
parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
|
| 275 |
+
parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
|
| 276 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="0,1", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
| 277 |
+
|
| 278 |
+
args = parser.parse_args().__dict__
|
| 279 |
+
create_ui(**args)
|
cli.py
CHANGED
|
@@ -31,6 +31,7 @@ def cli():
|
|
| 31 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
| 32 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
| 33 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
|
|
|
| 34 |
|
| 35 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 36 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
@@ -74,6 +75,7 @@ def cli():
|
|
| 74 |
|
| 75 |
model = whisper.load_model(model_name, device=device, download_root=model_dir)
|
| 76 |
transcriber = WhisperTranscriber(deleteUploadedFiles=False)
|
|
|
|
| 77 |
|
| 78 |
for audio_path in args.pop("audio"):
|
| 79 |
sources = []
|
|
|
|
| 31 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
| 32 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
| 33 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
| 34 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="0", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
|
| 35 |
|
| 36 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 37 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
|
|
| 75 |
|
| 76 |
model = whisper.load_model(model_name, device=device, download_root=model_dir)
|
| 77 |
transcriber = WhisperTranscriber(deleteUploadedFiles=False)
|
| 78 |
+
transcriber.parallel_device_list = args.pop("vad_parallel_devices")
|
| 79 |
|
| 80 |
for audio_path in args.pop("audio"):
|
| 81 |
sources = []
|
src/vad.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Any, Deque, Iterator, List, Dict
|
|
| 6 |
from pprint import pprint
|
| 7 |
|
| 8 |
from src.segments import merge_timestamps
|
|
|
|
| 9 |
|
| 10 |
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
| 11 |
try:
|
|
@@ -51,19 +52,20 @@ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
|
|
| 51 |
class TranscriptionConfig(ABC):
|
| 52 |
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
| 53 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
| 54 |
-
max_merge_size: float = None, max_prompt_window: float = None):
|
| 55 |
self.non_speech_strategy = non_speech_strategy
|
| 56 |
self.segment_padding_left = segment_padding_left
|
| 57 |
self.segment_padding_right = segment_padding_right
|
| 58 |
self.max_silent_period = max_silent_period
|
| 59 |
self.max_merge_size = max_merge_size
|
| 60 |
self.max_prompt_window = max_prompt_window
|
|
|
|
| 61 |
|
| 62 |
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
| 63 |
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
| 64 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
| 65 |
-
max_merge_size: float = None, max_prompt_window: float = None):
|
| 66 |
-
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
|
| 67 |
self.periodic_duration = periodic_duration
|
| 68 |
|
| 69 |
class AbstractTranscription(ABC):
|
|
@@ -91,37 +93,26 @@ class AbstractTranscription(ABC):
|
|
| 91 |
"""
|
| 92 |
return
|
| 93 |
|
| 94 |
-
def
|
| 95 |
"""
|
| 96 |
-
|
|
|
|
| 97 |
|
| 98 |
Parameters
|
| 99 |
----------
|
| 100 |
audio: str
|
| 101 |
-
The audio file.
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
|
| 105 |
-
the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
|
| 106 |
|
| 107 |
Returns
|
| 108 |
-------
|
| 109 |
A list of start and end timestamps, in fractional seconds.
|
| 110 |
"""
|
| 111 |
-
|
| 112 |
-
# get speech timestamps from full audio file
|
| 113 |
seconds_timestamps = self.get_transcribe_timestamps(audio, config)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
|
| 119 |
-
|
| 120 |
-
# A deque of transcribed segments that is passed to the next segment as a prompt
|
| 121 |
-
prompt_window = deque()
|
| 122 |
-
|
| 123 |
-
print("Timestamps:")
|
| 124 |
-
pprint(merged)
|
| 125 |
|
| 126 |
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
| 127 |
max_audio_duration = get_audio_duration(audio)
|
|
@@ -138,6 +129,32 @@ class AbstractTranscription(ABC):
|
|
| 138 |
|
| 139 |
print("Transcribing non-speech:")
|
| 140 |
pprint(merged)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
result = {
|
| 143 |
'text': "",
|
|
@@ -147,7 +164,7 @@ class AbstractTranscription(ABC):
|
|
| 147 |
languageCounter = Counter()
|
| 148 |
detected_language = None
|
| 149 |
|
| 150 |
-
segment_index =
|
| 151 |
|
| 152 |
# For each time segment, run whisper
|
| 153 |
for segment in merged:
|
|
@@ -172,7 +189,7 @@ class AbstractTranscription(ABC):
|
|
| 172 |
|
| 173 |
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
| 174 |
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
| 175 |
-
segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
|
| 176 |
|
| 177 |
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
| 178 |
|
|
@@ -373,6 +390,7 @@ class AbstractTranscription(ABC):
|
|
| 373 |
})
|
| 374 |
return result
|
| 375 |
|
|
|
|
| 376 |
class VadSileroTranscription(AbstractTranscription):
|
| 377 |
def __init__(self, sampling_rate: int = 16000):
|
| 378 |
super().__init__(sampling_rate=sampling_rate)
|
|
|
|
| 6 |
from pprint import pprint
|
| 7 |
|
| 8 |
from src.segments import merge_timestamps
|
| 9 |
+
from src.whisperContainer import WhisperCallback
|
| 10 |
|
| 11 |
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
| 12 |
try:
|
|
|
|
| 52 |
class TranscriptionConfig(ABC):
|
| 53 |
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
| 54 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
| 55 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
| 56 |
self.non_speech_strategy = non_speech_strategy
|
| 57 |
self.segment_padding_left = segment_padding_left
|
| 58 |
self.segment_padding_right = segment_padding_right
|
| 59 |
self.max_silent_period = max_silent_period
|
| 60 |
self.max_merge_size = max_merge_size
|
| 61 |
self.max_prompt_window = max_prompt_window
|
| 62 |
+
self.initial_segment_index = initial_segment_index
|
| 63 |
|
| 64 |
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
| 65 |
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
| 66 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
| 67 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
| 68 |
+
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
|
| 69 |
self.periodic_duration = periodic_duration
|
| 70 |
|
| 71 |
class AbstractTranscription(ABC):
|
|
|
|
| 93 |
"""
|
| 94 |
return
|
| 95 |
|
| 96 |
+
def get_merged_timestamps(self, audio: str, config: TranscriptionConfig):
|
| 97 |
"""
|
| 98 |
+
Get the start and end timestamps of the sections that should be transcribed by this VAD method,
|
| 99 |
+
after merging the segments using the specified configuration.
|
| 100 |
|
| 101 |
Parameters
|
| 102 |
----------
|
| 103 |
audio: str
|
| 104 |
+
The audio file.
|
| 105 |
+
config: TranscriptionConfig
|
| 106 |
+
The transcription configuration.
|
|
|
|
|
|
|
| 107 |
|
| 108 |
Returns
|
| 109 |
-------
|
| 110 |
A list of start and end timestamps, in fractional seconds.
|
| 111 |
"""
|
|
|
|
|
|
|
| 112 |
seconds_timestamps = self.get_transcribe_timestamps(audio, config)
|
| 113 |
|
| 114 |
+
merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size,
|
| 115 |
+
config.segment_padding_left, config.segment_padding_right)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
| 118 |
max_audio_duration = get_audio_duration(audio)
|
|
|
|
| 129 |
|
| 130 |
print("Transcribing non-speech:")
|
| 131 |
pprint(merged)
|
| 132 |
+
return merged
|
| 133 |
+
|
| 134 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
|
| 135 |
+
"""
|
| 136 |
+
Transcribe the given audo file.
|
| 137 |
+
|
| 138 |
+
Parameters
|
| 139 |
+
----------
|
| 140 |
+
audio: str
|
| 141 |
+
The audio file.
|
| 142 |
+
whisperCallable: WhisperCallback
|
| 143 |
+
A callback object to call to transcribe each segment.
|
| 144 |
+
|
| 145 |
+
Returns
|
| 146 |
+
-------
|
| 147 |
+
A list of start and end timestamps, in fractional seconds.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
# Get speech timestamps from full audio file
|
| 151 |
+
merged = self.get_merged_timestamps(audio, config)
|
| 152 |
+
|
| 153 |
+
# A deque of transcribed segments that is passed to the next segment as a prompt
|
| 154 |
+
prompt_window = deque()
|
| 155 |
+
|
| 156 |
+
print("Processing timestamps:")
|
| 157 |
+
pprint(merged)
|
| 158 |
|
| 159 |
result = {
|
| 160 |
'text': "",
|
|
|
|
| 164 |
languageCounter = Counter()
|
| 165 |
detected_language = None
|
| 166 |
|
| 167 |
+
segment_index = config.initial_segment_index
|
| 168 |
|
| 169 |
# For each time segment, run whisper
|
| 170 |
for segment in merged:
|
|
|
|
| 189 |
|
| 190 |
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
| 191 |
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
| 192 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
|
| 193 |
|
| 194 |
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
| 195 |
|
|
|
|
| 390 |
})
|
| 391 |
return result
|
| 392 |
|
| 393 |
+
|
| 394 |
class VadSileroTranscription(AbstractTranscription):
|
| 395 |
def __init__(self, sampling_rate: int = 16000):
|
| 396 |
super().__init__(sampling_rate=sampling_rate)
|
src/vadParallel.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.vad import AbstractTranscription, TranscriptionConfig
|
| 2 |
+
from src.whisperContainer import WhisperCallback
|
| 3 |
+
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
|
| 6 |
+
from typing import List
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
class ParallelTranscriptionConfig(TranscriptionConfig):
|
| 10 |
+
def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
|
| 11 |
+
super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
|
| 12 |
+
self.device_id = device_id
|
| 13 |
+
self.override_timestamps = override_timestamps
|
| 14 |
+
|
| 15 |
+
class ParallelTranscription(AbstractTranscription):
|
| 16 |
+
def __init__(self, sampling_rate: int = 16000):
|
| 17 |
+
super().__init__(sampling_rate=sampling_rate)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str]):
|
| 21 |
+
# First, get the timestamps for the original audio
|
| 22 |
+
merged = transcription.get_merged_timestamps(audio, config)
|
| 23 |
+
|
| 24 |
+
# Split into a list for each device
|
| 25 |
+
merged_split = self._chunks(merged, len(merged) // len(devices))
|
| 26 |
+
|
| 27 |
+
# Parameters that will be passed to the transcribe function
|
| 28 |
+
parameters = []
|
| 29 |
+
segment_index = config.initial_segment_index
|
| 30 |
+
|
| 31 |
+
for i in range(len(devices)):
|
| 32 |
+
device_segment_list = merged_split[i]
|
| 33 |
+
|
| 34 |
+
# Create a new config with the given device ID
|
| 35 |
+
device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
|
| 36 |
+
segment_index += len(device_segment_list)
|
| 37 |
+
|
| 38 |
+
parameters.append([audio, whisperCallable, device_config]);
|
| 39 |
+
|
| 40 |
+
merged = {
|
| 41 |
+
'text': '',
|
| 42 |
+
'segments': [],
|
| 43 |
+
'language': None
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
with Pool(len(devices)) as p:
|
| 47 |
+
# Run the transcription in parallel
|
| 48 |
+
results = p.starmap(self.transcribe, parameters)
|
| 49 |
+
|
| 50 |
+
for result in results:
|
| 51 |
+
# Merge the results
|
| 52 |
+
if (result['text'] is not None):
|
| 53 |
+
merged['text'] += result['text']
|
| 54 |
+
if (result['segments'] is not None):
|
| 55 |
+
merged['segments'].extend(result['segments'])
|
| 56 |
+
if (result['language'] is not None):
|
| 57 |
+
merged['language'] = result['language']
|
| 58 |
+
|
| 59 |
+
return merged
|
| 60 |
+
|
| 61 |
+
def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
|
| 62 |
+
return []
|
| 63 |
+
|
| 64 |
+
def get_merged_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
|
| 65 |
+
# Override timestamps that will be processed
|
| 66 |
+
if (config.override_timestamps is not None):
|
| 67 |
+
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
| 68 |
+
return config.override_timestamps
|
| 69 |
+
return super().get_merged_timestamps(audio, config)
|
| 70 |
+
|
| 71 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
|
| 72 |
+
# Override device ID
|
| 73 |
+
if (config.device_id is not None):
|
| 74 |
+
print("Using device " + config.device_id)
|
| 75 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
| 76 |
+
return super().transcribe(audio, whisperCallable, config)
|
| 77 |
+
|
| 78 |
+
def _chunks(self, lst, n):
|
| 79 |
+
"""Yield successive n-sized chunks from lst."""
|
| 80 |
+
return [lst[i:i + n] for i in range(0, len(lst), n)]
|
| 81 |
+
|
src/whisperContainer.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# External programs
|
| 2 |
+
import whisper
|
| 3 |
+
|
| 4 |
+
class WhisperContainer:
|
| 5 |
+
def __init__(self, model_name: str, device: str = None):
|
| 6 |
+
self.model_name = model_name
|
| 7 |
+
self.device = device
|
| 8 |
+
|
| 9 |
+
# Will be created on demand
|
| 10 |
+
self.model = None
|
| 11 |
+
|
| 12 |
+
def get_model(self):
|
| 13 |
+
if self.model is None:
|
| 14 |
+
print("Loading model " + self.model_name)
|
| 15 |
+
self.model = whisper.load_model(self.model_name, device=self.device)
|
| 16 |
+
return self.model
|
| 17 |
+
|
| 18 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
| 19 |
+
"""
|
| 20 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
language: str
|
| 25 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 26 |
+
task: str
|
| 27 |
+
The task - either translate or transcribe.
|
| 28 |
+
initial_prompt: str
|
| 29 |
+
The initial prompt to use for the transcription.
|
| 30 |
+
decodeOptions: dict
|
| 31 |
+
Additional options to pass to the decoder. Must be pickleable.
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
-------
|
| 35 |
+
A WhisperCallback object.
|
| 36 |
+
"""
|
| 37 |
+
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
| 38 |
+
|
| 39 |
+
# This is required for multiprocessing
|
| 40 |
+
def __getstate__(self):
|
| 41 |
+
return { "model_name": self.model_name, "device": self.device }
|
| 42 |
+
|
| 43 |
+
def __setstate__(self, state):
|
| 44 |
+
self.model_name = state["model_name"]
|
| 45 |
+
self.device = state["device"]
|
| 46 |
+
self.model = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class WhisperCallback:
|
| 50 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
| 51 |
+
self.model_container = model_container
|
| 52 |
+
self.language = language
|
| 53 |
+
self.task = task
|
| 54 |
+
self.initial_prompt = initial_prompt
|
| 55 |
+
self.decodeOptions = decodeOptions
|
| 56 |
+
|
| 57 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
|
| 58 |
+
"""
|
| 59 |
+
Peform the transcription of the given audio file or data.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
| 64 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
| 65 |
+
segment_index: int
|
| 66 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
| 67 |
+
task: str
|
| 68 |
+
The task - either translate or transcribe.
|
| 69 |
+
prompt: str
|
| 70 |
+
The prompt to use for the transcription.
|
| 71 |
+
detected_language: str
|
| 72 |
+
The detected language of the audio file.
|
| 73 |
+
|
| 74 |
+
Returns
|
| 75 |
+
-------
|
| 76 |
+
The result of the Whisper call.
|
| 77 |
+
"""
|
| 78 |
+
model = self.model_container.get_model()
|
| 79 |
+
|
| 80 |
+
return model.transcribe(audio, \
|
| 81 |
+
language=self.language if self.language else detected_language, task=self.task, \
|
| 82 |
+
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
| 83 |
+
**self.decodeOptions)
|
| 84 |
+
|
| 85 |
+
def _concat_prompt(self, prompt1, prompt2):
|
| 86 |
+
if (prompt1 is None):
|
| 87 |
+
return prompt2
|
| 88 |
+
elif (prompt2 is None):
|
| 89 |
+
return prompt1
|
| 90 |
+
else:
|
| 91 |
+
return prompt1 + " " + prompt2
|