import torch import torchaudio from torchaudio.pipelines import SQUIM_OBJECTIVE import numpy as np from typing import Dict, Union, Any from io import BytesIO import logging logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, model_dir: str, **kwargs): """Initialize the SQUIM model handler. Sets up the model on GPU if available, otherwise on CPU. """ # Determine the device to use self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize the SQUIM model self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float() # Store the expected sample rate from the model self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate # Set model to evaluation mode self.model.eval() print(f"Initialized SQUIM model on device: {self.device}") def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor: """Preprocess the input audio data. Args: input_data: Either raw bytes of audio file or a dictionary containing audio data Returns: torch.Tensor: Preprocessed audio tensor ready for inference """ try: audio = input_data if isinstance(input_data, bytes) else input_data['inputs'] # Handle different input types if isinstance(audio, bytes): # Load audio from bytes audio_buffer = BytesIO(audio) waveform, sample_rate = torchaudio.load(audio_buffer) else: logger.error(f"Unsupported input type: {type(audio)}") logger.debug(f"Input data: {input_data.keys()}") raise ValueError("Unsupported input type") # Convert to float32 waveform = waveform.float() # Resample if necessary if sample_rate != self.target_sample_rate: waveform = torchaudio.functional.resample( waveform, sample_rate, self.target_sample_rate ) # If stereo, convert to mono by averaging channels if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Move to appropriate device waveform = waveform.to(self.device) return waveform except Exception as e: raise RuntimeError(f"Error in preprocessing: {str(e)}") def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]: """Run inference with the SQUIM model. Args: audio_tensor: Preprocessed audio tensor Returns: Dictionary containing the quality metrics """ try: with torch.no_grad(): stoi, pesq, si_sdr = self.model(audio_tensor) return { "stoi": stoi.item(), "pesq": pesq.item(), "si_sdr": si_sdr.item() } except Exception as e: raise RuntimeError(f"Error during inference: {str(e)}") def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]: """Postprocess the model output. Args: model_output: Dictionary containing the raw model outputs Returns: Dictionary containing the formatted results with additional metadata """ return { "metrics": model_output, "metadata": { "model_name": "SQUIM", "device": str(self.device), "sample_rate": self.target_sample_rate } } def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]: """Main entry point for the handler. Args: input_data: Raw input data Returns: Processed results with quality metrics """ try: # Execute the full pipeline audio_tensor = self.preprocess(input_data) predictions = self.predict(audio_tensor) final_output = self.postprocess(predictions) return final_output except Exception as e: return { "error": str(e), "status": "error" }