File size: 4,492 Bytes
be6e0e6
 
 
 
 
 
145d68f
 
 
be6e0e6
 
 
73a666a
be6e0e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c87043
be6e0e6
3c87043
be6e0e6
3c87043
be6e0e6
 
3c87043
 
be6e0e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torchaudio
from torchaudio.pipelines import SQUIM_OBJECTIVE
import numpy as np
from typing import Dict, Union, Any
from io import BytesIO
import logging

logger = logging.getLogger(__name__)


class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs):
        """Initialize the SQUIM model handler.
        Sets up the model on GPU if available, otherwise on CPU.
        """
        # Determine the device to use
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize the SQUIM model
        self.model = SQUIM_OBJECTIVE.get_model().to(self.device).float()
        
        # Store the expected sample rate from the model
        self.target_sample_rate = SQUIM_OBJECTIVE.sample_rate
        
        # Set model to evaluation mode
        self.model.eval()
        
        print(f"Initialized SQUIM model on device: {self.device}")

    def preprocess(self, input_data: Union[bytes, Dict[str, Any]]) -> torch.Tensor:
        """Preprocess the input audio data.
        
        Args:
            input_data: Either raw bytes of audio file or a dictionary containing audio data
            
        Returns:
            torch.Tensor: Preprocessed audio tensor ready for inference
        """
        try:
            audio = input_data if isinstance(input_data, bytes) else input_data['inputs']
            # Handle different input types
            if isinstance(audio, bytes):
                # Load audio from bytes
                audio_buffer = BytesIO(audio)
                waveform, sample_rate = torchaudio.load(audio_buffer)
            else:
                logger.error(f"Unsupported input type: {type(audio)}")
                logger.debug(f"Input data: {input_data.keys()}")
                raise ValueError("Unsupported input type")

            # Convert to float32
            waveform = waveform.float()

            # Resample if necessary
            if sample_rate != self.target_sample_rate:
                waveform = torchaudio.functional.resample(
                    waveform, 
                    sample_rate, 
                    self.target_sample_rate
                )

            # If stereo, convert to mono by averaging channels
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Move to appropriate device
            waveform = waveform.to(self.device)

            return waveform

        except Exception as e:
            raise RuntimeError(f"Error in preprocessing: {str(e)}")

    def predict(self, audio_tensor: torch.Tensor) -> Dict[str, float]:
        """Run inference with the SQUIM model.
        
        Args:
            audio_tensor: Preprocessed audio tensor
            
        Returns:
            Dictionary containing the quality metrics
        """
        try:
            with torch.no_grad():
                stoi, pesq, si_sdr = self.model(audio_tensor)

            return {
                "stoi": stoi.item(),
                "pesq": pesq.item(),
                "si_sdr": si_sdr.item()
            }

        except Exception as e:
            raise RuntimeError(f"Error during inference: {str(e)}")

    def postprocess(self, model_output: Dict[str, float]) -> Dict[str, Any]:
        """Postprocess the model output.
        
        Args:
            model_output: Dictionary containing the raw model outputs
            
        Returns:
            Dictionary containing the formatted results with additional metadata
        """
        return {
            "metrics": model_output,
            "metadata": {
                "model_name": "SQUIM",
                "device": str(self.device),
                "sample_rate": self.target_sample_rate
            }
        }

    def __call__(self, input_data: Union[bytes, Dict[str, Any]]) -> Dict[str, Any]:
        """Main entry point for the handler.
        
        Args:
            input_data: Raw input data
            
        Returns:
            Processed results with quality metrics
        """
        try:
            # Execute the full pipeline
            audio_tensor = self.preprocess(input_data)
            predictions = self.predict(audio_tensor)
            final_output = self.postprocess(predictions)
            
            return final_output

        except Exception as e:
            return {
                "error": str(e),
                "status": "error"
            }