|
import torch |
|
import torchaudio |
|
from huggingface_hub import hf_hub_download |
|
from typing import Dict, Union, Any |
|
from io import BytesIO |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str, **kwargs): |
|
"""Initialize the ECAPA2 speaker embedding model handler. |
|
Sets up the model on GPU if available, otherwise on CPU. |
|
""" |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_file = hf_hub_download(repo_id='oza75/ECAPA2', filename='ecapa2.pt', cache_dir=model_dir) |
|
|
|
|
|
self.model = torch.jit.load(model_file, map_location=self.device) |
|
self.model.to(torch.float32) |
|
|
|
self.target_sample_rate = 16000 |
|
|
|
print(f"Initialized ECAPA2 model on device: {self.device}") |
|
|
|
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor: |
|
"""Preprocess the input audio data. |
|
|
|
Args: |
|
input_data: Either raw bytes of audio file or a dictionary containing audio data |
|
|
|
Returns: |
|
torch.Tensor: Preprocessed audio tensor ready for inference |
|
""" |
|
try: |
|
audio = input_data if isinstance(input_data, bytes) else input_data['inputs'] |
|
|
|
|
|
if isinstance(audio, bytes): |
|
|
|
audio_buffer = BytesIO(audio) |
|
waveform, sample_rate = torchaudio.load(audio_buffer) |
|
else: |
|
logger.error(f"Unsupported input type: {type(audio)}") |
|
logger.debug(f"Input data: {input_data.keys()}") |
|
raise ValueError("Unsupported input type") |
|
|
|
|
|
waveform = waveform.to(torch.float32) |
|
|
|
|
|
if sample_rate != self.target_sample_rate: |
|
waveform = torchaudio.functional.resample( |
|
waveform, |
|
sample_rate, |
|
self.target_sample_rate |
|
) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
waveform = waveform.to(self.device) |
|
|
|
return waveform |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error in preprocessing: {str(e)}") |
|
|
|
def predict(self, audio_tensor: torch.Tensor) -> torch.Tensor: |
|
"""Run inference with the ECAPA2 model. |
|
|
|
Args: |
|
audio_tensor: Preprocessed audio tensor |
|
|
|
Returns: |
|
Speaker embedding tensor |
|
""" |
|
try: |
|
embedding = self.model(audio_tensor) |
|
return embedding |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error during inference: {str(e)}") |
|
|
|
def postprocess(self, embedding: torch.Tensor) -> Dict[str, Any]: |
|
"""Postprocess the model output. |
|
|
|
Args: |
|
embedding: Speaker embedding tensor from the model |
|
|
|
Returns: |
|
Dictionary containing the embedding and metadata |
|
""" |
|
|
|
embedding_np = embedding.cpu().numpy() |
|
|
|
return { |
|
"embedding": embedding_np.tolist(), |
|
"embedding_dimension": embedding_np.shape[-1], |
|
"metadata": { |
|
"model_name": "ECAPA2", |
|
"device": str(self.device), |
|
"sample_rate": self.target_sample_rate |
|
} |
|
} |
|
|
|
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]: |
|
"""Main entry point for the handler. |
|
|
|
Args: |
|
input_data: Raw input data |
|
|
|
Returns: |
|
Processed results with speaker embedding |
|
""" |
|
try: |
|
|
|
audio_tensor = self.preprocess(input_data) |
|
embedding = self.predict(audio_tensor) |
|
final_output = self.postprocess(embedding) |
|
|
|
return final_output |
|
|
|
except Exception as e: |
|
return { |
|
"error": str(e), |
|
"status": "error" |
|
} |
|
|