Spaces:
Runtime error
Runtime error
Fix CLI for parallel devices
Browse files- app.py +4 -1
- cli.py +8 -4
- src/vadParallel.py +12 -5
- src/whisperContainer.py +3 -2
app.py
CHANGED
|
@@ -60,6 +60,9 @@ class WhisperTranscriber:
|
|
| 60 |
self.inputAudioMaxDuration = input_audio_max_duration
|
| 61 |
self.deleteUploadedFiles = delete_uploaded_files
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
|
| 64 |
try:
|
| 65 |
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
|
|
@@ -255,7 +258,7 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
|
|
| 255 |
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
|
| 256 |
|
| 257 |
# Specify a list of devices to use for parallel processing
|
| 258 |
-
ui.
|
| 259 |
|
| 260 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
| 261 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
|
|
|
| 60 |
self.inputAudioMaxDuration = input_audio_max_duration
|
| 61 |
self.deleteUploadedFiles = delete_uploaded_files
|
| 62 |
|
| 63 |
+
def set_parallel_devices(self, vad_parallel_devices: str):
|
| 64 |
+
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
| 65 |
+
|
| 66 |
def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
|
| 67 |
try:
|
| 68 |
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
|
|
|
|
| 258 |
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
|
| 259 |
|
| 260 |
# Specify a list of devices to use for parallel processing
|
| 261 |
+
ui.set_parallel_devices(vad_parallel_devices)
|
| 262 |
|
| 263 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
| 264 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
cli.py
CHANGED
|
@@ -12,6 +12,7 @@ from app import LANGUAGES, WhisperTranscriber
|
|
| 12 |
from src.download import download_url
|
| 13 |
|
| 14 |
from src.utils import optional_float, optional_int, str2bool
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
def cli():
|
|
@@ -31,7 +32,7 @@ def cli():
|
|
| 31 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
| 32 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
| 33 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
| 34 |
-
parser.add_argument("--vad_parallel_devices", type=str, default="
|
| 35 |
|
| 36 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 37 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
@@ -73,9 +74,12 @@ def cli():
|
|
| 73 |
vad_padding = args.pop("vad_padding")
|
| 74 |
vad_prompt_window = args.pop("vad_prompt_window")
|
| 75 |
|
| 76 |
-
model =
|
| 77 |
transcriber = WhisperTranscriber(delete_uploaded_files=False)
|
| 78 |
-
transcriber.
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
for audio_path in args.pop("audio"):
|
| 81 |
sources = []
|
|
@@ -99,7 +103,7 @@ def cli():
|
|
| 99 |
|
| 100 |
transcriber.write_result(result, source_name, output_dir)
|
| 101 |
|
| 102 |
-
transcriber.
|
| 103 |
|
| 104 |
def uri_validator(x):
|
| 105 |
try:
|
|
|
|
| 12 |
from src.download import download_url
|
| 13 |
|
| 14 |
from src.utils import optional_float, optional_int, str2bool
|
| 15 |
+
from src.whisperContainer import WhisperContainer
|
| 16 |
|
| 17 |
|
| 18 |
def cli():
|
|
|
|
| 32 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
| 33 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
| 34 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
| 35 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
|
| 36 |
|
| 37 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
| 38 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
|
|
| 74 |
vad_padding = args.pop("vad_padding")
|
| 75 |
vad_prompt_window = args.pop("vad_prompt_window")
|
| 76 |
|
| 77 |
+
model = WhisperContainer(model_name, device=device, download_root=model_dir)
|
| 78 |
transcriber = WhisperTranscriber(delete_uploaded_files=False)
|
| 79 |
+
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
| 80 |
+
|
| 81 |
+
if (transcriber._has_parallel_devices()):
|
| 82 |
+
print("Using parallel devices:", transcriber.parallel_device_list)
|
| 83 |
|
| 84 |
for audio_path in args.pop("audio"):
|
| 85 |
sources = []
|
|
|
|
| 103 |
|
| 104 |
transcriber.write_result(result, source_name, output_dir)
|
| 105 |
|
| 106 |
+
transcriber.close()
|
| 107 |
|
| 108 |
def uri_validator(x):
|
| 109 |
try:
|
src/vadParallel.py
CHANGED
|
@@ -88,14 +88,20 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 88 |
|
| 89 |
# Split into a list for each device
|
| 90 |
# TODO: Split by time instead of by number of chunks
|
| 91 |
-
merged_split = self.
|
| 92 |
|
| 93 |
# Parameters that will be passed to the transcribe function
|
| 94 |
parameters = []
|
| 95 |
segment_index = config.initial_segment_index
|
| 96 |
|
| 97 |
for i in range(len(merged_split)):
|
| 98 |
-
device_segment_list = merged_split[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# Create a new config with the given device ID
|
| 101 |
device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
|
|
@@ -159,7 +165,8 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 159 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
| 160 |
return super().transcribe(audio, whisperCallable, config)
|
| 161 |
|
| 162 |
-
def
|
| 163 |
-
"""
|
| 164 |
-
|
|
|
|
| 165 |
|
|
|
|
| 88 |
|
| 89 |
# Split into a list for each device
|
| 90 |
# TODO: Split by time instead of by number of chunks
|
| 91 |
+
merged_split = list(self._split(merged, len(devices)))
|
| 92 |
|
| 93 |
# Parameters that will be passed to the transcribe function
|
| 94 |
parameters = []
|
| 95 |
segment_index = config.initial_segment_index
|
| 96 |
|
| 97 |
for i in range(len(merged_split)):
|
| 98 |
+
device_segment_list = list(merged_split[i])
|
| 99 |
+
device_id = devices[i]
|
| 100 |
+
|
| 101 |
+
if (len(device_segment_list) <= 0):
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
print("Device " + device_id + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
|
| 105 |
|
| 106 |
# Create a new config with the given device ID
|
| 107 |
device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
|
|
|
|
| 165 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
| 166 |
return super().transcribe(audio, whisperCallable, config)
|
| 167 |
|
| 168 |
+
def _split(self, a, n):
|
| 169 |
+
"""Split a list into n approximately equal parts."""
|
| 170 |
+
k, m = divmod(len(a), n)
|
| 171 |
+
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
| 172 |
|
src/whisperContainer.py
CHANGED
|
@@ -23,9 +23,10 @@ class WhisperModelCache:
|
|
| 23 |
GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
|
| 24 |
|
| 25 |
class WhisperContainer:
|
| 26 |
-
def __init__(self, model_name: str, device: str = None, cache: WhisperModelCache = None):
|
| 27 |
self.model_name = model_name
|
| 28 |
self.device = device
|
|
|
|
| 29 |
self.cache = cache
|
| 30 |
|
| 31 |
# Will be created on demand
|
|
@@ -36,7 +37,7 @@ class WhisperContainer:
|
|
| 36 |
|
| 37 |
if (self.cache is None):
|
| 38 |
print("Loading whisper model " + self.model_name)
|
| 39 |
-
self.model = whisper.load_model(self.model_name, device=self.device)
|
| 40 |
else:
|
| 41 |
self.model = self.cache.get(self.model_name, device=self.device)
|
| 42 |
return self.model
|
|
|
|
| 23 |
GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
|
| 24 |
|
| 25 |
class WhisperContainer:
|
| 26 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: WhisperModelCache = None):
|
| 27 |
self.model_name = model_name
|
| 28 |
self.device = device
|
| 29 |
+
self.download_root = download_root
|
| 30 |
self.cache = cache
|
| 31 |
|
| 32 |
# Will be created on demand
|
|
|
|
| 37 |
|
| 38 |
if (self.cache is None):
|
| 39 |
print("Loading whisper model " + self.model_name)
|
| 40 |
+
self.model = whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
| 41 |
else:
|
| 42 |
self.model = self.cache.get(self.model_name, device=self.device)
|
| 43 |
return self.model
|