File size: 8,059 Bytes
01f8b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import unittest
import numpy as np
import torch
from unittest.mock import Mock
from audio_separator.separator.uvr_lib_v5.stft import STFT

# Short-Time Fourier Transform (STFT) Process Overview:
#
# STFT transforms a time-domain signal into a frequency-domain representation.
#   This transformation is achieved by dividing the signal into short frames (or segments) and applying the Fourier Transform to each frame.
#
# n_fft: The number of points used in the Fourier Transform, which determines the resolution of the frequency domain representation.
#   Essentially, it dictates how many frequency bins we get in our STFT.
#
# hop_length: The number of samples by which we shift each frame of the signal.
#   It affects the overlap between consecutive frames. If the hop_length is less than n_fft, we get overlapping frames.
#
# Windowing: Each frame of the signal is multiplied by a window function (e.g. Hann window) before applying the Fourier Transform.
#   This is done to minimize discontinuities at the borders of each frame.


class TestSTFT(unittest.TestCase):
    def setUp(self):
        self.n_fft = 2048
        self.hop_length = 512
        self.dim_f = 1025
        self.device = torch.device("cpu")
        self.stft = STFT(logger=Mock(), n_fft=self.n_fft, hop_length=self.hop_length, dim_f=self.dim_f, device=self.device)

    def create_mock_tensor(self, shape, device=None):
        tensor = torch.rand(shape)
        if device:
            tensor = tensor.to(device)
        return tensor

    def test_stft_initialization(self):
        self.assertEqual(self.stft.n_fft, self.n_fft)
        self.assertEqual(self.stft.hop_length, self.hop_length)
        self.assertEqual(self.stft.dim_f, self.dim_f)
        self.assertEqual(self.stft.device.type, "cpu")
        self.assertIsInstance(self.stft.hann_window, torch.Tensor)

    def test_stft_call(self):
        input_tensor = self.create_mock_tensor((1, 16000))

        # Apply STFT
        stft_result = self.stft(input_tensor)

        # Test conditions
        self.assertIsNotNone(stft_result)
        self.assertIsInstance(stft_result, torch.Tensor)

        # Calculate the expected shape based on input parameters:

        # Frequency Dimension (dim_f): This corresponds to the number of frequency bins in the STFT output.
        #   In the case of a real-valued input signal (like audio), the Fourier Transform produces a symmetric output.
        #   Hence, for an n_fft of 2048, we would typically get 2049 frequency bins (from 0 Hz to the Nyquist frequency).
        #   However, we often don't need the full symmetric spectrum.
        #   So, dim_f is used to specify how many frequency bins we are interested in.
        #   In this test, it's set to 1025, which is about half of n_fft + 1 (as the Fourier Transform of a real-valued signal is symmetric).

        # Time Dimension: This corresponds to how many frames (or segments) the input signal has been divided into.
        #   It depends on the length of the input signal and the hop_length.
        #   The formula for calculating the number of frames is derived from how we stride the window across the signal:
        #     Length of Input Signal: Let's denote it as L. In this test, the input tensor has a shape of [1, 16000], so L is 16000 (ignoring the batch dimension for simplicity).
        #     Number of Frames: The number of frames depends on how we stride the window across the signal. For each frame, we move the window by hop_length samples.
        #     Therefore, the number of frames N_frames can be roughly estimated by dividing the length of the signal by the hop_length.
        #     However, since the window overlaps the signal, we add an extra frame to account for the last segment of the signal. This gives us N_frames = (L // hop_length) + 1.

        # Putting It All Together
        #   expected_shape thus becomes (dim_f, N_frames), which is (1025, (16000 // 512) + 1) in this test case.

        expected_shape = (self.dim_f, (input_tensor.shape[1] // self.hop_length) + 1)

        self.assertEqual(stft_result.shape[-2:], expected_shape)

    def test_calculate_inverse_dimensions(self):
        # Create a sample input tensor
        sample_input = torch.randn(1, 2, 500, 32)  # Batch, Channel, Frequency, Time dimensions
        batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins = self.stft.calculate_inverse_dimensions(sample_input)

        # Expected values
        expected_num_freq_bins = self.n_fft // 2 + 1

        # Assertions
        self.assertEqual(batch_dims, sample_input.shape[:-3])
        self.assertEqual(channel_dim, 2)
        self.assertEqual(freq_dim, 500)
        self.assertEqual(time_dim, 32)
        self.assertEqual(num_freq_bins, expected_num_freq_bins)

    def test_pad_frequency_dimension(self):
        # Create a sample input tensor
        sample_input = torch.randn(1, 2, 500, 32)  # Batch, Channel, Frequency, Time dimensions
        batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins = self.stft.calculate_inverse_dimensions(sample_input)

        # Apply padding
        padded_output = self.stft.pad_frequency_dimension(sample_input, batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins)

        # Expected frequency dimension after padding
        expected_freq_dim = num_freq_bins

        # Assertions
        self.assertEqual(padded_output.shape[-2], expected_freq_dim)

    def test_prepare_for_istft(self):
        # Create a sample input tensor
        sample_input = torch.randn(1, 2, 500, 32)  # Batch, Channel, Frequency, Time dimensions
        batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins = self.stft.calculate_inverse_dimensions(sample_input)
        padded_output = self.stft.pad_frequency_dimension(sample_input, batch_dims, channel_dim, freq_dim, time_dim, num_freq_bins)

        # Apply prepare_for_istft
        complex_tensor = self.stft.prepare_for_istft(padded_output, batch_dims, channel_dim, num_freq_bins, time_dim)

        # Calculate the expected flattened batch size (flattening batch and channel dimensions)
        expected_flattened_batch_size = batch_dims[0] * (channel_dim // 2)

        # Expected shape of the complex tensor
        expected_shape = (expected_flattened_batch_size, num_freq_bins, time_dim)

        # Assertions
        self.assertEqual(complex_tensor.shape, expected_shape)

    def test_inverse_stft(self):
        # Create a mock tensor with the correct input shape
        input_tensor = torch.rand(1, 2, 1025, 32)  # shape matching output of STFT

        # Apply inverse STFT
        output_tensor = self.stft.inverse(input_tensor)

        # Check if the output tensor is on the CPU
        self.assertEqual(output_tensor.device.type, "cpu")

        # Expected output shape: (Batch size, Channel dimension, Time dimension)
        expected_shape = (1, 2, 7936)  # Calculated based on STFT parameters

        # Check if the output tensor has the expected shape
        self.assertEqual(output_tensor.shape, expected_shape)

    @unittest.skipIf(not torch.backends.mps.is_available(), "MPS not available")
    def test_stft_with_mps_device(self):
        mps_device = torch.device("mps")
        self.stft.device = mps_device
        input_tensor = self.create_mock_tensor((1, 16000), device=mps_device)
        stft_result = self.stft(input_tensor)
        self.assertIsNotNone(stft_result)
        self.assertIsInstance(stft_result, torch.Tensor)

    @unittest.skipIf(not torch.backends.mps.is_available(), "MPS not available")
    def test_inverse_with_mps_device(self):
        mps_device = torch.device("mps")
        self.stft.device = mps_device
        input_tensor = self.create_mock_tensor((1, 2, 1025, 32), device=mps_device)
        istft_result = self.stft.inverse(input_tensor)
        self.assertIsNotNone(istft_result)
        self.assertIsInstance(istft_result, torch.Tensor)


# Mock logger to use in tests
class MockLogger:
    def debug(self, message):
        pass


if __name__ == "__main__":
    unittest.main()