# Copyright 2019 Jian Wu # License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import math import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as tf import librosa.filters as filters from typing import Optional, Tuple from distutils.version import LooseVersion EPSILON = float(np.finfo(np.float32).eps) TORCH_VERSION = th.__version__ if TORCH_VERSION >= LooseVersion("1.7"): from torch.fft import fft as fft_func else: pass def export_jit(transform: nn.Module) -> nn.Module: """ Export transform module for inference """ export_out = [module for module in transform if module.exportable()] return nn.Sequential(*export_out) def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor: """ Return window coefficient Args: wnd: window name frame_len: length of the frame """ def sqrthann(frame_len, periodic=True): return th.hann_window(frame_len, periodic=periodic) ** 0.5 if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]: raise RuntimeError(f"Unknown window type: {wnd}") wnd_tpl = { "sqrthann": sqrthann, "hann": th.hann_window, "hamm": th.hamming_window, "blackman": th.blackman_window, "bartlett": th.bartlett_window, "rect": th.ones, } if wnd != "rect": # match with librosa c = wnd_tpl[wnd](frame_len, periodic=True) else: c = wnd_tpl[wnd](frame_len) return c.to(device) def init_kernel( frame_len: int, frame_hop: int, window: th.Tensor, round_pow_of_two: bool = True, normalized: bool = False, inverse: bool = False, mode: str = "librosa", ) -> Tuple[th.Tensor, th.Tensor]: """ Return STFT kernels Args: frame_len: length of the frame frame_hop: hop size between frames window: window tensor round_pow_of_two: if true, choose round(#power_of_two) as the FFT size normalized: return normalized DFT matrix inverse: return iDFT matrix mode: framing mode (librosa or kaldi) """ if mode not in ["librosa", "kaldi"]: raise ValueError(f"Unsupported mode: {mode}") # FFT size: B if round_pow_of_two or mode == "kaldi": fft_size = 2 ** math.ceil(math.log2(frame_len)) else: fft_size = frame_len # center padding window if needed if mode == "librosa" and fft_size != frame_len: lpad = (fft_size - frame_len) // 2 window = tf.pad(window, (lpad, fft_size - frame_len - lpad)) if normalized: # make K^H * K = I S = fft_size ** 0.5 else: S = 1 # W x B x 2 if TORCH_VERSION >= LooseVersion("1.7"): K = fft_func(th.eye(fft_size) / S, dim=-1) K = th.stack([K.real, K.imag], dim=-1) else: I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1) K = th.fft(I / S, 1) if mode == "kaldi": K = K[:frame_len] if inverse and not normalized: # to make K^H * K = I K = K / fft_size # 2 x B x W K = th.transpose(K, 0, 2) # 2B x 1 x W K = th.reshape(K, (fft_size * 2, 1, K.shape[-1])) return K.to(window.device), window def mel_filter( frame_len: int, round_pow_of_two: bool = True, num_bins: Optional[int] = None, sr: int = 16000, num_mels: int = 80, fmin: float = 0.0, fmax: Optional[float] = None, norm: bool = False, ) -> th.Tensor: """ Return mel filter coefficients Args: frame_len: length of the frame round_pow_of_two: if true, choose round(#power_of_two) as the FFT size num_bins: number of the frequency bins produced by STFT num_mels: number of the mel bands fmin: lowest frequency (in Hz) fmax: highest frequency (in Hz) norm: normalize the mel filter coefficients """ # FFT points if num_bins is None: N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len else: N = (num_bins - 1) * 2 # fmin & fmax freq_upper = sr // 2 if fmax is None: fmax = freq_upper else: fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper) fmin = max(0, fmin) # mel filter coefficients mel = filters.mel( sr, N, n_mels=num_mels, fmax=fmax, fmin=fmin, htk=True, norm="slaney" if norm else None, ) # num_mels x (N // 2 + 1) return th.tensor(mel, dtype=th.float32) def speed_perturb_filter( src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64 ) -> th.Tensor: """ Return speed perturb filters, reference: https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py Args: src_sr: sample rate of the source signal dst_sr: sample rate of the target signal Return: weight (Tensor): coefficients of the filter """ if src_sr == dst_sr: raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}") gcd = math.gcd(src_sr, dst_sr) src_sr = src_sr // gcd dst_sr = dst_sr // gcd if src_sr == 1 or dst_sr == 1: raise ValueError("do not support integer downsample/upsample") zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio padding = 1 + int(num_zeros / zeros_per_block) # dst_sr x src_sr x K times = ( np.arange(dst_sr)[:, None, None] / float(dst_sr) - np.arange(src_sr)[None, :, None] / float(src_sr) - np.arange(2 * padding + 1)[None, None, :] + padding ) window = np.heaviside(1 - np.abs(times / padding), 0.0) * ( 0.5 + 0.5 * np.cos(times / padding * math.pi) ) weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr) return th.tensor(weight, dtype=th.float32) def splice_feature( feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat" ) -> th.Tensor: """ Splice feature Args: feats (Tensor): N x ... x T x F, original feature lctx: left context rctx: right context op: operator on feature context Return: splice (Tensor): feature with context padded """ if lctx + rctx == 0: return feats if op not in ["cat", "stack"]: raise ValueError(f"Unknown op for feature splicing: {op}") # [N x ... x T x F, ...] ctx = [] T = feats.shape[-2] for c in range(-lctx, rctx + 1): idx = th.arange(c, c + T, device=feats.device, dtype=th.int64) idx = th.clamp(idx, min=0, max=T - 1) ctx.append(th.index_select(feats, -2, idx)) if op == "cat": # N x ... x T x FD splice = th.cat(ctx, -1) else: # N x ... x T x F x D splice = th.stack(ctx, -1) return splice def _forward_stft( wav: th.Tensor, kernel: th.Tensor, window: th.Tensor, return_polar: bool = False, pre_emphasis: float = 0, frame_hop: int = 256, onesided: bool = False, center: bool = False, eps: float = EPSILON, ) -> th.Tensor: """ STFT function implemented by conv1d (not efficient, but we don't care during training) Args: wav (Tensor): N x (C) x S kernel (Tensor): STFT transform kernels, from init_kernel(...) return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor pre_emphasis: factor of preemphasis frame_hop: frame hop size in number samples onesided: return half FFT bins center: if true, we assumed to have centered frames Return: transform (Tensor): STFT transform results """ wav_dim = wav.dim() if wav_dim not in [2, 3]: raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") # if N x S, reshape N x 1 x S # else: reshape NC x 1 x S N, S = wav.shape[0], wav.shape[-1] wav = wav.view(-1, 1, S) # NC x 1 x S+2P if center: pad = kernel.shape[-1] // 2 # NOTE: match with librosa wav = tf.pad(wav, (pad, pad), mode="reflect") # STFT kernel = kernel * window if pre_emphasis > 0: # NC x W x T frames = tf.unfold( wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0 ) # follow Kaldi's Preemphasize frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1] frames[:, 0] *= 1 - pre_emphasis # 1 x 2B x W, NC x W x T, NC x 2B x T packed = th.matmul(kernel[:, 0][None, ...], frames) else: packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0) # NC x 2B x T => N x C x 2B x T if wav_dim == 3: packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1]) # N x (C) x B x T real, imag = th.chunk(packed, 2, dim=-2) # N x (C) x B/2+1 x T if onesided: num_bins = kernel.shape[0] // 4 + 1 real = real[..., :num_bins, :] imag = imag[..., :num_bins, :] if return_polar: mag = (real ** 2 + imag ** 2 + eps) ** 0.5 pha = th.atan2(imag, real) return th.stack([mag, pha], dim=-1) else: return th.stack([real, imag], dim=-1) def _inverse_stft( transform: th.Tensor, kernel: th.Tensor, window: th.Tensor, return_polar: bool = False, frame_hop: int = 256, onesided: bool = False, center: bool = False, eps: float = EPSILON, ) -> th.Tensor: """ iSTFT function implemented by conv1d Args: transform (Tensor): STFT transform results kernel (Tensor): STFT transform kernels, from init_kernel(...) return_polar (bool): keep same with the one in _forward_stft frame_hop: frame hop size in number samples onesided: return half FFT bins center: used in _forward_stft Return: wav (Tensor), N x S """ # (N) x F x T x 2 transform_dim = transform.dim() # if F x T x 2, reshape 1 x F x T x 2 if transform_dim == 3: transform = th.unsqueeze(transform, 0) if transform_dim != 4: raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D") if return_polar: real = transform[..., 0] * th.cos(transform[..., 1]) imag = transform[..., 0] * th.sin(transform[..., 1]) else: real, imag = transform[..., 0], transform[..., 1] if onesided: # [self.num_bins - 2, ..., 1] reverse = range(kernel.shape[0] // 4 - 1, 0, -1) # extend matrix: N x B x T real = th.cat([real, real[:, reverse]], 1) imag = th.cat([imag, -imag[:, reverse]], 1) # pack: N x 2B x T packed = th.cat([real, imag], dim=1) # N x 1 x T wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0) # normalized audio samples # refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171 num_frames = packed.shape[-1] win_length = window.shape[0] # W x T win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1) # Do OLA on windows # v1) I = th.eye(win_length, device=win.device)[:, None] denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0) # v2) # num_samples = (num_frames - 1) * frame_hop + win_length # denorm = tf.fold(win[None, ...], (num_samples, 1), (win_length, 1), # stride=frame_hop)[..., 0] if center: pad = kernel.shape[-1] // 2 wav = wav[..., pad:-pad] denorm = denorm[..., pad:-pad] wav = wav / (denorm + eps) # N x S return wav.squeeze(1) def _pytorch_stft( wav: th.Tensor, frame_len: int, frame_hop: int, n_fft: int = 512, return_polar: bool = False, window: str = "sqrthann", normalized: bool = False, onesided: bool = True, center: bool = False, eps: float = EPSILON, ) -> th.Tensor: """ Wrapper of PyTorch STFT function Args: wav (Tensor): source audio signal frame_len: length of the frame frame_hop: hop size between frames n_fft: number of the FFT size return_polar: return the results in polar coordinate window: window tensor center: same definition with the parameter in librosa.stft normalized: use normalized DFT kernel onesided: output onesided STFT Return: transform (Tensor), STFT transform results """ if TORCH_VERSION < LooseVersion("1.7"): raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7") wav_dim = wav.dim() if wav_dim not in [2, 3]: raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") # if N x C x S, reshape NC x S wav = wav.view(-1, wav.shape[-1]) # STFT: N x F x T x 2 stft = th.stft( wav, n_fft, hop_length=frame_hop, win_length=window.shape[-1], window=window, center=center, normalized=normalized, onesided=onesided, return_complex=False, ) if wav_dim == 3: N, F, T, _ = stft.shape stft = stft.view(N, -1, F, T, 2) # N x (C) x F x T x 2 if not return_polar: return stft # N x (C) x F x T real, imag = stft[..., 0], stft[..., 1] mag = (real ** 2 + imag ** 2 + eps) ** 0.5 pha = th.atan2(imag, real) return th.stack([mag, pha], dim=-1) def _pytorch_istft( transform: th.Tensor, frame_len: int, frame_hop: int, window: th.Tensor, n_fft: int = 512, return_polar: bool = False, normalized: bool = False, onesided: bool = True, center: bool = False, eps: float = EPSILON, ) -> th.Tensor: """ Wrapper of PyTorch iSTFT function Args: transform (Tensor): results of STFT frame_len: length of the frame frame_hop: hop size between frames window: window tensor n_fft: number of the FFT size return_polar: keep same with _pytorch_stft center: same definition with the parameter in librosa.stft normalized: use normalized DFT kernel onesided: output onesided STFT Return: wav (Tensor): synthetic audio """ if TORCH_VERSION < LooseVersion("1.7"): raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7") transform_dim = transform.dim() # if F x T x 2, reshape 1 x F x T x 2 if transform_dim == 3: transform = th.unsqueeze(transform, 0) if transform_dim != 4: raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D") if return_polar: real = transform[..., 0] * th.cos(transform[..., 1]) imag = transform[..., 0] * th.sin(transform[..., 1]) transform = th.stack([real, imag], -1) # stft is a complex tensor of PyTorch stft = th.view_as_complex(transform) # (N) x S wav = th.istft( stft, n_fft, hop_length=frame_hop, win_length=window.shape[-1], window=window, center=center, normalized=normalized, onesided=onesided, return_complex=False, ) return wav def forward_stft( wav: th.Tensor, frame_len: int, frame_hop: int, window: str = "sqrthann", round_pow_of_two: bool = True, return_polar: bool = False, pre_emphasis: float = 0, normalized: bool = False, onesided: bool = True, center: bool = False, mode: str = "librosa", eps: float = EPSILON, ) -> th.Tensor: """ STFT function implementation, equals to STFT layer Args: wav: source audio signal frame_len: length of the frame frame_hop: hop size between frames return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor window: window name center: center flag (similar with that in librosa.stft) round_pow_of_two: if true, choose round(#power_of_two) as the FFT size pre_emphasis: factor of preemphasis normalized: use normalized DFT kernel onesided: output onesided STFT inverse: using iDFT kernel (for iSTFT) mode: STFT mode, "kaldi" or "librosa" or "torch" Return: transform: results of STFT """ window = init_window(window, frame_len, device=wav.device) if mode == "torch": n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len return _pytorch_stft( wav, frame_len, frame_hop, n_fft=n_fft, return_polar=return_polar, window=window, normalized=normalized, onesided=onesided, center=center, eps=eps, ) else: kernel, window = init_kernel( frame_len, frame_hop, window=window, round_pow_of_two=round_pow_of_two, normalized=normalized, inverse=False, mode=mode, ) return _forward_stft( wav, kernel, window, return_polar=return_polar, frame_hop=frame_hop, pre_emphasis=pre_emphasis, onesided=onesided, center=center, eps=eps, ) def inverse_stft( transform: th.Tensor, frame_len: int, frame_hop: int, return_polar: bool = False, window: str = "sqrthann", round_pow_of_two: bool = True, normalized: bool = False, onesided: bool = True, center: bool = False, mode: str = "librosa", eps: float = EPSILON, ) -> th.Tensor: """ iSTFT function implementation, equals to iSTFT layer Args: transform: results of STFT frame_len: length of the frame frame_hop: hop size between frames return_polar: keep same with function forward_stft(...) window: window name center: center flag (similar with that in librosa.stft) round_pow_of_two: if true, choose round(#power_of_two) as the FFT size normalized: use normalized DFT kernel onesided: output onesided STFT mode: STFT mode, "kaldi" or "librosa" or "torch" Return: wav: synthetic signals """ window = init_window(window, frame_len, device=transform.device) if mode == "torch": n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len return _pytorch_istft( transform, frame_len, frame_hop, n_fft=n_fft, return_polar=return_polar, window=window, normalized=normalized, onesided=onesided, center=center, eps=eps, ) else: kernel, window = init_kernel( frame_len, frame_hop, window, round_pow_of_two=round_pow_of_two, normalized=normalized, inverse=True, mode=mode, ) return _inverse_stft( transform, kernel, window, return_polar=return_polar, frame_hop=frame_hop, onesided=onesided, center=center, eps=eps, ) class STFTBase(nn.Module): """ Base layer for (i)STFT Args: frame_len: length of the frame frame_hop: hop size between frames window: window name center: center flag (similar with that in librosa.stft) round_pow_of_two: if true, choose round(#power_of_two) as the FFT size normalized: use normalized DFT kernel pre_emphasis: factor of preemphasis mode: STFT mode, "kaldi" or "librosa" or "torch" onesided: output onesided STFT inverse: using iDFT kernel (for iSTFT) """ def __init__( self, frame_len: int, frame_hop: int, window: str = "sqrthann", round_pow_of_two: bool = True, normalized: bool = False, pre_emphasis: float = 0, onesided: bool = True, inverse: bool = False, center: bool = False, mode: str = "librosa", ) -> None: super(STFTBase, self).__init__() if mode != "torch": K, w = init_kernel( frame_len, frame_hop, init_window(window, frame_len), round_pow_of_two=round_pow_of_two, normalized=normalized, inverse=inverse, mode=mode, ) self.K = nn.Parameter(K, requires_grad=False) self.w = nn.Parameter(w, requires_grad=False) self.num_bins = self.K.shape[0] // 4 + 1 self.pre_emphasis = pre_emphasis self.win_length = self.K.shape[2] else: self.K = None w = init_window(window, frame_len) self.w = nn.Parameter(w, requires_grad=False) fft_size = ( 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len ) self.num_bins = fft_size // 2 + 1 self.pre_emphasis = 0 self.win_length = fft_size self.frame_len = frame_len self.frame_hop = frame_hop self.window = window self.normalized = normalized self.onesided = onesided self.center = center self.mode = mode def num_frames(self, wav_len: th.Tensor) -> th.Tensor: """ Compute number of the frames """ assert th.sum(wav_len <= self.win_length) == 0 if self.center: wav_len += self.win_length return ( th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1 ) def extra_repr(self) -> str: str_repr = ( f"num_bins={self.num_bins}, win_length={self.win_length}, " + f"stride={self.frame_hop}, window={self.window}, " + f"center={self.center}, mode={self.mode}" ) if not self.onesided: str_repr += f", onesided={self.onesided}" if self.pre_emphasis > 0: str_repr += f", pre_emphasis={self.pre_emphasis}" if self.normalized: str_repr += f", normalized={self.normalized}" return str_repr class STFT(STFTBase): """ Short-time Fourier Transform as a Layer """ def __init__(self, *args, **kwargs): super(STFT, self).__init__(*args, inverse=False, **kwargs) def forward( self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON ) -> th.Tensor: """ Accept (single or multiple channel) raw waveform and output magnitude and phase Args wav (Tensor) input signal, N x (C) x S Return transform (Tensor), N x (C) x F x T x 2 """ if self.mode == "torch": return _pytorch_stft( wav, self.frame_len, self.frame_hop, n_fft=(self.num_bins - 1) * 2, return_polar=return_polar, window=self.w, normalized=self.normalized, onesided=self.onesided, center=self.center, eps=eps, ) else: return _forward_stft( wav, self.K, self.w, return_polar=return_polar, frame_hop=self.frame_hop, pre_emphasis=self.pre_emphasis, onesided=self.onesided, center=self.center, eps=eps, ) class iSTFT(STFTBase): """ Inverse Short-time Fourier Transform as a Layer """ def __init__(self, *args, **kwargs): super(iSTFT, self).__init__(*args, inverse=True, **kwargs) def forward( self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON ) -> th.Tensor: """ Accept phase & magnitude and output raw waveform Args transform (Tensor): STFT output, N x F x T x 2 Return s (Tensor): N x S """ if self.mode == "torch": return _pytorch_istft( transform, self.frame_len, self.frame_hop, n_fft=(self.num_bins - 1) * 2, return_polar=return_polar, window=self.w, normalized=self.normalized, onesided=self.onesided, center=self.center, eps=eps, ) else: return _inverse_stft( transform, self.K, self.w, return_polar=return_polar, frame_hop=self.frame_hop, onesided=self.onesided, center=self.center, eps=eps, )