oza75 commited on
Commit
5575b68
·
verified ·
1 Parent(s): 5336376

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +135 -0
handler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from huggingface_hub import hf_hub_download
4
+ from typing import Dict, Union, Any
5
+ from io import BytesIO
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, model_dir: str, **kwargs):
12
+ """Initialize the ECAPA2 speaker embedding model handler.
13
+ Sets up the model on GPU if available, otherwise on CPU.
14
+ """
15
+ # Determine the device to use
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Download and initialize the ECAPA2 model
19
+ model_file = hf_hub_download(repo_id='oza75/ECAPA2', filename='ecapa2.pt', cache_dir=model_dir)
20
+ self.model = torch.jit.load(model_file, map_location=self.device)
21
+
22
+ # Convert to half precision if using CUDA for faster inference
23
+ if torch.cuda.is_available():
24
+ self.model.half()
25
+
26
+ # Expected sample rate for the model
27
+ self.target_sample_rate = 16000 # ECAPA2 expects 16kHz audio
28
+
29
+ print(f"Initialized ECAPA2 model on device: {self.device}")
30
+
31
+ def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
32
+ """Preprocess the input audio data.
33
+
34
+ Args:
35
+ input_data: Either raw bytes of audio file or a dictionary containing audio data
36
+
37
+ Returns:
38
+ torch.Tensor: Preprocessed audio tensor ready for inference
39
+ """
40
+ try:
41
+ audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
42
+
43
+ # Handle different input types
44
+ if isinstance(audio, bytes):
45
+ # Load audio from bytes
46
+ audio_buffer = BytesIO(audio)
47
+ waveform, sample_rate = torchaudio.load(audio_buffer)
48
+ else:
49
+ logger.error(f"Unsupported input type: {type(audio)}")
50
+ logger.debug(f"Input data: {input_data.keys()}")
51
+ raise ValueError("Unsupported input type")
52
+
53
+ # Convert to float32 or float16 if using CUDA
54
+ waveform = waveform.to(torch.float16 if torch.cuda.is_available() else torch.float32)
55
+
56
+ # Resample if necessary
57
+ if sample_rate != self.target_sample_rate:
58
+ waveform = torchaudio.functional.resample(
59
+ waveform,
60
+ sample_rate,
61
+ self.target_sample_rate
62
+ )
63
+
64
+ # If stereo, convert to mono by averaging channels
65
+ if waveform.shape[0] > 1:
66
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
67
+
68
+ # Move to appropriate device
69
+ waveform = waveform.to(self.device)
70
+
71
+ return waveform
72
+
73
+ except Exception as e:
74
+ raise RuntimeError(f"Error in preprocessing: {str(e)}")
75
+
76
+ def predict(self, audio_tensor: torch.Tensor) -> torch.Tensor:
77
+ """Run inference with the ECAPA2 model.
78
+
79
+ Args:
80
+ audio_tensor: Preprocessed audio tensor
81
+
82
+ Returns:
83
+ Speaker embedding tensor
84
+ """
85
+ try:
86
+ embedding = self.model(audio_tensor)
87
+ return embedding
88
+
89
+ except Exception as e:
90
+ raise RuntimeError(f"Error during inference: {str(e)}")
91
+
92
+ def postprocess(self, embedding: torch.Tensor) -> Dict[str, Any]:
93
+ """Postprocess the model output.
94
+
95
+ Args:
96
+ embedding: Speaker embedding tensor from the model
97
+
98
+ Returns:
99
+ Dictionary containing the embedding and metadata
100
+ """
101
+ # Convert embedding to numpy array for JSON serialization
102
+ embedding_np = embedding.cpu().numpy()
103
+
104
+ return {
105
+ "embedding": embedding_np.tolist(),
106
+ "embedding_dimension": embedding_np.shape[-1],
107
+ "metadata": {
108
+ "model_name": "ECAPA2",
109
+ "device": str(self.device),
110
+ "sample_rate": self.target_sample_rate
111
+ }
112
+ }
113
+
114
+ def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
115
+ """Main entry point for the handler.
116
+
117
+ Args:
118
+ input_data: Raw input data
119
+
120
+ Returns:
121
+ Processed results with speaker embedding
122
+ """
123
+ try:
124
+ # Execute the full pipeline
125
+ audio_tensor = self.preprocess(input_data)
126
+ embedding = self.predict(audio_tensor)
127
+ final_output = self.postprocess(embedding)
128
+
129
+ return final_output
130
+
131
+ except Exception as e:
132
+ return {
133
+ "error": str(e),
134
+ "status": "error"
135
+ }