|
from typing import List |
|
from src.diarization.diarization import Diarization, DiarizationEntry |
|
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache |
|
from src.vadParallel import ParallelContext |
|
|
|
class DiarizationContainer: |
|
def __init__(self, auth_token: str = None, enable_daemon_process: bool = True, auto_cleanup_timeout_seconds=60, cache: ModelCache = None): |
|
self.auth_token = auth_token |
|
self.enable_daemon_process = enable_daemon_process |
|
self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds |
|
self.diarization_context: ParallelContext = None |
|
self.cache = cache |
|
self.model = None |
|
|
|
def run(self, audio_file, **kwargs): |
|
|
|
if self.diarization_context is None and self.enable_daemon_process: |
|
|
|
self.diarization_context = ParallelContext(num_processes=1) |
|
|
|
|
|
if self.diarization_context is None: |
|
return self.execute(audio_file, **kwargs) |
|
|
|
|
|
pool = self.diarization_context.get_pool() |
|
|
|
try: |
|
result = pool.apply(self.execute, (audio_file,), kwargs) |
|
return result |
|
finally: |
|
self.diarization_context.return_pool(pool) |
|
|
|
def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict): |
|
if self.model is not None: |
|
return self.model.mark_speakers(diarization_result, whisper_result) |
|
|
|
|
|
model = Diarization(self.auth_token) |
|
return model.mark_speakers(diarization_result, whisper_result) |
|
|
|
def get_model(self): |
|
|
|
if (self.model is None): |
|
if self.cache: |
|
print("Loading diarization model from cache") |
|
self.model = self.cache.get("diarization", lambda : Diarization(self.auth_token)) |
|
else: |
|
print("Loading diarization model") |
|
self.model = Diarization(self.auth_token) |
|
return self.model |
|
|
|
def execute(self, audio_file, **kwargs): |
|
model = self.get_model() |
|
|
|
|
|
result = list(model.run(audio_file, **kwargs)) |
|
return result |
|
|
|
def cleanup(self): |
|
if self.diarization_context is not None: |
|
self.diarization_context.close() |
|
|
|
def __getstate__(self): |
|
return { |
|
"auth_token": self.auth_token, |
|
"enable_daemon_process": self.enable_daemon_process, |
|
"auto_cleanup_timeout_seconds": self.auto_cleanup_timeout_seconds |
|
} |
|
|
|
def __setstate__(self, state): |
|
self.auth_token = state["auth_token"] |
|
self.enable_daemon_process = state["enable_daemon_process"] |
|
self.auto_cleanup_timeout_seconds = state["auto_cleanup_timeout_seconds"] |
|
self.diarization_context = None |
|
self.cache = GLOBAL_MODEL_CACHE |
|
self.model = None |