oza75 commited on
Commit
be6e0e6
·
verified ·
1 Parent(s): beaebd4

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +143 -0
handler.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from torchaudio.pipelines import SQUIM_OBJECTIVE
4
+ import numpy as np
5
+ from typing import Dict, Union, Any
6
+ from io import BytesIO
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, **kwargs):
11
+ """Initialize the SQUIM model handler.
12
+ Sets up the model on GPU if available, otherwise on CPU.
13
+ """
14
+ # Determine the device to use
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Initialize the SQUIM model
18
+ self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float()
19
+
20
+ # Store the expected sample rate from the model
21
+ self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate
22
+
23
+ # Set model to evaluation mode
24
+ self.model.eval()
25
+
26
+ print(f"Initialized SQUIM model on device: {self.device}")
27
+
28
+ def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
29
+ """Preprocess the input audio data.
30
+
31
+ Args:
32
+ input_data: Either raw bytes of audio file or a dictionary containing audio data
33
+
34
+ Returns:
35
+ torch.Tensor: Preprocessed audio tensor ready for inference
36
+ """
37
+ try:
38
+ # Handle different input types
39
+ if isinstance(input_data, bytes):
40
+ # Load audio from bytes
41
+ audio_buffer = BytesIO(input_data)
42
+ waveform, sample_rate = torchaudio.load(audio_buffer)
43
+ elif isinstance(input_data, dict):
44
+ if 'audio' in input_data:
45
+ # Handle numpy array input
46
+ audio_array = input_data['audio']
47
+ if isinstance(audio_array, list):
48
+ audio_array = np.array(audio_array)
49
+ waveform = torch.from_numpy(audio_array)
50
+ sample_rate = input_data.get('sampling_rate', self.target_sample_rate)
51
+ # Ensure 2D tensor [channels, time]
52
+ if waveform.dim() == 1:
53
+ waveform = waveform.unsqueeze(0)
54
+ else:
55
+ raise ValueError("Input dictionary must contain 'audio' key")
56
+ else:
57
+ raise ValueError("Unsupported input type")
58
+
59
+ # Convert to float32
60
+ waveform = waveform.float()
61
+
62
+ # Resample if necessary
63
+ if sample_rate != self.target_sample_rate:
64
+ waveform = torchaudio.functional.resample(
65
+ waveform,
66
+ sample_rate,
67
+ self.target_sample_rate
68
+ )
69
+
70
+ # If stereo, convert to mono by averaging channels
71
+ if waveform.shape[0] > 1:
72
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
73
+
74
+ # Move to appropriate device
75
+ waveform = waveform.to(self.device)
76
+
77
+ return waveform
78
+
79
+ except Exception as e:
80
+ raise RuntimeError(f"Error in preprocessing: {str(e)}")
81
+
82
+ def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]:
83
+ """Run inference with the SQUIM model.
84
+
85
+ Args:
86
+ audio_tensor: Preprocessed audio tensor
87
+
88
+ Returns:
89
+ Dictionary containing the quality metrics
90
+ """
91
+ try:
92
+ with torch.no_grad():
93
+ stoi, pesq, si_sdr = self.model(audio_tensor)
94
+
95
+ return {
96
+ "stoi": stoi.item(),
97
+ "pesq": pesq.item(),
98
+ "si_sdr": si_sdr.item()
99
+ }
100
+
101
+ except Exception as e:
102
+ raise RuntimeError(f"Error during inference: {str(e)}")
103
+
104
+ def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]:
105
+ """Postprocess the model output.
106
+
107
+ Args:
108
+ model_output: Dictionary containing the raw model outputs
109
+
110
+ Returns:
111
+ Dictionary containing the formatted results with additional metadata
112
+ """
113
+ return {
114
+ "metrics": model_output,
115
+ "metadata": {
116
+ "model_name": "SQUIM",
117
+ "device": str(self.device),
118
+ "sample_rate": self.target_sample_rate
119
+ }
120
+ }
121
+
122
+ def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
123
+ """Main entry point for the handler.
124
+
125
+ Args:
126
+ input_data: Raw input data
127
+
128
+ Returns:
129
+ Processed results with quality metrics
130
+ """
131
+ try:
132
+ # Execute the full pipeline
133
+ audio_tensor = self.preprocess(input_data)
134
+ predictions = self.predict(audio_tensor)
135
+ final_output = self.postprocess(predictions)
136
+
137
+ return final_output
138
+
139
+ except Exception as e:
140
+ return {
141
+ "error": str(e),
142
+ "status": "error"
143
+ }