torchaudio-squim / handler.py
oza75's picture
Update handler.py
3c87043 verified
import torch
import torchaudio
from torchaudio.pipelines import SQUIM_OBJECTIVE
import numpy as np
from typing import Dict, Union, Any
from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs):
"""Initialize the SQUIM model handler.
Sets up the model on GPU if available, otherwise on CPU.
"""
# Determine the device to use
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the SQUIM model
self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float()
# Store the expected sample rate from the model
self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate
# Set model to evaluation mode
self.model.eval()
print(f"Initialized SQUIM model on device: {self.device}")
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
"""Preprocess the input audio data.
Args:
input_data: Either raw bytes of audio file or a dictionary containing audio data
Returns:
torch.Tensor: Preprocessed audio tensor ready for inference
"""
try:
audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
# Handle different input types
if isinstance(audio, bytes):
# Load audio from bytes
audio_buffer = BytesIO(audio)
waveform, sample_rate = torchaudio.load(audio_buffer)
else:
logger.error(f"Unsupported input type: {type(audio)}")
logger.debug(f"Input data: {input_data.keys()}")
raise ValueError("Unsupported input type")
# Convert to float32
waveform = waveform.float()
# Resample if necessary
if sample_rate != self.target_sample_rate:
waveform = torchaudio.functional.resample(
waveform,
sample_rate,
self.target_sample_rate
)
# If stereo, convert to mono by averaging channels
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Move to appropriate device
waveform = waveform.to(self.device)
return waveform
except Exception as e:
raise RuntimeError(f"Error in preprocessing: {str(e)}")
def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]:
"""Run inference with the SQUIM model.
Args:
audio_tensor: Preprocessed audio tensor
Returns:
Dictionary containing the quality metrics
"""
try:
with torch.no_grad():
stoi, pesq, si_sdr = self.model(audio_tensor)
return {
"stoi": stoi.item(),
"pesq": pesq.item(),
"si_sdr": si_sdr.item()
}
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")
def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]:
"""Postprocess the model output.
Args:
model_output: Dictionary containing the raw model outputs
Returns:
Dictionary containing the formatted results with additional metadata
"""
return {
"metrics": model_output,
"metadata": {
"model_name": "SQUIM",
"device": str(self.device),
"sample_rate": self.target_sample_rate
}
}
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
"""Main entry point for the handler.
Args:
input_data: Raw input data
Returns:
Processed results with quality metrics
"""
try:
# Execute the full pipeline
audio_tensor = self.preprocess(input_data)
predictions = self.predict(audio_tensor)
final_output = self.postprocess(predictions)
return final_output
except Exception as e:
return {
"error": str(e),
"status": "error"
}