ECAPA2 / handler.py
oza75's picture
Update handler.py
8c5eea2 verified
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from typing import Dict, Union, Any
from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs):
"""Initialize the ECAPA2 speaker embedding model handler.
Sets up the model on GPU if available, otherwise on CPU.
"""
# Determine the device to use
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download and initialize the ECAPA2 model
model_file = hf_hub_download(repo_id='oza75/ECAPA2', filename='ecapa2.pt', cache_dir=model_dir)
# Load model in float32 precision initially
self.model = torch.jit.load(model_file, map_location=self.device)
self.model.to(torch.float32)
# Expected sample rate for the model
self.target_sample_rate = 16000 # ECAPA2 expects 16kHz audio
print(f"Initialized ECAPA2 model on device: {self.device}")
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
"""Preprocess the input audio data.
Args:
input_data: Either raw bytes of audio file or a dictionary containing audio data
Returns:
torch.Tensor: Preprocessed audio tensor ready for inference
"""
try:
audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
# Handle different input types
if isinstance(audio, bytes):
# Load audio from bytes
audio_buffer = BytesIO(audio)
waveform, sample_rate = torchaudio.load(audio_buffer)
else:
logger.error(f"Unsupported input type: {type(audio)}")
logger.debug(f"Input data: {input_data.keys()}")
raise ValueError("Unsupported input type")
# Always use float32 for processing
waveform = waveform.to(torch.float32)
# Resample if necessary
if sample_rate != self.target_sample_rate:
waveform = torchaudio.functional.resample(
waveform,
sample_rate,
self.target_sample_rate
)
# If stereo, convert to mono by averaging channels
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Move to appropriate device
waveform = waveform.to(self.device)
return waveform
except Exception as e:
raise RuntimeError(f"Error in preprocessing: {str(e)}")
def predict(self, audio_tensor: torch.Tensor) -> torch.Tensor:
"""Run inference with the ECAPA2 model.
Args:
audio_tensor: Preprocessed audio tensor
Returns:
Speaker embedding tensor
"""
try:
embedding = self.model(audio_tensor)
return embedding
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")
def postprocess(self, embedding: torch.Tensor) -> Dict[str, Any]:
"""Postprocess the model output.
Args:
embedding: Speaker embedding tensor from the model
Returns:
Dictionary containing the embedding and metadata
"""
# Convert embedding to numpy array for JSON serialization
embedding_np = embedding.cpu().numpy()
return {
"embedding": embedding_np.tolist(),
"embedding_dimension": embedding_np.shape[-1],
"metadata": {
"model_name": "ECAPA2",
"device": str(self.device),
"sample_rate": self.target_sample_rate
}
}
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
"""Main entry point for the handler.
Args:
input_data: Raw input data
Returns:
Processed results with speaker embedding
"""
try:
# Execute the full pipeline
audio_tensor = self.preprocess(input_data)
embedding = self.predict(audio_tensor)
final_output = self.postprocess(embedding)
return final_output
except Exception as e:
return {
"error": str(e),
"status": "error"
}