Spaces:
Running
Running
import logging | |
import librosa | |
import torch | |
from torch import nn | |
logger = logging.getLogger(__name__) | |
hann_window = {} | |
mel_basis = {} | |
def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor: | |
"""Spectral normalization / dynamic range compression.""" | |
return torch.log(torch.clamp(x, min=clip_val) * spec_gain) | |
def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor: | |
"""Spectral denormalization / dynamic range decompression.""" | |
return torch.exp(x) / spec_gain | |
def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor: | |
""" | |
Args Shapes: | |
- y : :math:`[B, 1, T]` | |
Return Shapes: | |
- spec : :math:`[B,C,T]` | |
""" | |
y = y.squeeze(1) | |
if torch.min(y) < -1.0: | |
logger.info("min value is %.3f", torch.min(y)) | |
if torch.max(y) > 1.0: | |
logger.info("max value is %.3f", torch.max(y)) | |
global hann_window | |
wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}" | |
if wnsize_dtype_device not in hann_window: | |
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) | |
y = torch.nn.functional.pad( | |
y.unsqueeze(1), | |
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), | |
mode="reflect", | |
) | |
y = y.squeeze(1) | |
spec = torch.view_as_real( | |
torch.stft( | |
y, | |
n_fft, | |
hop_length=hop_length, | |
win_length=win_length, | |
window=hann_window[wnsize_dtype_device], | |
center=center, | |
pad_mode="reflect", | |
normalized=False, | |
onesided=True, | |
return_complex=True, | |
) | |
) | |
return torch.sqrt(spec.pow(2).sum(-1) + 1e-6) | |
def spec_to_mel( | |
spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float | |
) -> torch.Tensor: | |
""" | |
Args Shapes: | |
- spec : :math:`[B,C,T]` | |
Return Shapes: | |
- mel : :math:`[B,C,T]` | |
""" | |
global mel_basis | |
fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" | |
if fmax_dtype_device not in mel_basis: | |
# TODO: switch librosa to torchaudio | |
mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) | |
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) | |
mel = torch.matmul(mel_basis[fmax_dtype_device], spec) | |
return amp_to_db(mel) | |
def wav_to_mel( | |
y: torch.Tensor, | |
n_fft: int, | |
num_mels: int, | |
sample_rate: int, | |
hop_length: int, | |
win_length: int, | |
fmin: float, | |
fmax: float, | |
*, | |
center: bool = False, | |
) -> torch.Tensor: | |
""" | |
Args Shapes: | |
- y : :math:`[B, 1, T]` | |
Return Shapes: | |
- spec : :math:`[B,C,T]` | |
""" | |
spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center) | |
return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax) | |
class TorchSTFT(nn.Module): # pylint: disable=abstract-method | |
"""Some of the audio processing funtions using Torch for faster batch processing. | |
Args: | |
n_fft (int): | |
FFT window size for STFT. | |
hop_length (int): | |
number of frames between STFT columns. | |
win_length (int, optional): | |
STFT window length. | |
pad_wav (bool, optional): | |
If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. | |
window (str, optional): | |
The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" | |
sample_rate (int, optional): | |
target audio sampling rate. Defaults to None. | |
mel_fmin (int, optional): | |
minimum filter frequency for computing melspectrograms. Defaults to None. | |
mel_fmax (int, optional): | |
maximum filter frequency for computing melspectrograms. Defaults to None. | |
n_mels (int, optional): | |
number of melspectrogram dimensions. Defaults to None. | |
use_mel (bool, optional): | |
If True compute the melspectrograms otherwise. Defaults to False. | |
do_amp_to_db_linear (bool, optional): | |
enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. | |
spec_gain (float, optional): | |
gain applied when converting amplitude to DB. Defaults to 1.0. | |
power (float, optional): | |
Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. | |
use_htk (bool, optional): | |
Use HTK formula in mel filter instead of Slaney. | |
mel_norm (None, 'slaney', or number, optional): | |
If 'slaney', divide the triangular mel weights by the width of the mel band | |
(area normalization). | |
If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. | |
See `librosa.util.normalize` for a full description of supported norm values | |
(including `+-np.inf`). | |
Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". | |
""" | |
def __init__( | |
self, | |
n_fft, | |
hop_length, | |
win_length, | |
pad_wav=False, | |
window="hann_window", | |
sample_rate=None, | |
mel_fmin=0, | |
mel_fmax=None, | |
n_mels=80, | |
use_mel=False, | |
do_amp_to_db=False, | |
spec_gain=1.0, | |
power=None, | |
use_htk=False, | |
mel_norm="slaney", | |
normalized=False, | |
): | |
super().__init__() | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.pad_wav = pad_wav | |
self.sample_rate = sample_rate | |
self.mel_fmin = mel_fmin | |
self.mel_fmax = mel_fmax | |
self.n_mels = n_mels | |
self.use_mel = use_mel | |
self.do_amp_to_db = do_amp_to_db | |
self.spec_gain = spec_gain | |
self.power = power | |
self.use_htk = use_htk | |
self.mel_norm = mel_norm | |
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) | |
self.mel_basis = None | |
self.normalized = normalized | |
if use_mel: | |
self._build_mel_basis() | |
def __call__(self, x): | |
"""Compute spectrogram frames by torch based stft. | |
Args: | |
x (Tensor): input waveform | |
Returns: | |
Tensor: spectrogram frames. | |
Shapes: | |
x: [B x T] or [:math:`[B, 1, T]`] | |
""" | |
if x.ndim == 2: | |
x = x.unsqueeze(1) | |
if self.pad_wav: | |
padding = int((self.n_fft - self.hop_length) / 2) | |
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") | |
# B x D x T x 2 | |
o = torch.view_as_real( | |
torch.stft( | |
x.squeeze(1), | |
self.n_fft, | |
self.hop_length, | |
self.win_length, | |
self.window, | |
center=True, | |
pad_mode="reflect", # compatible with audio.py | |
normalized=self.normalized, | |
onesided=True, | |
return_complex=True, | |
) | |
) | |
M = o[:, :, :, 0] | |
P = o[:, :, :, 1] | |
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) | |
if self.power is not None: | |
S = S**self.power | |
if self.use_mel: | |
S = torch.matmul(self.mel_basis.to(x), S) | |
if self.do_amp_to_db: | |
S = self._amp_to_db(S, spec_gain=self.spec_gain) | |
return S | |
def _build_mel_basis(self): | |
mel_basis = librosa.filters.mel( | |
sr=self.sample_rate, | |
n_fft=self.n_fft, | |
n_mels=self.n_mels, | |
fmin=self.mel_fmin, | |
fmax=self.mel_fmax, | |
htk=self.use_htk, | |
norm=self.mel_norm, | |
) | |
self.mel_basis = torch.from_numpy(mel_basis).float() | |