File size: 4,625 Bytes
5575b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fc5349
 
8c5eea2
5575b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fc5349
 
5575b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from typing import Dict, Union, Any
from io import BytesIO
import logging

logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs):
        """Initialize the ECAPA2 speaker embedding model handler.
        Sets up the model on GPU if available, otherwise on CPU.
        """
        # Determine the device to use
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Download and initialize the ECAPA2 model
        model_file = hf_hub_download(repo_id='oza75/ECAPA2', filename='ecapa2.pt', cache_dir=model_dir)
        
        # Load model in float32 precision initially
        self.model = torch.jit.load(model_file, map_location=self.device)
        self.model.to(torch.float32)
        # Expected sample rate for the model
        self.target_sample_rate = 16000  # ECAPA2 expects 16kHz audio
        
        print(f"Initialized ECAPA2 model on device: {self.device}")

    def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
        """Preprocess the input audio data.
        
        Args:
            input_data: Either raw bytes of audio file or a dictionary containing audio data
            
        Returns:
            torch.Tensor: Preprocessed audio tensor ready for inference
        """
        try:
            audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
            
            # Handle different input types
            if isinstance(audio, bytes):
                # Load audio from bytes
                audio_buffer = BytesIO(audio)
                waveform, sample_rate = torchaudio.load(audio_buffer)
            else:
                logger.error(f"Unsupported input type: {type(audio)}")
                logger.debug(f"Input data: {input_data.keys()}")
                raise ValueError("Unsupported input type")

            # Always use float32 for processing
            waveform = waveform.to(torch.float32)

            # Resample if necessary
            if sample_rate != self.target_sample_rate:
                waveform = torchaudio.functional.resample(
                    waveform, 
                    sample_rate, 
                    self.target_sample_rate
                )

            # If stereo, convert to mono by averaging channels
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Move to appropriate device
            waveform = waveform.to(self.device)

            return waveform

        except Exception as e:
            raise RuntimeError(f"Error in preprocessing: {str(e)}")

    def predict(self, audio_tensor: torch.Tensor) -> torch.Tensor:
        """Run inference with the ECAPA2 model.
        
        Args:
            audio_tensor: Preprocessed audio tensor
            
        Returns:
            Speaker embedding tensor
        """
        try:
            embedding = self.model(audio_tensor)
            return embedding

        except Exception as e:
            raise RuntimeError(f"Error during inference: {str(e)}")

    def postprocess(self, embedding: torch.Tensor) -> Dict[str, Any]:
        """Postprocess the model output.
        
        Args:
            embedding: Speaker embedding tensor from the model
            
        Returns:
            Dictionary containing the embedding and metadata
        """
        # Convert embedding to numpy array for JSON serialization
        embedding_np = embedding.cpu().numpy()
        
        return {
            "embedding": embedding_np.tolist(),
            "embedding_dimension": embedding_np.shape[-1],
            "metadata": {
                "model_name": "ECAPA2",
                "device": str(self.device),
                "sample_rate": self.target_sample_rate
            }
        }

    def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
        """Main entry point for the handler.
        
        Args:
            input_data: Raw input data
            
        Returns:
            Processed results with speaker embedding
        """
        try:
            # Execute the full pipeline
            audio_tensor = self.preprocess(input_data)
            embedding = self.predict(audio_tensor)
            final_output = self.postprocess(embedding)
            
            return final_output

        except Exception as e:
            return {
                "error": str(e),
                "status": "error"
            }