File size: 4,492 Bytes
be6e0e6 145d68f be6e0e6 73a666a be6e0e6 3c87043 be6e0e6 3c87043 be6e0e6 3c87043 be6e0e6 3c87043 be6e0e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torchaudio
from torchaudio.pipelines import SQUIM_OBJECTIVE
import numpy as np
from typing import Dict, Union, Any
from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs):
"""Initialize the SQUIM model handler.
Sets up the model on GPU if available, otherwise on CPU.
"""
# Determine the device to use
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the SQUIM model
self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float()
# Store the expected sample rate from the model
self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate
# Set model to evaluation mode
self.model.eval()
print(f"Initialized SQUIM model on device: {self.device}")
def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
"""Preprocess the input audio data.
Args:
input_data: Either raw bytes of audio file or a dictionary containing audio data
Returns:
torch.Tensor: Preprocessed audio tensor ready for inference
"""
try:
audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
# Handle different input types
if isinstance(audio, bytes):
# Load audio from bytes
audio_buffer = BytesIO(audio)
waveform, sample_rate = torchaudio.load(audio_buffer)
else:
logger.error(f"Unsupported input type: {type(audio)}")
logger.debug(f"Input data: {input_data.keys()}")
raise ValueError("Unsupported input type")
# Convert to float32
waveform = waveform.float()
# Resample if necessary
if sample_rate != self.target_sample_rate:
waveform = torchaudio.functional.resample(
waveform,
sample_rate,
self.target_sample_rate
)
# If stereo, convert to mono by averaging channels
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Move to appropriate device
waveform = waveform.to(self.device)
return waveform
except Exception as e:
raise RuntimeError(f"Error in preprocessing: {str(e)}")
def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]:
"""Run inference with the SQUIM model.
Args:
audio_tensor: Preprocessed audio tensor
Returns:
Dictionary containing the quality metrics
"""
try:
with torch.no_grad():
stoi, pesq, si_sdr = self.model(audio_tensor)
return {
"stoi": stoi.item(),
"pesq": pesq.item(),
"si_sdr": si_sdr.item()
}
except Exception as e:
raise RuntimeError(f"Error during inference: {str(e)}")
def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]:
"""Postprocess the model output.
Args:
model_output: Dictionary containing the raw model outputs
Returns:
Dictionary containing the formatted results with additional metadata
"""
return {
"metrics": model_output,
"metadata": {
"model_name": "SQUIM",
"device": str(self.device),
"sample_rate": self.target_sample_rate
}
}
def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
"""Main entry point for the handler.
Args:
input_data: Raw input data
Returns:
Processed results with quality metrics
"""
try:
# Execute the full pipeline
audio_tensor = self.preprocess(input_data)
predictions = self.predict(audio_tensor)
final_output = self.postprocess(predictions)
return final_output
except Exception as e:
return {
"error": str(e),
"status": "error"
}
|