|
import torch |
|
import torchaudio |
|
from torchaudio.pipelines import SQUIM_OBJECTIVE |
|
import numpy as np |
|
from typing import Dict, Union, Any |
|
from io import BytesIO |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str, **kwargs): |
|
"""Initialize the SQUIM model handler. |
|
Sets up the model on GPU if available, otherwise on CPU. |
|
""" |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float() |
|
|
|
|
|
self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate |
|
|
|
|
|
self.model.eval() |
|
|
|
print(f"Initialized SQUIM model on device: {self.device}") |
|
|
|
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor: |
|
"""Preprocess the input audio data. |
|
|
|
Args: |
|
input_data: Either raw bytes of audio file or a dictionary containing audio data |
|
|
|
Returns: |
|
torch.Tensor: Preprocessed audio tensor ready for inference |
|
""" |
|
try: |
|
audio = input_data if isinstance(input_data, bytes) else input_data['inputs'] |
|
|
|
if isinstance(audio, bytes): |
|
|
|
audio_buffer = BytesIO(audio) |
|
waveform, sample_rate = torchaudio.load(audio_buffer) |
|
else: |
|
logger.error(f"Unsupported input type: {type(audio)}") |
|
logger.debug(f"Input data: {input_data.keys()}") |
|
raise ValueError("Unsupported input type") |
|
|
|
|
|
waveform = waveform.float() |
|
|
|
|
|
if sample_rate != self.target_sample_rate: |
|
waveform = torchaudio.functional.resample( |
|
waveform, |
|
sample_rate, |
|
self.target_sample_rate |
|
) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
waveform = waveform.to(self.device) |
|
|
|
return waveform |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error in preprocessing: {str(e)}") |
|
|
|
def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]: |
|
"""Run inference with the SQUIM model. |
|
|
|
Args: |
|
audio_tensor: Preprocessed audio tensor |
|
|
|
Returns: |
|
Dictionary containing the quality metrics |
|
""" |
|
try: |
|
with torch.no_grad(): |
|
stoi, pesq, si_sdr = self.model(audio_tensor) |
|
|
|
return { |
|
"stoi": stoi.item(), |
|
"pesq": pesq.item(), |
|
"si_sdr": si_sdr.item() |
|
} |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error during inference: {str(e)}") |
|
|
|
def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]: |
|
"""Postprocess the model output. |
|
|
|
Args: |
|
model_output: Dictionary containing the raw model outputs |
|
|
|
Returns: |
|
Dictionary containing the formatted results with additional metadata |
|
""" |
|
return { |
|
"metrics": model_output, |
|
"metadata": { |
|
"model_name": "SQUIM", |
|
"device": str(self.device), |
|
"sample_rate": self.target_sample_rate |
|
} |
|
} |
|
|
|
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]: |
|
"""Main entry point for the handler. |
|
|
|
Args: |
|
input_data: Raw input data |
|
|
|
Returns: |
|
Processed results with quality metrics |
|
""" |
|
try: |
|
|
|
audio_tensor = self.preprocess(input_data) |
|
predictions = self.predict(audio_tensor) |
|
final_output = self.postprocess(predictions) |
|
|
|
return final_output |
|
|
|
except Exception as e: |
|
return { |
|
"error": str(e), |
|
"status": "error" |
|
} |
|
|