|
|
|
|
|
|
|
|
|
|
|
import typing as tp |
|
|
|
from einops import rearrange |
|
from librosa import filters |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import torchaudio |
|
|
|
|
|
class ChromaExtractor(nn.Module): |
|
"""Chroma extraction and quantization. |
|
|
|
Args: |
|
sample_rate (int): Sample rate for the chroma extraction. |
|
n_chroma (int): Number of chroma bins for the chroma extraction. |
|
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). |
|
nfft (int, optional): Number of FFT. |
|
winlen (int, optional): Window length. |
|
winhop (int, optional): Window hop size. |
|
argmax (bool, optional): Whether to use argmax. Defaults to False. |
|
norm (float, optional): Norm for chroma normalization. Defaults to inf. |
|
""" |
|
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, |
|
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, |
|
norm: float = torch.inf): |
|
super().__init__() |
|
self.winlen = winlen or 2 ** radix2_exp |
|
self.nfft = nfft or self.winlen |
|
self.winhop = winhop or (self.winlen // 4) |
|
self.sample_rate = sample_rate |
|
self.n_chroma = n_chroma |
|
self.norm = norm |
|
self.argmax = argmax |
|
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, |
|
n_chroma=self.n_chroma)), persistent=False) |
|
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, |
|
hop_length=self.winhop, power=2, center=True, |
|
pad=0, normalized=True) |
|
|
|
def forward(self, wav: torch.Tensor) -> torch.Tensor: |
|
T = wav.shape[-1] |
|
|
|
|
|
if T < self.nfft: |
|
pad = self.nfft - T |
|
r = 0 if pad % 2 == 0 else 1 |
|
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) |
|
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" |
|
|
|
spec = self.spec(wav).squeeze(1) |
|
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) |
|
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) |
|
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') |
|
|
|
if self.argmax: |
|
idx = norm_chroma.argmax(-1, keepdim=True) |
|
norm_chroma[:] = 0 |
|
norm_chroma.scatter_(dim=-1, index=idx, value=1) |
|
|
|
return norm_chroma |
|
|