import torch import torchaudio from huggingface_hub import hf_hub_download from typing import Dict, Union, Any from io import BytesIO import logging logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_dir: str, **kwargs): """Initialize the ECAPA2 speaker embedding model handler. Sets up the model on GPU if available, otherwise on CPU. """ # Determine the device to use self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Download and initialize the ECAPA2 model model_file = hf_hub_download(repo_id='oza75/ECAPA2', filename='ecapa2.pt', cache_dir=model_dir) # Load model in float32 precision initially self.model = torch.jit.load(model_file, map_location=self.device) self.model.to(torch.float32) # Expected sample rate for the model self.target_sample_rate = 16000 # ECAPA2 expects 16kHz audio print(f"Initialized ECAPA2 model on device: {self.device}") def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor: """Preprocess the input audio data. Args: input_data: Either raw bytes of audio file or a dictionary containing audio data Returns: torch.Tensor: Preprocessed audio tensor ready for inference """ try: audio = input_data if isinstance(input_data, bytes) else input_data['inputs'] # Handle different input types if isinstance(audio, bytes): # Load audio from bytes audio_buffer = BytesIO(audio) waveform, sample_rate = torchaudio.load(audio_buffer) else: logger.error(f"Unsupported input type: {type(audio)}") logger.debug(f"Input data: {input_data.keys()}") raise ValueError("Unsupported input type") # Always use float32 for processing waveform = waveform.to(torch.float32) # Resample if necessary if sample_rate != self.target_sample_rate: waveform = torchaudio.functional.resample( waveform, sample_rate, self.target_sample_rate ) # If stereo, convert to mono by averaging channels if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Move to appropriate device waveform = waveform.to(self.device) return waveform except Exception as e: raise RuntimeError(f"Error in preprocessing: {str(e)}") def predict(self, audio_tensor: torch.Tensor) -> torch.Tensor: """Run inference with the ECAPA2 model. Args: audio_tensor: Preprocessed audio tensor Returns: Speaker embedding tensor """ try: embedding = self.model(audio_tensor) return embedding except Exception as e: raise RuntimeError(f"Error during inference: {str(e)}") def postprocess(self, embedding: torch.Tensor) -> Dict[str, Any]: """Postprocess the model output. Args: embedding: Speaker embedding tensor from the model Returns: Dictionary containing the embedding and metadata """ # Convert embedding to numpy array for JSON serialization embedding_np = embedding.cpu().numpy() return { "embedding": embedding_np.tolist(), "embedding_dimension": embedding_np.shape[-1], "metadata": { "model_name": "ECAPA2", "device": str(self.device), "sample_rate": self.target_sample_rate } } def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]: """Main entry point for the handler. Args: input_data: Raw input data Returns: Processed results with speaker embedding """ try: # Execute the full pipeline audio_tensor = self.preprocess(input_data) embedding = self.predict(audio_tensor) final_output = self.postprocess(embedding) return final_output except Exception as e: return { "error": str(e), "status": "error" }