|  | import multiprocessing | 
					
						
						|  | import threading | 
					
						
						|  | import time | 
					
						
						|  | from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration | 
					
						
						|  | from src.whisperContainer import WhisperCallback | 
					
						
						|  |  | 
					
						
						|  | from multiprocessing import Pool | 
					
						
						|  |  | 
					
						
						|  | from typing import Any, Dict, List | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ParallelContext: | 
					
						
						|  | def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None): | 
					
						
						|  | self.num_processes = num_processes | 
					
						
						|  | self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds | 
					
						
						|  | self.lock = threading.Lock() | 
					
						
						|  |  | 
					
						
						|  | self.ref_count = 0 | 
					
						
						|  | self.pool = None | 
					
						
						|  | self.cleanup_timer = None | 
					
						
						|  |  | 
					
						
						|  | def get_pool(self): | 
					
						
						|  |  | 
					
						
						|  | if (self.pool is None): | 
					
						
						|  | context = multiprocessing.get_context('spawn') | 
					
						
						|  | self.pool = context.Pool(self.num_processes) | 
					
						
						|  |  | 
					
						
						|  | self.ref_count = self.ref_count + 1 | 
					
						
						|  |  | 
					
						
						|  | if (self.auto_cleanup_timeout_seconds is not None): | 
					
						
						|  | self._stop_auto_cleanup() | 
					
						
						|  |  | 
					
						
						|  | return self.pool | 
					
						
						|  |  | 
					
						
						|  | def return_pool(self, pool): | 
					
						
						|  | if (self.pool == pool and self.ref_count > 0): | 
					
						
						|  | self.ref_count = self.ref_count - 1 | 
					
						
						|  |  | 
					
						
						|  | if (self.ref_count == 0): | 
					
						
						|  | if (self.auto_cleanup_timeout_seconds is not None): | 
					
						
						|  | self._start_auto_cleanup() | 
					
						
						|  |  | 
					
						
						|  | def _start_auto_cleanup(self): | 
					
						
						|  | if (self.cleanup_timer is not None): | 
					
						
						|  | self.cleanup_timer.cancel() | 
					
						
						|  | self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup) | 
					
						
						|  | self.cleanup_timer.start() | 
					
						
						|  |  | 
					
						
						|  | print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds") | 
					
						
						|  |  | 
					
						
						|  | def _stop_auto_cleanup(self): | 
					
						
						|  | if (self.cleanup_timer is not None): | 
					
						
						|  | self.cleanup_timer.cancel() | 
					
						
						|  | self.cleanup_timer = None | 
					
						
						|  |  | 
					
						
						|  | print("Stopped auto cleanup of pool") | 
					
						
						|  |  | 
					
						
						|  | def _execute_cleanup(self): | 
					
						
						|  | print("Executing cleanup of pool") | 
					
						
						|  |  | 
					
						
						|  | if (self.ref_count == 0): | 
					
						
						|  | self.close() | 
					
						
						|  |  | 
					
						
						|  | def close(self): | 
					
						
						|  | self._stop_auto_cleanup() | 
					
						
						|  |  | 
					
						
						|  | if (self.pool is not None): | 
					
						
						|  | print("Closing pool of " + str(self.num_processes) + " processes") | 
					
						
						|  | self.pool.close() | 
					
						
						|  | self.pool.join() | 
					
						
						|  | self.pool = None | 
					
						
						|  |  | 
					
						
						|  | class ParallelTranscriptionConfig(TranscriptionConfig): | 
					
						
						|  | def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None): | 
					
						
						|  | super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index) | 
					
						
						|  | self.device_id = device_id | 
					
						
						|  | self.override_timestamps = override_timestamps | 
					
						
						|  |  | 
					
						
						|  | class ParallelTranscription(AbstractTranscription): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60 | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, sampling_rate: int = 16000): | 
					
						
						|  | super().__init__(sampling_rate=sampling_rate) | 
					
						
						|  |  | 
					
						
						|  | def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, | 
					
						
						|  | cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None): | 
					
						
						|  | total_duration = get_audio_duration(audio) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()): | 
					
						
						|  | merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context) | 
					
						
						|  | else: | 
					
						
						|  | timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration) | 
					
						
						|  | merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (len(gpu_devices) > 1): | 
					
						
						|  | whisperCallable.model_container.ensure_downloaded() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | merged_split = list(self._split(merged, len(gpu_devices))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | parameters = [] | 
					
						
						|  | segment_index = config.initial_segment_index | 
					
						
						|  |  | 
					
						
						|  | for i in range(len(gpu_devices)): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device_segment_list = list(merged_split[i]) if i < len(merged_split) else [] | 
					
						
						|  | device_id = gpu_devices[i] | 
					
						
						|  |  | 
					
						
						|  | print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config) | 
					
						
						|  | segment_index += len(device_segment_list) | 
					
						
						|  |  | 
					
						
						|  | parameters.append([audio, whisperCallable, device_config]); | 
					
						
						|  |  | 
					
						
						|  | merged = { | 
					
						
						|  | 'text': '', | 
					
						
						|  | 'segments': [], | 
					
						
						|  | 'language': None | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | created_context = False | 
					
						
						|  |  | 
					
						
						|  | perf_start_gpu = time.perf_counter() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if (gpu_parallel_context is None): | 
					
						
						|  | gpu_parallel_context = ParallelContext(len(gpu_devices)) | 
					
						
						|  | created_context = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pool = gpu_parallel_context.get_pool() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | results = pool.starmap(self.transcribe, parameters) | 
					
						
						|  |  | 
					
						
						|  | for result in results: | 
					
						
						|  |  | 
					
						
						|  | if (result['text'] is not None): | 
					
						
						|  | merged['text'] += result['text'] | 
					
						
						|  | if (result['segments'] is not None): | 
					
						
						|  | merged['segments'].extend(result['segments']) | 
					
						
						|  | if (result['language'] is not None): | 
					
						
						|  | merged['language'] = result['language'] | 
					
						
						|  |  | 
					
						
						|  | finally: | 
					
						
						|  |  | 
					
						
						|  | if (gpu_parallel_context is not None): | 
					
						
						|  | gpu_parallel_context.return_pool(pool) | 
					
						
						|  |  | 
					
						
						|  | if (created_context): | 
					
						
						|  | gpu_parallel_context.close() | 
					
						
						|  |  | 
					
						
						|  | perf_end_gpu = time.perf_counter() | 
					
						
						|  | print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds") | 
					
						
						|  |  | 
					
						
						|  | return merged | 
					
						
						|  |  | 
					
						
						|  | def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float, | 
					
						
						|  | cpu_device_count: int, cpu_parallel_context: ParallelContext = None): | 
					
						
						|  | parameters = [] | 
					
						
						|  |  | 
					
						
						|  | chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS) | 
					
						
						|  | chunk_start = 0 | 
					
						
						|  | cpu_device_id = 0 | 
					
						
						|  |  | 
					
						
						|  | perf_start_time = time.perf_counter() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | while (chunk_start < total_duration): | 
					
						
						|  | chunk_end = min(chunk_start + chunk_size, total_duration) | 
					
						
						|  |  | 
					
						
						|  | if (chunk_end - chunk_start < 1): | 
					
						
						|  |  | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " + | 
					
						
						|  | str(chunk_end) + " on CPU device " + str(cpu_device_id)) | 
					
						
						|  | parameters.append([audio, config, chunk_start, chunk_end]); | 
					
						
						|  |  | 
					
						
						|  | cpu_device_id += 1 | 
					
						
						|  | chunk_start = chunk_end | 
					
						
						|  |  | 
					
						
						|  | created_context = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | if (cpu_parallel_context is None): | 
					
						
						|  | cpu_parallel_context = ParallelContext(cpu_device_count) | 
					
						
						|  | created_context = True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pool = cpu_parallel_context.get_pool() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | results = pool.starmap(transcription.get_transcribe_timestamps, parameters) | 
					
						
						|  |  | 
					
						
						|  | timestamps = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for result in results: | 
					
						
						|  | timestamps.extend(result) | 
					
						
						|  |  | 
					
						
						|  | merged = transcription.get_merged_timestamps(timestamps, config, total_duration) | 
					
						
						|  |  | 
					
						
						|  | perf_end_time = time.perf_counter() | 
					
						
						|  | print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time)) | 
					
						
						|  | return merged | 
					
						
						|  |  | 
					
						
						|  | finally: | 
					
						
						|  |  | 
					
						
						|  | if (cpu_parallel_context is not None): | 
					
						
						|  | cpu_parallel_context.return_pool(pool) | 
					
						
						|  |  | 
					
						
						|  | if (created_context): | 
					
						
						|  | cpu_parallel_context.close() | 
					
						
						|  |  | 
					
						
						|  | def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float): | 
					
						
						|  | return [] | 
					
						
						|  |  | 
					
						
						|  | def get_merged_timestamps(self,  timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float): | 
					
						
						|  |  | 
					
						
						|  | if (config.override_timestamps is not None): | 
					
						
						|  | print("Using override timestamps of size " + str(len(config.override_timestamps))) | 
					
						
						|  | return config.override_timestamps | 
					
						
						|  | return super().get_merged_timestamps(timestamps, config, total_duration) | 
					
						
						|  |  | 
					
						
						|  | def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig): | 
					
						
						|  |  | 
					
						
						|  | if (os.environ.get("INITIALIZED", None) is None): | 
					
						
						|  | os.environ["INITIALIZED"] = "1" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (config.device_id is not None): | 
					
						
						|  | print("Using device " + config.device_id) | 
					
						
						|  | os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id | 
					
						
						|  |  | 
					
						
						|  | return super().transcribe(audio, whisperCallable, config) | 
					
						
						|  |  | 
					
						
						|  | def _split(self, a, n): | 
					
						
						|  | """Split a list into n approximately equal parts.""" | 
					
						
						|  | k, m = divmod(len(a), n) | 
					
						
						|  | return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)) | 
					
						
						|  |  | 
					
						
						|  |  |