|
import torch |
|
import torchaudio |
|
from torchaudio.pipelines import SQUIM_OBJECTIVE |
|
import numpy as np |
|
from typing import Dict, Union, Any |
|
from io import BytesIO |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, **kwargs): |
|
"""Initialize the SQUIM model handler. |
|
Sets up the model on GPU if available, otherwise on CPU. |
|
""" |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float() |
|
|
|
|
|
self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate |
|
|
|
|
|
self.model.eval() |
|
|
|
print(f"Initialized SQUIM model on device: {self.device}") |
|
|
|
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor: |
|
"""Preprocess the input audio data. |
|
|
|
Args: |
|
input_data: Either raw bytes of audio file or a dictionary containing audio data |
|
|
|
Returns: |
|
torch.Tensor: Preprocessed audio tensor ready for inference |
|
""" |
|
try: |
|
|
|
if isinstance(input_data, bytes): |
|
|
|
audio_buffer = BytesIO(input_data) |
|
waveform, sample_rate = torchaudio.load(audio_buffer) |
|
elif isinstance(input_data, dict): |
|
if 'audio' in input_data: |
|
|
|
audio_array = input_data['audio'] |
|
if isinstance(audio_array, list): |
|
audio_array = np.array(audio_array) |
|
waveform = torch.from_numpy(audio_array) |
|
sample_rate = input_data.get('sampling_rate', self.target_sample_rate) |
|
|
|
if waveform.dim() == 1: |
|
waveform = waveform.unsqueeze(0) |
|
else: |
|
raise ValueError("Input dictionary must contain 'audio' key") |
|
else: |
|
raise ValueError("Unsupported input type") |
|
|
|
|
|
waveform = waveform.float() |
|
|
|
|
|
if sample_rate != self.target_sample_rate: |
|
waveform = torchaudio.functional.resample( |
|
waveform, |
|
sample_rate, |
|
self.target_sample_rate |
|
) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
waveform = waveform.to(self.device) |
|
|
|
return waveform |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error in preprocessing: {str(e)}") |
|
|
|
def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]: |
|
"""Run inference with the SQUIM model. |
|
|
|
Args: |
|
audio_tensor: Preprocessed audio tensor |
|
|
|
Returns: |
|
Dictionary containing the quality metrics |
|
""" |
|
try: |
|
with torch.no_grad(): |
|
stoi, pesq, si_sdr = self.model(audio_tensor) |
|
|
|
return { |
|
"stoi": stoi.item(), |
|
"pesq": pesq.item(), |
|
"si_sdr": si_sdr.item() |
|
} |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error during inference: {str(e)}") |
|
|
|
def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]: |
|
"""Postprocess the model output. |
|
|
|
Args: |
|
model_output: Dictionary containing the raw model outputs |
|
|
|
Returns: |
|
Dictionary containing the formatted results with additional metadata |
|
""" |
|
return { |
|
"metrics": model_output, |
|
"metadata": { |
|
"model_name": "SQUIM", |
|
"device": str(self.device), |
|
"sample_rate": self.target_sample_rate |
|
} |
|
} |
|
|
|
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]: |
|
"""Main entry point for the handler. |
|
|
|
Args: |
|
input_data: Raw input data |
|
|
|
Returns: |
|
Processed results with quality metrics |
|
""" |
|
try: |
|
|
|
audio_tensor = self.preprocess(input_data) |
|
predictions = self.predict(audio_tensor) |
|
final_output = self.postprocess(predictions) |
|
|
|
return final_output |
|
|
|
except Exception as e: |
|
return { |
|
"error": str(e), |
|
"status": "error" |
|
} |
|
|