ASesYusuf1's picture
Upload 131 files
01f8b5b verified
import torch
class STFT:
"""
This class performs the Short-Time Fourier Transform (STFT) and its inverse (ISTFT).
These functions are essential for converting the audio between the time domain and the frequency domain,
which is a crucial aspect of audio processing in neural networks.
"""
def __init__(self, logger, n_fft, hop_length, dim_f, device):
self.logger = logger
self.n_fft = n_fft
self.hop_length = hop_length
self.dim_f = dim_f
self.device = device
# Create a Hann window tensor for use in the STFT.
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
def __call__(self, input_tensor):
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
# If on a non-standard device, temporarily move the tensor to CPU for processing.
if is_non_standard_device:
input_tensor = input_tensor.cpu()
# Transfer the pre-defined window tensor to the same device as the input tensor.
stft_window = self.hann_window.to(input_tensor.device)
# Extract batch dimensions (all dimensions except the last two which are channel and time).
batch_dimensions = input_tensor.shape[:-2]
# Extract channel and time dimensions (last two dimensions of the tensor).
channel_dim, time_dim = input_tensor.shape[-2:]
# Reshape the tensor to merge batch and channel dimensions for STFT processing.
reshaped_tensor = input_tensor.reshape([-1, time_dim])
# Perform the Short-Time Fourier Transform (STFT) on the reshaped tensor.
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
# Rearrange the dimensions of the STFT output to bring the frequency dimension forward.
permuted_stft_output = stft_output.permute([0, 3, 1, 2])
# Reshape the output to restore the original batch and channel dimensions, while keeping the newly formed frequency and time dimensions.
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape(
[*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]
)
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
if is_non_standard_device:
final_output = final_output.to(self.device)
# Return the transformed tensor, sliced to retain only the required frequency dimension (`dim_f`).
return final_output[..., : self.dim_f, :]
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
"""
Adds zero padding to the frequency dimension of the input tensor.
"""
# Create a padding tensor for the frequency dimension
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
# Concatenate the padding to the input tensor along the frequency dimension.
padded_tensor = torch.cat([input_tensor, freq_padding], -2)
return padded_tensor
def calculate_inverse_dimensions(self, input_tensor):
# Extract batch dimensions and frequency-time dimensions.
batch_dimensions = input_tensor.shape[:-3]
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
# Calculate the number of frequency bins for the inverse STFT.
num_freq_bins = self.n_fft // 2 + 1
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
"""
Prepares the tensor for Inverse Short-Time Fourier Transform (ISTFT) by reshaping
and creating a complex tensor from the real and imaginary parts.
"""
# Reshape the tensor to separate real and imaginary parts and prepare for ISTFT.
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
# Flatten batch dimensions and rearrange for ISTFT.
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
# Rearrange the dimensions of the tensor to bring the frequency dimension forward.
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
# Combine real and imaginary parts into a complex tensor.
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
return complex_tensor
def inverse(self, input_tensor):
# Determine if the input tensor's device is not a standard computing device (i.e., not CPU or CUDA).
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
# If on a non-standard device, temporarily move the tensor to CPU for processing.
if is_non_standard_device:
input_tensor = input_tensor.cpu()
# Transfer the pre-defined Hann window tensor to the same device as the input tensor.
stft_window = self.hann_window.to(input_tensor.device)
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
# Perform the Inverse Short-Time Fourier Transform (ISTFT).
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
# Reshape ISTFT result to restore original batch and channel dimensions.
final_output = istft_result.reshape([*batch_dimensions, 2, -1])
# If the original tensor was on a non-standard device, move the processed tensor back to that device.
if is_non_standard_device:
final_output = final_output.to(self.device)
return final_output