Spaces:
Runtime error
Runtime error
| import torch | |
| class STFT: | |
| def __init__(self, n_fft, hop_length, dim_f): | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.window = torch.hann_window(window_length=n_fft, periodic=True) | |
| self.dim_f = dim_f | |
| def __call__(self, x): | |
| window = self.window.to(x.device) | |
| batch_dims = x.shape[:-2] | |
| c, t = x.shape[-2:] | |
| x = x.reshape([-1, t]) | |
| x = torch.stft( | |
| x, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| window=window, | |
| center=True, | |
| return_complex=True, | |
| ) | |
| x = torch.view_as_real(x) | |
| x = x.permute([0, 3, 1, 2]) | |
| x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape( | |
| [*batch_dims, c * 2, -1, x.shape[-1]] | |
| ) | |
| return x[..., : self.dim_f, :] | |
| def inverse(self, x): | |
| window = self.window.to(x.device) | |
| batch_dims = x.shape[:-3] | |
| c, f, t = x.shape[-3:] | |
| n = self.n_fft // 2 + 1 | |
| f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) | |
| x = torch.cat([x, f_pad], -2) | |
| x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) | |
| x = x.permute([0, 2, 3, 1]) | |
| x = x.contiguous() | |
| t_complex = torch.view_as_complex(x) | |
| x = torch.istft( | |
| t_complex, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| window=window, | |
| center=True, | |
| ) | |
| x = x.reshape([*batch_dims, 2, -1]) | |
| return x | |