File size: 4,625 Bytes
5575b68 6fc5349 8c5eea2 5575b68 6fc5349 5575b68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from typing import Dict, Union, Any
from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs):
"""Initialize the ECAPA2 speaker embedding model handler.
Sets up the model on GPU if available, otherwise on CPU.
"""
# Determine the device to use
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download and initialize the ECAPA2 model
model_file = hf_hub_download(repo_id='oza75/ECAPA2', filename='ecapa2.pt', cache_dir=model_dir)
# Load model in float32 precision initially
self.model = torch.jit.load(model_file, map_location=self.device)
self.model.to(torch.float32)
# Expected sample rate for the model
self.target_sample_rate = 16000 # ECAPA2 expects 16kHz audio
print(f"Initialized ECAPA2 model on device: {self.device}")
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
"""Preprocess the input audio data.
Args:
input_data: Either raw bytes of audio file or a dictionary containing audio data
Returns:
torch.Tensor: Preprocessed audio tensor ready for inference
"""
try:
audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
# Handle different input types
if isinstance(audio, bytes):
# Load audio from bytes
audio_buffer = BytesIO(audio)
waveform, sample_rate = torchaudio.load(audio_buffer)
else:
logger.error(f"Unsupported input type: {type(audio)}")
logger.debug(f"Input data: {input_data.keys()}")
raise ValueError("Unsupported input type")
# Always use float32 for processing
waveform = waveform.to(torch.float32)
# Resample if necessary
if sample_rate != self.target_sample_rate:
waveform = torchaudio.functional.resample(
waveform,
sample_rate,
self.target_sample_rate
)
# If stereo, convert to mono by averaging channels
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Move to appropriate device
waveform = waveform.to(self.device)
return waveform
except Exception as e:
raise RuntimeError(f"Error in preprocessing: {str(e)}")
def predict(self, audio_tensor: torch.Tensor) -> torch.Tensor:
"""Run inference with the ECAPA2 model.
Args:
audio_tensor: Preprocessed audio tensor
Returns:
Speaker embedding tensor
"""
try:
embedding = self.model(audio_tensor)
return embedding
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")
def postprocess(self, embedding: torch.Tensor) -> Dict[str, Any]:
"""Postprocess the model output.
Args:
embedding: Speaker embedding tensor from the model
Returns:
Dictionary containing the embedding and metadata
"""
# Convert embedding to numpy array for JSON serialization
embedding_np = embedding.cpu().numpy()
return {
"embedding": embedding_np.tolist(),
"embedding_dimension": embedding_np.shape[-1],
"metadata": {
"model_name": "ECAPA2",
"device": str(self.device),
"sample_rate": self.target_sample_rate
}
}
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
"""Main entry point for the handler.
Args:
input_data: Raw input data
Returns:
Processed results with speaker embedding
"""
try:
# Execute the full pipeline
audio_tensor = self.preprocess(input_data)
embedding = self.predict(audio_tensor)
final_output = self.postprocess(embedding)
return final_output
except Exception as e:
return {
"error": str(e),
"status": "error"
}
|