Upload 11 files
Browse files- audiocraft/modules/activations.py +96 -0
- audiocraft/modules/chroma.py +66 -0
- audiocraft/modules/codebooks_patterns.py +548 -0
- audiocraft/modules/conditioners.py +1416 -0
- audiocraft/modules/conv.py +243 -0
- audiocraft/modules/diffusion_schedule.py +272 -0
- audiocraft/modules/lstm.py +25 -0
- audiocraft/modules/rope.py +125 -0
- audiocraft/modules/seanet.py +258 -0
- audiocraft/modules/streaming.py +131 -0
- audiocraft/modules/transformer.py +755 -0
audiocraft/modules/activations.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from typing import Union, Callable
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CustomGLU(nn.Module):
|
| 14 |
+
"""Custom Gated Linear Unit activation.
|
| 15 |
+
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
|
| 16 |
+
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
|
| 17 |
+
function (i.e. sigmoid, swish, etc.).
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
|
| 21 |
+
dim (int): the dimension on which to split the input. Default: -1
|
| 22 |
+
|
| 23 |
+
Shape:
|
| 24 |
+
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
| 25 |
+
dimensions
|
| 26 |
+
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
| 27 |
+
|
| 28 |
+
Examples::
|
| 29 |
+
>>> m = CustomGLU(nn.Sigmoid())
|
| 30 |
+
>>> input = torch.randn(4, 2)
|
| 31 |
+
>>> output = m(input)
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, activation: nn.Module, dim: int = -1):
|
| 34 |
+
super(CustomGLU, self).__init__()
|
| 35 |
+
self.dim = dim
|
| 36 |
+
self.activation = activation
|
| 37 |
+
|
| 38 |
+
def forward(self, x: Tensor):
|
| 39 |
+
assert x.shape[self.dim] % 2 == 0 # M = N / 2
|
| 40 |
+
a, b = torch.chunk(x, 2, dim=self.dim)
|
| 41 |
+
return a * self.activation(b)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SwiGLU(CustomGLU):
|
| 45 |
+
"""SiLU Gated Linear Unit activation.
|
| 46 |
+
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
|
| 47 |
+
the first half of the input matrices, :math:`b` is the second half.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
dim (int): the dimension on which to split the input. Default: -1
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self, dim: int = -1):
|
| 53 |
+
super(SwiGLU, self).__init__(nn.SiLU(), dim)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class GeGLU(CustomGLU):
|
| 57 |
+
"""GeLU Gated Linear Unit activation.
|
| 58 |
+
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
|
| 59 |
+
the first half of the input matrices, :math:`b` is the second half.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
dim (int): the dimension on which to split the input. Default: -1
|
| 63 |
+
"""
|
| 64 |
+
def __init__(self, dim: int = -1):
|
| 65 |
+
super(GeGLU, self).__init__(nn.GELU(), dim)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ReGLU(CustomGLU):
|
| 69 |
+
"""ReLU Gated Linear Unit activation.
|
| 70 |
+
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
|
| 71 |
+
the first half of the input matrices, :math:`b` is the second half.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dim (int): the dimension on which to split the input. Default: -1
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, dim: int = -1):
|
| 77 |
+
super(ReGLU, self).__init__(nn.ReLU(), dim)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_activation_fn(
|
| 81 |
+
activation: Union[str, Callable[[Tensor], Tensor]]
|
| 82 |
+
) -> Union[str, Callable[[Tensor], Tensor]]:
|
| 83 |
+
"""Helper function to map an activation string to the activation class.
|
| 84 |
+
If the supplied activation is not a string that is recognized, the activation is passed back.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
activation (str, or Callable[[Tensor], Tensor]): Activation to check
|
| 88 |
+
"""
|
| 89 |
+
if isinstance(activation, str):
|
| 90 |
+
if activation == "reglu":
|
| 91 |
+
return ReGLU()
|
| 92 |
+
elif activation == "geglu":
|
| 93 |
+
return GeGLU()
|
| 94 |
+
elif activation == "swiglu":
|
| 95 |
+
return SwiGLU()
|
| 96 |
+
return activation
|
audiocraft/modules/chroma.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import typing as tp
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from librosa import filters
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torchaudio
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ChromaExtractor(nn.Module):
|
| 17 |
+
"""Chroma extraction and quantization.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
sample_rate (int): Sample rate for the chroma extraction.
|
| 21 |
+
n_chroma (int): Number of chroma bins for the chroma extraction.
|
| 22 |
+
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
|
| 23 |
+
nfft (int, optional): Number of FFT.
|
| 24 |
+
winlen (int, optional): Window length.
|
| 25 |
+
winhop (int, optional): Window hop size.
|
| 26 |
+
argmax (bool, optional): Whether to use argmax. Defaults to False.
|
| 27 |
+
norm (float, optional): Norm for chroma normalization. Defaults to inf.
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
|
| 30 |
+
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
|
| 31 |
+
norm: float = torch.inf):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.winlen = winlen or 2 ** radix2_exp
|
| 34 |
+
self.nfft = nfft or self.winlen
|
| 35 |
+
self.winhop = winhop or (self.winlen // 4)
|
| 36 |
+
self.sample_rate = sample_rate
|
| 37 |
+
self.n_chroma = n_chroma
|
| 38 |
+
self.norm = norm
|
| 39 |
+
self.argmax = argmax
|
| 40 |
+
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
|
| 41 |
+
n_chroma=self.n_chroma)), persistent=False)
|
| 42 |
+
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
|
| 43 |
+
hop_length=self.winhop, power=2, center=True,
|
| 44 |
+
pad=0, normalized=True)
|
| 45 |
+
|
| 46 |
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
T = wav.shape[-1]
|
| 48 |
+
# in case we are getting a wav that was dropped out (nullified)
|
| 49 |
+
# from the conditioner, make sure wav length is no less that nfft
|
| 50 |
+
if T < self.nfft:
|
| 51 |
+
pad = self.nfft - T
|
| 52 |
+
r = 0 if pad % 2 == 0 else 1
|
| 53 |
+
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
|
| 54 |
+
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
|
| 55 |
+
|
| 56 |
+
spec = self.spec(wav).squeeze(1)
|
| 57 |
+
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
|
| 58 |
+
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
|
| 59 |
+
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
|
| 60 |
+
|
| 61 |
+
if self.argmax:
|
| 62 |
+
idx = norm_chroma.argmax(-1, keepdim=True)
|
| 63 |
+
norm_chroma[:] = 0
|
| 64 |
+
norm_chroma.scatter_(dim=-1, index=idx, value=1)
|
| 65 |
+
|
| 66 |
+
return norm_chroma
|
audiocraft/modules/codebooks_patterns.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
import logging
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
from abc import ABC, abstractmethod
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
| 17 |
+
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class Pattern:
|
| 23 |
+
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
| 24 |
+
|
| 25 |
+
The codebook pattern consists in a layout, defining for each sequence step
|
| 26 |
+
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
| 27 |
+
The first item of the pattern is always an empty list in order to properly insert a special token
|
| 28 |
+
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
| 29 |
+
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
| 30 |
+
|
| 31 |
+
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
| 32 |
+
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
| 33 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
|
| 34 |
+
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
| 35 |
+
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
| 36 |
+
is returned along with a mask indicating valid tokens.
|
| 37 |
+
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
| 38 |
+
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
| 39 |
+
to fill and specify invalid positions if needed.
|
| 40 |
+
See the dedicated methods for more details.
|
| 41 |
+
"""
|
| 42 |
+
# Pattern layout, for each sequence step, we have a list of coordinates
|
| 43 |
+
# corresponding to the original codebook timestep and position.
|
| 44 |
+
# The first list is always an empty list in order to properly insert
|
| 45 |
+
# a special token to start with.
|
| 46 |
+
layout: PatternLayout
|
| 47 |
+
timesteps: int
|
| 48 |
+
n_q: int
|
| 49 |
+
|
| 50 |
+
def __post_init__(self):
|
| 51 |
+
assert len(self.layout) > 0
|
| 52 |
+
self._validate_layout()
|
| 53 |
+
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
| 54 |
+
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
| 55 |
+
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
| 56 |
+
|
| 57 |
+
def _validate_layout(self):
|
| 58 |
+
"""Runs checks on the layout to ensure a valid pattern is defined.
|
| 59 |
+
A pattern is considered invalid if:
|
| 60 |
+
- Multiple timesteps for a same codebook are defined in the same sequence step
|
| 61 |
+
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
| 62 |
+
(this would mean that we have future timesteps before past timesteps).
|
| 63 |
+
"""
|
| 64 |
+
q_timesteps = {q: 0 for q in range(self.n_q)}
|
| 65 |
+
for s, seq_coords in enumerate(self.layout):
|
| 66 |
+
if len(seq_coords) > 0:
|
| 67 |
+
qs = set()
|
| 68 |
+
for coord in seq_coords:
|
| 69 |
+
qs.add(coord.q)
|
| 70 |
+
last_q_timestep = q_timesteps[coord.q]
|
| 71 |
+
assert coord.t >= last_q_timestep, \
|
| 72 |
+
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
| 73 |
+
q_timesteps[coord.q] = coord.t
|
| 74 |
+
# each sequence step contains at max 1 coordinate per codebook
|
| 75 |
+
assert len(qs) == len(seq_coords), \
|
| 76 |
+
f"Multiple entries for a same codebook are found at step {s}"
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def num_sequence_steps(self):
|
| 80 |
+
return len(self.layout) - 1
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def max_delay(self):
|
| 84 |
+
max_t_in_seq_coords = 0
|
| 85 |
+
for seq_coords in self.layout[1:]:
|
| 86 |
+
for coords in seq_coords:
|
| 87 |
+
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
| 88 |
+
return max_t_in_seq_coords - self.timesteps
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def valid_layout(self):
|
| 92 |
+
valid_step = len(self.layout) - self.max_delay
|
| 93 |
+
return self.layout[:valid_step]
|
| 94 |
+
|
| 95 |
+
def starts_with_special_token(self):
|
| 96 |
+
return self.layout[0] == []
|
| 97 |
+
|
| 98 |
+
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
| 99 |
+
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
| 100 |
+
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
| 101 |
+
and the actual codebook coordinates.
|
| 102 |
+
"""
|
| 103 |
+
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
| 104 |
+
if q is not None:
|
| 105 |
+
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
| 106 |
+
coords = []
|
| 107 |
+
for s, seq_codes in enumerate(self.layout):
|
| 108 |
+
for code in seq_codes:
|
| 109 |
+
if code.t == t and (q is None or code.q == q):
|
| 110 |
+
coords.append((s, code))
|
| 111 |
+
return coords
|
| 112 |
+
|
| 113 |
+
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
| 114 |
+
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
| 115 |
+
|
| 116 |
+
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
| 117 |
+
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
| 118 |
+
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
| 119 |
+
|
| 120 |
+
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
| 121 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
| 122 |
+
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
timesteps (int): Maximum number of timesteps steps to consider.
|
| 126 |
+
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
| 127 |
+
device (torch.device or str): Device for created tensors.
|
| 128 |
+
Returns:
|
| 129 |
+
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
| 130 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
| 131 |
+
"""
|
| 132 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
| 133 |
+
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
| 134 |
+
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
| 135 |
+
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
| 136 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
| 137 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
| 138 |
+
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
| 139 |
+
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
| 140 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
| 141 |
+
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
| 142 |
+
# which will correspond to the index: n_q * timesteps
|
| 143 |
+
indexes[:] = n_q * timesteps
|
| 144 |
+
# iterate over the pattern and fill scattered indexes and mask
|
| 145 |
+
for s, sequence_coords in enumerate(ref_layout):
|
| 146 |
+
for coords in sequence_coords:
|
| 147 |
+
if coords.t < timesteps:
|
| 148 |
+
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
| 149 |
+
mask[coords.q, s] = 1
|
| 150 |
+
indexes = torch.from_numpy(indexes).to(device)
|
| 151 |
+
mask = torch.from_numpy(mask).to(device)
|
| 152 |
+
return indexes, mask
|
| 153 |
+
|
| 154 |
+
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
| 155 |
+
"""Build sequence corresponding to the pattern from the input tensor z.
|
| 156 |
+
The sequence is built using up to sequence_steps if specified, and non-pattern
|
| 157 |
+
coordinates are filled with the special token.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
| 161 |
+
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
| 162 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
| 163 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
| 164 |
+
Returns:
|
| 165 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
| 166 |
+
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
| 167 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
| 168 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
| 169 |
+
"""
|
| 170 |
+
B, K, T = z.shape
|
| 171 |
+
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
| 172 |
+
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
| 173 |
+
)
|
| 174 |
+
z = z.view(B, -1)
|
| 175 |
+
# we append the special token as the last index of our flattened z tensor
|
| 176 |
+
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
| 177 |
+
values = z[:, indexes.view(-1)]
|
| 178 |
+
values = values.view(B, K, indexes.shape[-1])
|
| 179 |
+
return values, indexes, mask
|
| 180 |
+
|
| 181 |
+
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
| 182 |
+
keep_only_valid_steps: bool = False,
|
| 183 |
+
is_model_output: bool = False,
|
| 184 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
| 185 |
+
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
| 186 |
+
from interleaving pattern.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
sequence_steps (int): Sequence steps.
|
| 190 |
+
n_q (int): Number of codebooks.
|
| 191 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
| 192 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
| 193 |
+
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
| 194 |
+
device (torch.device or str): Device for created tensors.
|
| 195 |
+
Returns:
|
| 196 |
+
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
| 197 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
| 198 |
+
"""
|
| 199 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
| 200 |
+
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
| 201 |
+
timesteps = self.timesteps
|
| 202 |
+
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
| 203 |
+
assert sequence_steps <= len(ref_layout), \
|
| 204 |
+
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
| 205 |
+
|
| 206 |
+
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
| 207 |
+
if is_model_output and self.starts_with_special_token():
|
| 208 |
+
ref_layout = ref_layout[1:]
|
| 209 |
+
|
| 210 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
| 211 |
+
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
| 212 |
+
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
| 213 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
| 214 |
+
indexes[:] = n_q * sequence_steps
|
| 215 |
+
for s, sequence_codes in enumerate(ref_layout):
|
| 216 |
+
if s < sequence_steps:
|
| 217 |
+
for code in sequence_codes:
|
| 218 |
+
if code.t < timesteps:
|
| 219 |
+
indexes[code.q, code.t] = s + code.q * sequence_steps
|
| 220 |
+
mask[code.q, code.t] = 1
|
| 221 |
+
indexes = torch.from_numpy(indexes).to(device)
|
| 222 |
+
mask = torch.from_numpy(mask).to(device)
|
| 223 |
+
return indexes, mask
|
| 224 |
+
|
| 225 |
+
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
| 226 |
+
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
| 227 |
+
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
| 228 |
+
are filled with the special token.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
| 232 |
+
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
| 233 |
+
Returns:
|
| 234 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
| 235 |
+
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
| 236 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
| 237 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
| 238 |
+
"""
|
| 239 |
+
B, K, S = s.shape
|
| 240 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
| 241 |
+
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
| 242 |
+
)
|
| 243 |
+
s = s.view(B, -1)
|
| 244 |
+
# we append the special token as the last index of our flattened z tensor
|
| 245 |
+
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
| 246 |
+
values = s[:, indexes.view(-1)]
|
| 247 |
+
values = values.view(B, K, indexes.shape[-1])
|
| 248 |
+
return values, indexes, mask
|
| 249 |
+
|
| 250 |
+
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
| 251 |
+
"""Revert model logits obtained on a sequence built from the pattern
|
| 252 |
+
back to a tensor matching the original sequence.
|
| 253 |
+
|
| 254 |
+
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
| 255 |
+
1. It is designed to work with the extra cardinality dimension
|
| 256 |
+
2. We return the logits for the first sequence item that matches the special_token and
|
| 257 |
+
which matching target in the original sequence is the first item of the sequence,
|
| 258 |
+
while we skip the last logits as there is no matching target
|
| 259 |
+
"""
|
| 260 |
+
B, card, K, S = logits.shape
|
| 261 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
| 262 |
+
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
| 263 |
+
)
|
| 264 |
+
logits = logits.reshape(B, card, -1)
|
| 265 |
+
# we append the special token as the last index of our flattened z tensor
|
| 266 |
+
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
| 267 |
+
values = logits[:, :, indexes.view(-1)]
|
| 268 |
+
values = values.view(B, card, K, indexes.shape[-1])
|
| 269 |
+
return values, indexes, mask
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class CodebooksPatternProvider(ABC):
|
| 273 |
+
"""Abstraction around providing pattern for interleaving codebooks.
|
| 274 |
+
|
| 275 |
+
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
| 276 |
+
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
| 277 |
+
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
| 278 |
+
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
| 279 |
+
can be used to construct a new sequence from the original codes respecting the specified
|
| 280 |
+
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
| 281 |
+
being a tuple with the original timestep and codebook to build the new sequence.
|
| 282 |
+
Note that all patterns must start with an empty list that is then used to insert a first
|
| 283 |
+
sequence step of special tokens in the newly generated sequence.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
n_q (int): number of codebooks.
|
| 287 |
+
cached (bool): if True, patterns for a given length are cached. In general
|
| 288 |
+
that should be true for efficiency reason to avoid synchronization points.
|
| 289 |
+
"""
|
| 290 |
+
def __init__(self, n_q: int, cached: bool = True):
|
| 291 |
+
assert n_q > 0
|
| 292 |
+
self.n_q = n_q
|
| 293 |
+
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
| 294 |
+
|
| 295 |
+
@abstractmethod
|
| 296 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
| 297 |
+
"""Builds pattern with specific interleaving between codebooks.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
timesteps (int): Total number of timesteps.
|
| 301 |
+
"""
|
| 302 |
+
raise NotImplementedError()
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class DelayedPatternProvider(CodebooksPatternProvider):
|
| 306 |
+
"""Provider for delayed pattern across delayed codebooks.
|
| 307 |
+
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
| 308 |
+
from different timesteps.
|
| 309 |
+
|
| 310 |
+
Example:
|
| 311 |
+
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
| 312 |
+
[[1, 2, 3, 4],
|
| 313 |
+
[1, 2, 3, 4],
|
| 314 |
+
[1, 2, 3, 4]]
|
| 315 |
+
The resulting sequence obtained from the returned pattern is:
|
| 316 |
+
[[S, 1, 2, 3, 4],
|
| 317 |
+
[S, S, 1, 2, 3],
|
| 318 |
+
[S, S, S, 1, 2]]
|
| 319 |
+
(with S being a special token)
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
n_q (int): Number of codebooks.
|
| 323 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
| 324 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
| 325 |
+
flatten_first (int): Flatten the first N timesteps.
|
| 326 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
| 327 |
+
"""
|
| 328 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
| 329 |
+
flatten_first: int = 0, empty_initial: int = 0):
|
| 330 |
+
super().__init__(n_q)
|
| 331 |
+
if delays is None:
|
| 332 |
+
delays = list(range(n_q))
|
| 333 |
+
self.delays = delays
|
| 334 |
+
self.flatten_first = flatten_first
|
| 335 |
+
self.empty_initial = empty_initial
|
| 336 |
+
assert len(self.delays) == self.n_q
|
| 337 |
+
assert sorted(self.delays) == self.delays
|
| 338 |
+
|
| 339 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
| 340 |
+
omit_special_token = self.empty_initial < 0
|
| 341 |
+
out: PatternLayout = [] if omit_special_token else [[]]
|
| 342 |
+
max_delay = max(self.delays)
|
| 343 |
+
if self.empty_initial:
|
| 344 |
+
out += [[] for _ in range(self.empty_initial)]
|
| 345 |
+
if self.flatten_first:
|
| 346 |
+
for t in range(min(timesteps, self.flatten_first)):
|
| 347 |
+
for q in range(self.n_q):
|
| 348 |
+
out.append([LayoutCoord(t, q)])
|
| 349 |
+
for t in range(self.flatten_first, timesteps + max_delay):
|
| 350 |
+
v = []
|
| 351 |
+
for q, delay in enumerate(self.delays):
|
| 352 |
+
t_for_q = t - delay
|
| 353 |
+
if t_for_q >= self.flatten_first:
|
| 354 |
+
v.append(LayoutCoord(t_for_q, q))
|
| 355 |
+
out.append(v)
|
| 356 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class ParallelPatternProvider(DelayedPatternProvider):
|
| 360 |
+
"""Provider for parallel pattern across codebooks.
|
| 361 |
+
This pattern provider is a special case of the delayed pattern with actually no delay,
|
| 362 |
+
hence delays=repeat(0, n_q).
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
n_q (int): Number of codebooks.
|
| 366 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
| 367 |
+
"""
|
| 368 |
+
def __init__(self, n_q: int, empty_initial: int = 0):
|
| 369 |
+
super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class UnrolledPatternProvider(CodebooksPatternProvider):
|
| 373 |
+
"""Provider for unrolling codebooks pattern.
|
| 374 |
+
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
| 375 |
+
while also specifying a given delay between the flattened codebooks representation, allowing to
|
| 376 |
+
unroll the codebooks in the sequence.
|
| 377 |
+
|
| 378 |
+
Example:
|
| 379 |
+
1. Flattening of the codebooks.
|
| 380 |
+
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
| 381 |
+
taking n_q = 3 and timesteps = 4:
|
| 382 |
+
[[1, 2, 3, 4],
|
| 383 |
+
[1, 2, 3, 4],
|
| 384 |
+
[1, 2, 3, 4]]
|
| 385 |
+
will result into:
|
| 386 |
+
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
| 387 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
| 388 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
| 389 |
+
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
| 390 |
+
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
| 391 |
+
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
| 392 |
+
[[1, 2, 3, 4],
|
| 393 |
+
[1, 2, 3, 4],
|
| 394 |
+
[1, 2, 3, 4]]
|
| 395 |
+
will result into:
|
| 396 |
+
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
| 397 |
+
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
| 398 |
+
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
| 399 |
+
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
| 400 |
+
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
| 401 |
+
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
| 402 |
+
and delays = [0, 3, 3]:
|
| 403 |
+
[[1, 2, 3, 4],
|
| 404 |
+
[1, 2, 3, 4],
|
| 405 |
+
[1, 2, 3, 4]]
|
| 406 |
+
will result into:
|
| 407 |
+
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
| 408 |
+
[S, S, S, 1, S, 2, S, 3, S, 4],
|
| 409 |
+
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
n_q (int): Number of codebooks.
|
| 413 |
+
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
|
| 414 |
+
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
| 415 |
+
have n_q extra steps for each timestep.
|
| 416 |
+
delays (list of int, optional): Delay for each of the codebooks. If not defined,
|
| 417 |
+
no delay is added and therefore will default to [0] * ``n_q``.
|
| 418 |
+
Note that two codebooks that will be flattened to the same inner step
|
| 419 |
+
should have the same delay, otherwise the pattern is considered as invalid.
|
| 420 |
+
"""
|
| 421 |
+
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
| 422 |
+
|
| 423 |
+
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
| 424 |
+
delays: tp.Optional[tp.List[int]] = None):
|
| 425 |
+
super().__init__(n_q)
|
| 426 |
+
if flattening is None:
|
| 427 |
+
flattening = list(range(n_q))
|
| 428 |
+
if delays is None:
|
| 429 |
+
delays = [0] * n_q
|
| 430 |
+
assert len(flattening) == n_q
|
| 431 |
+
assert len(delays) == n_q
|
| 432 |
+
assert sorted(flattening) == flattening
|
| 433 |
+
assert sorted(delays) == delays
|
| 434 |
+
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
| 435 |
+
self.max_delay = max(delays)
|
| 436 |
+
|
| 437 |
+
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
| 438 |
+
"""Build a flattened codebooks representation as a dictionary of inner step
|
| 439 |
+
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
| 440 |
+
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
| 441 |
+
"""
|
| 442 |
+
flattened_codebooks: dict = {}
|
| 443 |
+
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
| 444 |
+
if inner_step not in flattened_codebooks:
|
| 445 |
+
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
| 446 |
+
else:
|
| 447 |
+
flat_codebook = flattened_codebooks[inner_step]
|
| 448 |
+
assert flat_codebook.delay == delay, (
|
| 449 |
+
"Delay and flattening between codebooks is inconsistent: ",
|
| 450 |
+
"two codebooks flattened to the same position should have the same delay."
|
| 451 |
+
)
|
| 452 |
+
flat_codebook.codebooks.append(q)
|
| 453 |
+
flattened_codebooks[inner_step] = flat_codebook
|
| 454 |
+
return flattened_codebooks
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def _num_inner_steps(self):
|
| 458 |
+
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
| 459 |
+
"""
|
| 460 |
+
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
| 461 |
+
|
| 462 |
+
def num_virtual_steps(self, timesteps: int) -> int:
|
| 463 |
+
return timesteps * self._num_inner_steps + 1
|
| 464 |
+
|
| 465 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
| 466 |
+
"""Builds pattern for delay across codebooks.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
timesteps (int): Total number of timesteps.
|
| 470 |
+
"""
|
| 471 |
+
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
| 472 |
+
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
| 473 |
+
indexed_out: list = [(-1, [])]
|
| 474 |
+
max_timesteps = timesteps + self.max_delay
|
| 475 |
+
for t in range(max_timesteps):
|
| 476 |
+
# for each timestep, we unroll the flattened codebooks,
|
| 477 |
+
# emitting the sequence step with the corresponding delay
|
| 478 |
+
for step in range(self._num_inner_steps):
|
| 479 |
+
if step in self._flattened_codebooks:
|
| 480 |
+
# we have codebooks at this virtual step to emit
|
| 481 |
+
step_codebooks = self._flattened_codebooks[step]
|
| 482 |
+
t_for_q = t + step_codebooks.delay
|
| 483 |
+
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
| 484 |
+
if t_for_q < max_timesteps and t < max_timesteps:
|
| 485 |
+
indexed_out.append((t_for_q, coords))
|
| 486 |
+
else:
|
| 487 |
+
# there is no codebook in this virtual step so we emit an empty list
|
| 488 |
+
indexed_out.append((t, []))
|
| 489 |
+
out = [coords for _, coords in sorted(indexed_out)]
|
| 490 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class CoarseFirstPattern(CodebooksPatternProvider):
|
| 494 |
+
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
|
| 495 |
+
potentially with delays.
|
| 496 |
+
|
| 497 |
+
..Warning:: You must always generate the full training duration at test time, for instance,
|
| 498 |
+
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
|
| 499 |
+
location. This is due to the non causality of the remaining codebooks with respect to
|
| 500 |
+
the first ones.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
n_q (int): Number of codebooks.
|
| 504 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
| 505 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
| 506 |
+
"""
|
| 507 |
+
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
| 508 |
+
super().__init__(n_q)
|
| 509 |
+
if delays is None:
|
| 510 |
+
delays = [0] * (n_q - 1)
|
| 511 |
+
self.delays = delays
|
| 512 |
+
assert len(self.delays) == self.n_q - 1
|
| 513 |
+
assert sorted(self.delays) == self.delays
|
| 514 |
+
|
| 515 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
| 516 |
+
out: PatternLayout = [[]]
|
| 517 |
+
for t in range(timesteps):
|
| 518 |
+
out.append([LayoutCoord(t, 0)])
|
| 519 |
+
max_delay = max(self.delays)
|
| 520 |
+
for t in range(timesteps + max_delay):
|
| 521 |
+
v = []
|
| 522 |
+
for q, delay in enumerate(self.delays):
|
| 523 |
+
t_for_q = t - delay
|
| 524 |
+
if t_for_q >= 0:
|
| 525 |
+
v.append(LayoutCoord(t_for_q, q + 1))
|
| 526 |
+
out.append(v)
|
| 527 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class MusicLMPattern(CodebooksPatternProvider):
|
| 531 |
+
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
| 532 |
+
but in a different order.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
n_q (int): Number of codebooks.
|
| 536 |
+
group_by (int): Number of codebooks to group together.
|
| 537 |
+
"""
|
| 538 |
+
def __init__(self, n_q: int, group_by: int = 2):
|
| 539 |
+
super().__init__(n_q)
|
| 540 |
+
self.group_by = group_by
|
| 541 |
+
|
| 542 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
| 543 |
+
out: PatternLayout = [[]]
|
| 544 |
+
for offset in range(0, self.n_q, self.group_by):
|
| 545 |
+
for t in range(timesteps):
|
| 546 |
+
for q in range(offset, offset + self.group_by):
|
| 547 |
+
out.append([LayoutCoord(t, q)])
|
| 548 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
audiocraft/modules/conditioners.py
ADDED
|
@@ -0,0 +1,1416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from itertools import chain
|
| 11 |
+
import logging
|
| 12 |
+
import math
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import random
|
| 15 |
+
import re
|
| 16 |
+
import typing as tp
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
import einops
|
| 20 |
+
from num2words import num2words
|
| 21 |
+
import spacy
|
| 22 |
+
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
|
| 23 |
+
import torch
|
| 24 |
+
from torch import nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 27 |
+
|
| 28 |
+
from .chroma import ChromaExtractor
|
| 29 |
+
from .streaming import StreamingModule
|
| 30 |
+
from .transformer import create_sin_embedding
|
| 31 |
+
from ..data.audio import audio_read
|
| 32 |
+
from ..data.audio_dataset import SegmentInfo
|
| 33 |
+
from ..data.audio_utils import convert_audio
|
| 34 |
+
from ..environment import AudioCraftEnvironment
|
| 35 |
+
from ..quantization import ResidualVectorQuantizer
|
| 36 |
+
from ..utils.autocast import TorchAutocast
|
| 37 |
+
from ..utils.cache import EmbeddingCache
|
| 38 |
+
from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
| 43 |
+
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class WavCondition(tp.NamedTuple):
|
| 47 |
+
wav: torch.Tensor
|
| 48 |
+
length: torch.Tensor
|
| 49 |
+
sample_rate: tp.List[int]
|
| 50 |
+
path: tp.List[tp.Optional[str]] = []
|
| 51 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class JointEmbedCondition(tp.NamedTuple):
|
| 55 |
+
wav: torch.Tensor
|
| 56 |
+
text: tp.List[tp.Optional[str]]
|
| 57 |
+
length: torch.Tensor
|
| 58 |
+
sample_rate: tp.List[int]
|
| 59 |
+
path: tp.List[tp.Optional[str]] = []
|
| 60 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class ConditioningAttributes:
|
| 65 |
+
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
| 66 |
+
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
| 67 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, item):
|
| 70 |
+
return getattr(self, item)
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def text_attributes(self):
|
| 74 |
+
return self.text.keys()
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def wav_attributes(self):
|
| 78 |
+
return self.wav.keys()
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def joint_embed_attributes(self):
|
| 82 |
+
return self.joint_embed.keys()
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def attributes(self):
|
| 86 |
+
return {
|
| 87 |
+
"text": self.text_attributes,
|
| 88 |
+
"wav": self.wav_attributes,
|
| 89 |
+
"joint_embed": self.joint_embed_attributes,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def to_flat_dict(self):
|
| 93 |
+
return {
|
| 94 |
+
**{f"text.{k}": v for k, v in self.text.items()},
|
| 95 |
+
**{f"wav.{k}": v for k, v in self.wav.items()},
|
| 96 |
+
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def from_flat_dict(cls, x):
|
| 101 |
+
out = cls()
|
| 102 |
+
for k, v in x.items():
|
| 103 |
+
kind, att = k.split(".")
|
| 104 |
+
out[kind][att] = v
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class SegmentWithAttributes(SegmentInfo):
|
| 109 |
+
"""Base class for all dataclasses that are used for conditioning.
|
| 110 |
+
All child classes should implement `to_condition_attributes` that converts
|
| 111 |
+
the existing attributes to a dataclass of type ConditioningAttributes.
|
| 112 |
+
"""
|
| 113 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
| 114 |
+
raise NotImplementedError()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def nullify_condition(condition: ConditionType, dim: int = 1):
|
| 118 |
+
"""Transform an input condition to a null condition.
|
| 119 |
+
The way it is done by converting it to a single zero vector similarly
|
| 120 |
+
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
| 124 |
+
dim (int): The dimension that will be truncated (should be the time dimension)
|
| 125 |
+
WARNING!: dim should not be the batch dimension!
|
| 126 |
+
Returns:
|
| 127 |
+
ConditionType: A tuple of null condition and mask
|
| 128 |
+
"""
|
| 129 |
+
assert dim != 0, "dim cannot be the batch dimension!"
|
| 130 |
+
assert isinstance(condition, tuple) and \
|
| 131 |
+
isinstance(condition[0], torch.Tensor) and \
|
| 132 |
+
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
| 133 |
+
cond, mask = condition
|
| 134 |
+
B = cond.shape[0]
|
| 135 |
+
last_dim = cond.dim() - 1
|
| 136 |
+
out = cond.transpose(dim, last_dim)
|
| 137 |
+
out = 0. * out[..., :1]
|
| 138 |
+
out = out.transpose(dim, last_dim)
|
| 139 |
+
mask = torch.zeros((B, 1), device=out.device).int()
|
| 140 |
+
assert cond.dim() == out.dim()
|
| 141 |
+
return out, mask
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def nullify_wav(cond: WavCondition) -> WavCondition:
|
| 145 |
+
"""Transform a WavCondition to a nullified WavCondition.
|
| 146 |
+
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
| 150 |
+
Returns:
|
| 151 |
+
WavCondition: Nullified wav condition.
|
| 152 |
+
"""
|
| 153 |
+
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
|
| 154 |
+
return WavCondition(
|
| 155 |
+
wav=null_wav,
|
| 156 |
+
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
| 157 |
+
sample_rate=cond.sample_rate,
|
| 158 |
+
path=[None] * cond.wav.shape[0],
|
| 159 |
+
seek_time=[None] * cond.wav.shape[0],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
| 164 |
+
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
| 165 |
+
and replacing metadata by dummy attributes.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
|
| 169 |
+
"""
|
| 170 |
+
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
| 171 |
+
return JointEmbedCondition(
|
| 172 |
+
wav=null_wav, text=[None] * len(embed.text),
|
| 173 |
+
length=torch.LongTensor([0]).to(embed.wav.device),
|
| 174 |
+
sample_rate=embed.sample_rate,
|
| 175 |
+
path=[None] * embed.wav.shape[0],
|
| 176 |
+
seek_time=[0] * embed.wav.shape[0],
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class Tokenizer:
|
| 181 |
+
"""Base tokenizer implementation
|
| 182 |
+
(in case we want to introduce more advances tokenizers in the future).
|
| 183 |
+
"""
|
| 184 |
+
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 185 |
+
raise NotImplementedError()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class WhiteSpaceTokenizer(Tokenizer):
|
| 189 |
+
"""This tokenizer should be used for natural language descriptions.
|
| 190 |
+
For example:
|
| 191 |
+
["he didn't, know he's going home.", 'shorter sentence'] =>
|
| 192 |
+
[[78, 62, 31, 4, 78, 25, 19, 34],
|
| 193 |
+
[59, 77, 0, 0, 0, 0, 0, 0]]
|
| 194 |
+
"""
|
| 195 |
+
PUNCTUATION = "?:!.,;"
|
| 196 |
+
|
| 197 |
+
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
| 198 |
+
lemma: bool = True, stopwords: bool = True) -> None:
|
| 199 |
+
self.n_bins = n_bins
|
| 200 |
+
self.pad_idx = pad_idx
|
| 201 |
+
self.lemma = lemma
|
| 202 |
+
self.stopwords = stopwords
|
| 203 |
+
try:
|
| 204 |
+
self.nlp = spacy.load(language)
|
| 205 |
+
except IOError:
|
| 206 |
+
spacy.cli.download(language) # type: ignore
|
| 207 |
+
self.nlp = spacy.load(language)
|
| 208 |
+
|
| 209 |
+
@tp.no_type_check
|
| 210 |
+
def __call__(self, texts: tp.List[tp.Optional[str]],
|
| 211 |
+
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 212 |
+
"""Take a list of strings and convert them to a tensor of indices.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
texts (list[str]): List of strings.
|
| 216 |
+
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
| 217 |
+
Returns:
|
| 218 |
+
tuple[torch.Tensor, torch.Tensor]:
|
| 219 |
+
- Indices of words in the LUT.
|
| 220 |
+
- And a mask indicating where the padding tokens are
|
| 221 |
+
"""
|
| 222 |
+
output, lengths = [], []
|
| 223 |
+
texts = deepcopy(texts)
|
| 224 |
+
for i, text in enumerate(texts):
|
| 225 |
+
# if current sample doesn't have a certain attribute, replace with pad token
|
| 226 |
+
if text is None:
|
| 227 |
+
output.append(torch.Tensor([self.pad_idx]))
|
| 228 |
+
lengths.append(0)
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
# convert numbers to words
|
| 232 |
+
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
|
| 233 |
+
# normalize text
|
| 234 |
+
text = self.nlp(text) # type: ignore
|
| 235 |
+
# remove stopwords
|
| 236 |
+
if self.stopwords:
|
| 237 |
+
text = [w for w in text if not w.is_stop] # type: ignore
|
| 238 |
+
# remove punctuation
|
| 239 |
+
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
|
| 240 |
+
# lemmatize if needed
|
| 241 |
+
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
| 242 |
+
|
| 243 |
+
texts[i] = " ".join(text)
|
| 244 |
+
lengths.append(len(text))
|
| 245 |
+
# convert to tensor
|
| 246 |
+
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
|
| 247 |
+
output.append(tokens)
|
| 248 |
+
|
| 249 |
+
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
| 250 |
+
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
|
| 251 |
+
if return_text:
|
| 252 |
+
return padded_output, mask, texts # type: ignore
|
| 253 |
+
return padded_output, mask
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class NoopTokenizer(Tokenizer):
|
| 257 |
+
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
|
| 258 |
+
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
|
| 259 |
+
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
|
| 260 |
+
split it to ["Jeff", "Buckley"] and return an index per word.
|
| 261 |
+
|
| 262 |
+
For example:
|
| 263 |
+
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
|
| 264 |
+
["Metal", "Rock", "Classical"] => [0, 223, 51]
|
| 265 |
+
"""
|
| 266 |
+
def __init__(self, n_bins: int, pad_idx: int = 0):
|
| 267 |
+
self.n_bins = n_bins
|
| 268 |
+
self.pad_idx = pad_idx
|
| 269 |
+
|
| 270 |
+
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 271 |
+
output, lengths = [], []
|
| 272 |
+
for text in texts:
|
| 273 |
+
# if current sample doesn't have a certain attribute, replace with pad token
|
| 274 |
+
if text is None:
|
| 275 |
+
output.append(self.pad_idx)
|
| 276 |
+
lengths.append(0)
|
| 277 |
+
else:
|
| 278 |
+
output.append(hash_trick(text, self.n_bins))
|
| 279 |
+
lengths.append(1)
|
| 280 |
+
|
| 281 |
+
tokens = torch.LongTensor(output).unsqueeze(1)
|
| 282 |
+
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
| 283 |
+
return tokens, mask
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class BaseConditioner(nn.Module):
|
| 287 |
+
"""Base model for all conditioner modules.
|
| 288 |
+
We allow the output dim to be different than the hidden dim for two reasons:
|
| 289 |
+
1) keep our LUTs small when the vocab is large;
|
| 290 |
+
2) make all condition dims consistent.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
dim (int): Hidden dim of the model.
|
| 294 |
+
output_dim (int): Output dim of the conditioner.
|
| 295 |
+
"""
|
| 296 |
+
def __init__(self, dim: int, output_dim: int):
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.dim = dim
|
| 299 |
+
self.output_dim = output_dim
|
| 300 |
+
self.output_proj = nn.Linear(dim, output_dim)
|
| 301 |
+
|
| 302 |
+
def tokenize(self, *args, **kwargs) -> tp.Any:
|
| 303 |
+
"""Should be any part of the processing that will lead to a synchronization
|
| 304 |
+
point, e.g. BPE tokenization with transfer to the GPU.
|
| 305 |
+
|
| 306 |
+
The returned value will be saved and return later when calling forward().
|
| 307 |
+
"""
|
| 308 |
+
raise NotImplementedError()
|
| 309 |
+
|
| 310 |
+
def forward(self, inputs: tp.Any) -> ConditionType:
|
| 311 |
+
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
| 312 |
+
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
ConditionType:
|
| 316 |
+
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
| 317 |
+
output embedding and D is the dimension of the embedding.
|
| 318 |
+
- And a mask indicating where the padding tokens.
|
| 319 |
+
"""
|
| 320 |
+
raise NotImplementedError()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class TextConditioner(BaseConditioner):
|
| 324 |
+
...
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class LUTConditioner(TextConditioner):
|
| 328 |
+
"""Lookup table TextConditioner.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
n_bins (int): Number of bins.
|
| 332 |
+
dim (int): Hidden dim of the model (text-encoder/LUT).
|
| 333 |
+
output_dim (int): Output dim of the conditioner.
|
| 334 |
+
tokenizer (str): Name of the tokenizer.
|
| 335 |
+
pad_idx (int, optional): Index for padding token. Defaults to 0.
|
| 336 |
+
"""
|
| 337 |
+
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
|
| 338 |
+
super().__init__(dim, output_dim)
|
| 339 |
+
self.embed = nn.Embedding(n_bins, dim)
|
| 340 |
+
self.tokenizer: Tokenizer
|
| 341 |
+
if tokenizer == 'whitespace':
|
| 342 |
+
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
| 343 |
+
elif tokenizer == 'noop':
|
| 344 |
+
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
| 345 |
+
else:
|
| 346 |
+
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
| 347 |
+
|
| 348 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 349 |
+
device = self.embed.weight.device
|
| 350 |
+
tokens, mask = self.tokenizer(x)
|
| 351 |
+
tokens, mask = tokens.to(device), mask.to(device)
|
| 352 |
+
return tokens, mask
|
| 353 |
+
|
| 354 |
+
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
|
| 355 |
+
tokens, mask = inputs
|
| 356 |
+
embeds = self.embed(tokens)
|
| 357 |
+
embeds = self.output_proj(embeds)
|
| 358 |
+
embeds = (embeds * mask.unsqueeze(-1))
|
| 359 |
+
return embeds, mask
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class T5Conditioner(TextConditioner):
|
| 363 |
+
"""T5-based TextConditioner.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
name (str): Name of the T5 model.
|
| 367 |
+
output_dim (int): Output dim of the conditioner.
|
| 368 |
+
finetune (bool): Whether to fine-tune T5 at train time.
|
| 369 |
+
device (str): Device for T5 Conditioner.
|
| 370 |
+
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
|
| 371 |
+
word_dropout (float, optional): Word dropout probability.
|
| 372 |
+
normalize_text (bool, optional): Whether to apply text normalization.
|
| 373 |
+
"""
|
| 374 |
+
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
| 375 |
+
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
| 376 |
+
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
| 377 |
+
MODELS_DIMS = {
|
| 378 |
+
"t5-small": 512,
|
| 379 |
+
"t5-base": 768,
|
| 380 |
+
"t5-large": 1024,
|
| 381 |
+
"t5-3b": 1024,
|
| 382 |
+
"t5-11b": 1024,
|
| 383 |
+
"google/flan-t5-small": 512,
|
| 384 |
+
"google/flan-t5-base": 768,
|
| 385 |
+
"google/flan-t5-large": 1024,
|
| 386 |
+
"google/flan-t5-3b": 1024,
|
| 387 |
+
"google/flan-t5-11b": 1024,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
| 391 |
+
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
| 392 |
+
normalize_text: bool = False):
|
| 393 |
+
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
| 394 |
+
super().__init__(self.MODELS_DIMS[name], output_dim)
|
| 395 |
+
self.device = device
|
| 396 |
+
self.name = name
|
| 397 |
+
self.finetune = finetune
|
| 398 |
+
self.word_dropout = word_dropout
|
| 399 |
+
if autocast_dtype is None or self.device == 'cpu':
|
| 400 |
+
self.autocast = TorchAutocast(enabled=False)
|
| 401 |
+
if self.device != 'cpu':
|
| 402 |
+
logger.warning("T5 has no autocast, this might lead to NaN")
|
| 403 |
+
else:
|
| 404 |
+
dtype = getattr(torch, autocast_dtype)
|
| 405 |
+
assert isinstance(dtype, torch.dtype)
|
| 406 |
+
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
|
| 407 |
+
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
| 408 |
+
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
|
| 409 |
+
# thanks https://gist.github.com/simon-weber/7853144
|
| 410 |
+
previous_level = logging.root.manager.disable
|
| 411 |
+
logging.disable(logging.ERROR)
|
| 412 |
+
with warnings.catch_warnings():
|
| 413 |
+
warnings.simplefilter("ignore")
|
| 414 |
+
try:
|
| 415 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
| 416 |
+
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
|
| 417 |
+
finally:
|
| 418 |
+
logging.disable(previous_level)
|
| 419 |
+
if finetune:
|
| 420 |
+
self.t5 = t5
|
| 421 |
+
else:
|
| 422 |
+
# this makes sure that the t5 models is not part
|
| 423 |
+
# of the saved checkpoint
|
| 424 |
+
self.__dict__['t5'] = t5.to(device)
|
| 425 |
+
|
| 426 |
+
self.normalize_text = normalize_text
|
| 427 |
+
if normalize_text:
|
| 428 |
+
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
|
| 429 |
+
|
| 430 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
| 431 |
+
# if current sample doesn't have a certain attribute, replace with empty string
|
| 432 |
+
entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
|
| 433 |
+
if self.normalize_text:
|
| 434 |
+
_, _, entries = self.text_normalizer(entries, return_text=True)
|
| 435 |
+
if self.word_dropout > 0. and self.training:
|
| 436 |
+
new_entries = []
|
| 437 |
+
for entry in entries:
|
| 438 |
+
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
|
| 439 |
+
new_entries.append(" ".join(words))
|
| 440 |
+
entries = new_entries
|
| 441 |
+
|
| 442 |
+
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
| 443 |
+
|
| 444 |
+
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
|
| 445 |
+
mask = inputs['attention_mask']
|
| 446 |
+
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
| 447 |
+
return inputs
|
| 448 |
+
|
| 449 |
+
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
| 450 |
+
mask = inputs['attention_mask']
|
| 451 |
+
with torch.set_grad_enabled(self.finetune), self.autocast:
|
| 452 |
+
embeds = self.t5(**inputs).last_hidden_state
|
| 453 |
+
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
| 454 |
+
embeds = (embeds * mask.unsqueeze(-1))
|
| 455 |
+
return embeds, mask
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class WaveformConditioner(BaseConditioner):
|
| 459 |
+
"""Base class for all conditioners that take a waveform as input.
|
| 460 |
+
Classes that inherit must implement `_get_wav_embedding` that outputs
|
| 461 |
+
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
| 462 |
+
factor of the embedding model.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
dim (int): The internal representation dimension.
|
| 466 |
+
output_dim (int): Output dimension.
|
| 467 |
+
device (tp.Union[torch.device, str]): Device.
|
| 468 |
+
"""
|
| 469 |
+
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
| 470 |
+
super().__init__(dim, output_dim)
|
| 471 |
+
self.device = device
|
| 472 |
+
# if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
|
| 473 |
+
self._use_masking = True
|
| 474 |
+
|
| 475 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
| 476 |
+
wav, length, sample_rate, path, seek_time = x
|
| 477 |
+
assert length is not None
|
| 478 |
+
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
|
| 479 |
+
|
| 480 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
| 481 |
+
"""Gets as input a WavCondition and returns a dense embedding."""
|
| 482 |
+
raise NotImplementedError()
|
| 483 |
+
|
| 484 |
+
def _downsampling_factor(self):
|
| 485 |
+
"""Returns the downsampling factor of the embedding model."""
|
| 486 |
+
raise NotImplementedError()
|
| 487 |
+
|
| 488 |
+
def forward(self, x: WavCondition) -> ConditionType:
|
| 489 |
+
"""Extract condition embedding and mask from a waveform and its metadata.
|
| 490 |
+
Args:
|
| 491 |
+
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
| 492 |
+
Returns:
|
| 493 |
+
ConditionType: a dense vector representing the conditioning along with its mask
|
| 494 |
+
"""
|
| 495 |
+
wav, lengths, *_ = x
|
| 496 |
+
with torch.no_grad():
|
| 497 |
+
embeds = self._get_wav_embedding(x)
|
| 498 |
+
embeds = embeds.to(self.output_proj.weight)
|
| 499 |
+
embeds = self.output_proj(embeds)
|
| 500 |
+
|
| 501 |
+
if lengths is not None and self._use_masking:
|
| 502 |
+
lengths = lengths / self._downsampling_factor()
|
| 503 |
+
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
| 504 |
+
else:
|
| 505 |
+
mask = torch.ones_like(embeds[..., 0])
|
| 506 |
+
embeds = (embeds * mask.unsqueeze(-1))
|
| 507 |
+
return embeds, mask
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class ChromaStemConditioner(WaveformConditioner):
|
| 511 |
+
"""Chroma conditioner based on stems.
|
| 512 |
+
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
|
| 513 |
+
the drums and bass often dominate the chroma leading to the chroma features
|
| 514 |
+
not containing information about the melody.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
output_dim (int): Output dimension for the conditioner.
|
| 518 |
+
sample_rate (int): Sample rate for the chroma extractor.
|
| 519 |
+
n_chroma (int): Number of chroma bins for the chroma extractor.
|
| 520 |
+
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
|
| 521 |
+
duration (int): duration used during training. This is later used for correct padding
|
| 522 |
+
in case we are using chroma as prefix.
|
| 523 |
+
match_len_on_eval (bool, optional): if True then all chromas are padded to the training
|
| 524 |
+
duration. Defaults to False.
|
| 525 |
+
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
|
| 526 |
+
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
| 527 |
+
Defaults to None.
|
| 528 |
+
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
|
| 529 |
+
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
| 530 |
+
**kwargs: Additional parameters for the chroma extractor.
|
| 531 |
+
"""
|
| 532 |
+
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
| 533 |
+
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
| 534 |
+
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
| 535 |
+
device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
| 536 |
+
from demucs import pretrained
|
| 537 |
+
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
| 538 |
+
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
| 539 |
+
self.sample_rate = sample_rate
|
| 540 |
+
self.match_len_on_eval = match_len_on_eval
|
| 541 |
+
if match_len_on_eval:
|
| 542 |
+
self._use_masking = False
|
| 543 |
+
self.duration = duration
|
| 544 |
+
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
| 545 |
+
stem_sources: list = self.demucs.sources # type: ignore
|
| 546 |
+
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
|
| 547 |
+
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
|
| 548 |
+
radix2_exp=radix2_exp, **kwargs).to(device)
|
| 549 |
+
self.chroma_len = self._get_chroma_len()
|
| 550 |
+
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
|
| 551 |
+
self.cache = None
|
| 552 |
+
if cache_path is not None:
|
| 553 |
+
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
| 554 |
+
compute_embed_fn=self._get_full_chroma_for_cache,
|
| 555 |
+
extract_embed_fn=self._extract_chroma_chunk)
|
| 556 |
+
|
| 557 |
+
def _downsampling_factor(self) -> int:
|
| 558 |
+
return self.chroma.winhop
|
| 559 |
+
|
| 560 |
+
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
|
| 561 |
+
"""Load pre-defined waveforms from a json.
|
| 562 |
+
These waveforms will be used for chroma extraction during evaluation.
|
| 563 |
+
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
|
| 564 |
+
"""
|
| 565 |
+
if path is None:
|
| 566 |
+
return None
|
| 567 |
+
|
| 568 |
+
logger.info(f"Loading evaluation wavs from {path}")
|
| 569 |
+
from audiocraft.data.audio_dataset import AudioDataset
|
| 570 |
+
dataset: AudioDataset = AudioDataset.from_meta(
|
| 571 |
+
path, segment_duration=self.duration, min_audio_duration=self.duration,
|
| 572 |
+
sample_rate=self.sample_rate, channels=1)
|
| 573 |
+
|
| 574 |
+
if len(dataset) > 0:
|
| 575 |
+
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
|
| 576 |
+
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
|
| 577 |
+
return eval_wavs
|
| 578 |
+
else:
|
| 579 |
+
raise ValueError("Could not find evaluation wavs, check lengths of wavs")
|
| 580 |
+
|
| 581 |
+
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
|
| 582 |
+
self.eval_wavs = eval_wavs
|
| 583 |
+
|
| 584 |
+
def has_eval_wavs(self) -> bool:
|
| 585 |
+
return self.eval_wavs is not None
|
| 586 |
+
|
| 587 |
+
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
|
| 588 |
+
"""Sample wavs from a predefined list."""
|
| 589 |
+
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
|
| 590 |
+
total_eval_wavs = len(self.eval_wavs)
|
| 591 |
+
out = self.eval_wavs
|
| 592 |
+
if num_samples > total_eval_wavs:
|
| 593 |
+
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
|
| 594 |
+
return out[torch.randperm(len(out))][:num_samples]
|
| 595 |
+
|
| 596 |
+
def _get_chroma_len(self) -> int:
|
| 597 |
+
"""Get length of chroma during training."""
|
| 598 |
+
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
|
| 599 |
+
dummy_chr = self.chroma(dummy_wav)
|
| 600 |
+
return dummy_chr.shape[1]
|
| 601 |
+
|
| 602 |
+
@torch.no_grad()
|
| 603 |
+
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 604 |
+
"""Get parts of the wav that holds the melody, extracting the main stems from the wav."""
|
| 605 |
+
from demucs.apply import apply_model
|
| 606 |
+
from demucs.audio import convert_audio
|
| 607 |
+
with self.autocast:
|
| 608 |
+
wav = convert_audio(
|
| 609 |
+
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
| 610 |
+
stems = apply_model(self.demucs, wav, device=self.device)
|
| 611 |
+
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
| 612 |
+
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
| 613 |
+
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
| 614 |
+
return mix_wav
|
| 615 |
+
|
| 616 |
+
@torch.no_grad()
|
| 617 |
+
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
|
| 618 |
+
"""Extract chroma features from the waveform."""
|
| 619 |
+
with self.autocast:
|
| 620 |
+
return self.chroma(wav)
|
| 621 |
+
|
| 622 |
+
@torch.no_grad()
|
| 623 |
+
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
| 624 |
+
"""Compute wav embedding, applying stem and chroma extraction."""
|
| 625 |
+
# avoid 0-size tensors when we are working with null conds
|
| 626 |
+
if wav.shape[-1] == 1:
|
| 627 |
+
return self._extract_chroma(wav)
|
| 628 |
+
stems = self._get_stemmed_wav(wav, sample_rate)
|
| 629 |
+
chroma = self._extract_chroma(stems)
|
| 630 |
+
return chroma
|
| 631 |
+
|
| 632 |
+
@torch.no_grad()
|
| 633 |
+
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
|
| 634 |
+
"""Extract chroma from the whole audio waveform at the given path."""
|
| 635 |
+
wav, sr = audio_read(path)
|
| 636 |
+
wav = wav[None].to(self.device)
|
| 637 |
+
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
| 638 |
+
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
|
| 639 |
+
return chroma
|
| 640 |
+
|
| 641 |
+
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
| 642 |
+
"""Extract a chunk of chroma from the full chroma derived from the full waveform."""
|
| 643 |
+
wav_length = x.wav.shape[-1]
|
| 644 |
+
seek_time = x.seek_time[idx]
|
| 645 |
+
assert seek_time is not None, (
|
| 646 |
+
"WavCondition seek_time is required "
|
| 647 |
+
"when extracting chroma chunks from pre-computed chroma.")
|
| 648 |
+
full_chroma = full_chroma.float()
|
| 649 |
+
frame_rate = self.sample_rate / self._downsampling_factor()
|
| 650 |
+
target_length = int(frame_rate * wav_length / self.sample_rate)
|
| 651 |
+
index = int(frame_rate * seek_time)
|
| 652 |
+
out = full_chroma[index: index + target_length]
|
| 653 |
+
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
|
| 654 |
+
return out.to(self.device)
|
| 655 |
+
|
| 656 |
+
@torch.no_grad()
|
| 657 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
| 658 |
+
"""Get the wav embedding from the WavCondition.
|
| 659 |
+
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
|
| 660 |
+
or will rely on the embedding cache to load the pre-computed embedding if relevant.
|
| 661 |
+
"""
|
| 662 |
+
sampled_wav: tp.Optional[torch.Tensor] = None
|
| 663 |
+
if not self.training and self.eval_wavs is not None:
|
| 664 |
+
warn_once(logger, "Using precomputed evaluation wavs!")
|
| 665 |
+
sampled_wav = self._sample_eval_wavs(len(x.wav))
|
| 666 |
+
|
| 667 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
| 668 |
+
no_nullified_cond = x.wav.shape[-1] > 1
|
| 669 |
+
if sampled_wav is not None:
|
| 670 |
+
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
|
| 671 |
+
elif self.cache is not None and no_undefined_paths and no_nullified_cond:
|
| 672 |
+
paths = [Path(p) for p in x.path if p is not None]
|
| 673 |
+
chroma = self.cache.get_embed_from_cache(paths, x)
|
| 674 |
+
else:
|
| 675 |
+
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
| 676 |
+
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
|
| 677 |
+
|
| 678 |
+
if self.match_len_on_eval:
|
| 679 |
+
B, T, C = chroma.shape
|
| 680 |
+
if T > self.chroma_len:
|
| 681 |
+
chroma = chroma[:, :self.chroma_len]
|
| 682 |
+
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
|
| 683 |
+
elif T < self.chroma_len:
|
| 684 |
+
n_repeat = int(math.ceil(self.chroma_len / T))
|
| 685 |
+
chroma = chroma.repeat(1, n_repeat, 1)
|
| 686 |
+
chroma = chroma[:, :self.chroma_len]
|
| 687 |
+
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
|
| 688 |
+
|
| 689 |
+
return chroma
|
| 690 |
+
|
| 691 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
| 692 |
+
"""Apply WavConditioner tokenization and populate cache if needed."""
|
| 693 |
+
x = super().tokenize(x)
|
| 694 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
| 695 |
+
if self.cache is not None and no_undefined_paths:
|
| 696 |
+
paths = [Path(p) for p in x.path if p is not None]
|
| 697 |
+
self.cache.populate_embed_cache(paths, x)
|
| 698 |
+
return x
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
class JointEmbeddingConditioner(BaseConditioner):
|
| 702 |
+
"""Joint embedding conditioning supporting both audio or text conditioning.
|
| 703 |
+
|
| 704 |
+
Args:
|
| 705 |
+
dim (int): Dimension.
|
| 706 |
+
output_dim (int): Output dimension.
|
| 707 |
+
device (str): Device.
|
| 708 |
+
attribute (str): Attribute used by the conditioner.
|
| 709 |
+
autocast_dtype (str): Autocast for the conditioner.
|
| 710 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
| 711 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
| 712 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
| 713 |
+
kwargs: Additional parameters for residual vector quantizer.
|
| 714 |
+
"""
|
| 715 |
+
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
| 716 |
+
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
|
| 717 |
+
n_q: int = 12, bins: int = 1024, **kwargs):
|
| 718 |
+
super().__init__(dim=dim, output_dim=output_dim)
|
| 719 |
+
self.device = device
|
| 720 |
+
self.attribute = attribute
|
| 721 |
+
if autocast_dtype is None or device == 'cpu':
|
| 722 |
+
self.autocast = TorchAutocast(enabled=False)
|
| 723 |
+
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
| 724 |
+
else:
|
| 725 |
+
dtype = getattr(torch, autocast_dtype)
|
| 726 |
+
assert isinstance(dtype, torch.dtype)
|
| 727 |
+
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
|
| 728 |
+
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
| 729 |
+
# residual vector quantizer to discretize the conditioned embedding
|
| 730 |
+
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
|
| 731 |
+
if quantize:
|
| 732 |
+
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
| 733 |
+
|
| 734 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 735 |
+
"""Get joint embedding in latent space from the inputs.
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
| 739 |
+
and corresponding empty indexes.
|
| 740 |
+
"""
|
| 741 |
+
raise NotImplementedError()
|
| 742 |
+
|
| 743 |
+
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
| 744 |
+
with self.autocast:
|
| 745 |
+
embed, empty_idx = self._get_embed(x)
|
| 746 |
+
if self.quantizer is not None:
|
| 747 |
+
embed = embed.view(-1, self.dim, 1)
|
| 748 |
+
q_res = self.quantizer(embed, frame_rate=1)
|
| 749 |
+
out_embed = q_res.x.view(-1, self.dim)
|
| 750 |
+
else:
|
| 751 |
+
out_embed = embed
|
| 752 |
+
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
| 753 |
+
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
| 754 |
+
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
| 755 |
+
out_embed = (out_embed * mask.unsqueeze(-1))
|
| 756 |
+
return out_embed, mask
|
| 757 |
+
|
| 758 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
| 759 |
+
return x
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
| 763 |
+
"""Joint Embedding conditioner based on pre-trained CLAP model.
|
| 764 |
+
|
| 765 |
+
This CLAP-based conditioner supports a caching mechanism
|
| 766 |
+
over the computed embeddings for faster training.
|
| 767 |
+
|
| 768 |
+
Args:
|
| 769 |
+
dim (int): Dimension.
|
| 770 |
+
output_dim (int): Output dimension.
|
| 771 |
+
device (str): Device.
|
| 772 |
+
attribute (str): Attribute used by the conditioner.
|
| 773 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
| 774 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
| 775 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
| 776 |
+
checkpoint (str): Path to CLAP checkpoint.
|
| 777 |
+
model_arch (str): CLAP model architecture.
|
| 778 |
+
enable_fusion (bool): Enable fusion for CLAP model.
|
| 779 |
+
sample_rate (int): Sample rate used by CLAP model.
|
| 780 |
+
max_audio_length (float): Maximum audio length for CLAP model.
|
| 781 |
+
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
|
| 782 |
+
normalize (bool): Whether to normalize the CLAP embedding.
|
| 783 |
+
text_p (float): Probability of using text representation instead of audio at train time.
|
| 784 |
+
batch_size (Optional[int]): Batch size for CLAP embedding computation.
|
| 785 |
+
autocast_dtype (str): Autocast for the conditioner.
|
| 786 |
+
cache_path (Optional[str]): Path for pre-computed embeddings caching.
|
| 787 |
+
kwargs: Additional parameters for residual vector quantizer.
|
| 788 |
+
"""
|
| 789 |
+
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
| 790 |
+
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
|
| 791 |
+
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
|
| 792 |
+
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
|
| 793 |
+
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
|
| 794 |
+
try:
|
| 795 |
+
import laion_clap # type: ignore
|
| 796 |
+
except ImportError:
|
| 797 |
+
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
|
| 798 |
+
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
|
| 799 |
+
"Please retrain all models.")
|
| 800 |
+
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
|
| 801 |
+
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
| 802 |
+
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
| 803 |
+
load_clap_state_dict(clap_model, checkpoint)
|
| 804 |
+
clap_model.eval()
|
| 805 |
+
clap_model.to(device)
|
| 806 |
+
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
|
| 807 |
+
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
|
| 808 |
+
**kwargs)
|
| 809 |
+
self.checkpoint = checkpoint
|
| 810 |
+
self.enable_fusion = enable_fusion
|
| 811 |
+
self.model_arch = model_arch
|
| 812 |
+
self.clap: laion_clap.CLAP_Module
|
| 813 |
+
self.clap_tokenize: RobertaTokenizer
|
| 814 |
+
self.clap_sample_rate = sample_rate
|
| 815 |
+
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
|
| 816 |
+
self.clap_stride = int(self.clap_sample_rate * audio_stride)
|
| 817 |
+
self.batch_size = batch_size or 1
|
| 818 |
+
self.normalize = normalize
|
| 819 |
+
self.text_p = text_p
|
| 820 |
+
self.__dict__['clap_tokenize'] = clap_tokenize
|
| 821 |
+
self.__dict__['clap'] = clap_model
|
| 822 |
+
self.wav_cache, self.text_cache = None, None
|
| 823 |
+
if cache_path is not None:
|
| 824 |
+
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
| 825 |
+
compute_embed_fn=self._get_wav_embedding_for_cache,
|
| 826 |
+
extract_embed_fn=self._extract_wav_embedding_chunk)
|
| 827 |
+
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
|
| 828 |
+
compute_embed_fn=self._get_text_embedding_for_cache)
|
| 829 |
+
|
| 830 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
| 831 |
+
# we use the default params from CLAP module here as well
|
| 832 |
+
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
| 833 |
+
|
| 834 |
+
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
|
| 835 |
+
"""Compute text embedding from CLAP model on a given a batch of text.
|
| 836 |
+
|
| 837 |
+
Args:
|
| 838 |
+
text (list[str]): List of text for the batch, with B items.
|
| 839 |
+
Returns:
|
| 840 |
+
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
|
| 841 |
+
"""
|
| 842 |
+
with torch.no_grad():
|
| 843 |
+
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
| 844 |
+
return embed.view(embed.size(0), 1, embed.size(-1))
|
| 845 |
+
|
| 846 |
+
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
|
| 847 |
+
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
| 848 |
+
"""Get text embedding function for the cache."""
|
| 849 |
+
text = x.text[idx]
|
| 850 |
+
text = text if text is not None else ""
|
| 851 |
+
return self._compute_text_embedding([text])[0]
|
| 852 |
+
|
| 853 |
+
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
|
| 854 |
+
"""Preprocess wav to expected format by CLAP model.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
| 858 |
+
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
| 859 |
+
sample_rates (list[int]): Sample rates for each sample in the batch
|
| 860 |
+
Returns:
|
| 861 |
+
torch.Tensor: Audio wav of shape [B, T].
|
| 862 |
+
"""
|
| 863 |
+
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
|
| 864 |
+
if sample_rates is not None:
|
| 865 |
+
_wav = []
|
| 866 |
+
for i, audio in enumerate(wav):
|
| 867 |
+
sr = sample_rates[i]
|
| 868 |
+
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
|
| 869 |
+
_wav.append(audio)
|
| 870 |
+
wav = torch.stack(_wav, dim=0)
|
| 871 |
+
wav = wav.mean(dim=1)
|
| 872 |
+
return wav
|
| 873 |
+
|
| 874 |
+
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
|
| 875 |
+
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
|
| 876 |
+
"""Compute audio wave embedding from CLAP model.
|
| 877 |
+
|
| 878 |
+
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
|
| 879 |
+
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
|
| 880 |
+
average the resulting embeddings.
|
| 881 |
+
|
| 882 |
+
Args:
|
| 883 |
+
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
| 884 |
+
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
| 885 |
+
sample_rates (list[int]): Sample rates for each sample in the batch.
|
| 886 |
+
reduce_mean (bool): Whether to get the average tensor.
|
| 887 |
+
Returns:
|
| 888 |
+
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
|
| 889 |
+
"""
|
| 890 |
+
with torch.no_grad():
|
| 891 |
+
wav = self._preprocess_wav(wav, length, sample_rates)
|
| 892 |
+
B, T = wav.shape
|
| 893 |
+
if T >= self.clap_max_frames:
|
| 894 |
+
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
|
| 895 |
+
else:
|
| 896 |
+
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
|
| 897 |
+
wav = einops.rearrange(wav, 'b f t -> (b f) t')
|
| 898 |
+
embed_list = []
|
| 899 |
+
for i in range(0, wav.size(0), self.batch_size):
|
| 900 |
+
_wav = wav[i:i+self.batch_size, ...]
|
| 901 |
+
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
|
| 902 |
+
embed_list.append(_embed)
|
| 903 |
+
embed = torch.cat(embed_list, dim=0)
|
| 904 |
+
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
|
| 905 |
+
if reduce_mean:
|
| 906 |
+
embed = embed.mean(dim=1, keepdim=True)
|
| 907 |
+
return embed # [B, F, D] with F=1 if reduce_mean is True
|
| 908 |
+
|
| 909 |
+
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
|
| 910 |
+
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
| 911 |
+
"""Compute audio wave embedding for the cache.
|
| 912 |
+
The embedding is computed on a given audio read from file.
|
| 913 |
+
|
| 914 |
+
Args:
|
| 915 |
+
path (str or Path): Path to the full audio file.
|
| 916 |
+
Returns:
|
| 917 |
+
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
|
| 918 |
+
"""
|
| 919 |
+
wav, sr = audio_read(path) # [C, T]
|
| 920 |
+
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
|
| 921 |
+
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
|
| 922 |
+
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
|
| 923 |
+
return embed.squeeze(0) # [F, D]
|
| 924 |
+
|
| 925 |
+
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
| 926 |
+
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
|
| 927 |
+
|
| 928 |
+
Args:
|
| 929 |
+
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
|
| 930 |
+
x (JointEmbedCondition): Joint embedding condition for the full batch.
|
| 931 |
+
idx (int): Index considered for the given embedding to extract.
|
| 932 |
+
Returns:
|
| 933 |
+
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
|
| 934 |
+
"""
|
| 935 |
+
sample_rate = x.sample_rate[idx]
|
| 936 |
+
seek_time = x.seek_time[idx]
|
| 937 |
+
seek_time = 0. if seek_time is None else seek_time
|
| 938 |
+
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
|
| 939 |
+
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
|
| 940 |
+
start_offset = int(seek_time * sample_rate // clap_stride)
|
| 941 |
+
end_offset = int(end_seek_time * sample_rate // clap_stride)
|
| 942 |
+
wav_embed = full_embed[start_offset:end_offset, ...]
|
| 943 |
+
wav_embed = wav_embed.mean(dim=0, keepdim=True)
|
| 944 |
+
return wav_embed.to(self.device) # [F, D]
|
| 945 |
+
|
| 946 |
+
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
| 947 |
+
"""Get CLAP embedding from a batch of text descriptions."""
|
| 948 |
+
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
| 949 |
+
if self.text_cache is not None and no_nullified_cond:
|
| 950 |
+
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
|
| 951 |
+
paths = [Path(p) for p in x.path if p is not None]
|
| 952 |
+
embed = self.text_cache.get_embed_from_cache(paths, x)
|
| 953 |
+
else:
|
| 954 |
+
text = [xi if xi is not None else "" for xi in x.text]
|
| 955 |
+
embed = self._compute_text_embedding(text)
|
| 956 |
+
if self.normalize:
|
| 957 |
+
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
| 958 |
+
return embed
|
| 959 |
+
|
| 960 |
+
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
| 961 |
+
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
|
| 962 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
| 963 |
+
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
| 964 |
+
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
|
| 965 |
+
paths = [Path(p) for p in x.path if p is not None]
|
| 966 |
+
embed = self.wav_cache.get_embed_from_cache(paths, x)
|
| 967 |
+
else:
|
| 968 |
+
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
|
| 969 |
+
if self.normalize:
|
| 970 |
+
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
| 971 |
+
return embed
|
| 972 |
+
|
| 973 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
| 974 |
+
# Trying to limit as much as possible sync points when the cache is warm.
|
| 975 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
| 976 |
+
if self.wav_cache is not None and no_undefined_paths:
|
| 977 |
+
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
| 978 |
+
paths = [Path(p) for p in x.path if p is not None]
|
| 979 |
+
self.wav_cache.populate_embed_cache(paths, x)
|
| 980 |
+
if self.text_cache is not None and no_undefined_paths:
|
| 981 |
+
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
| 982 |
+
paths = [Path(p) for p in x.path if p is not None]
|
| 983 |
+
self.text_cache.populate_embed_cache(paths, x)
|
| 984 |
+
return x
|
| 985 |
+
|
| 986 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 987 |
+
"""Extract shared latent representation from either the wav or the text using CLAP."""
|
| 988 |
+
# decide whether to use text embedding at train time or not
|
| 989 |
+
use_text_embed = random.random() < self.text_p
|
| 990 |
+
if self.training and not use_text_embed:
|
| 991 |
+
embed = self._get_wav_embedding(x)
|
| 992 |
+
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
|
| 993 |
+
else:
|
| 994 |
+
embed = self._get_text_embedding(x)
|
| 995 |
+
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
|
| 996 |
+
return embed, empty_idx
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
| 1000 |
+
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
| 1001 |
+
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
| 1002 |
+
If the condition is of any other type, set its value to None.
|
| 1003 |
+
Works in-place.
|
| 1004 |
+
"""
|
| 1005 |
+
if condition_type not in ['text', 'wav', 'joint_embed']:
|
| 1006 |
+
raise ValueError(
|
| 1007 |
+
"dropout_condition got an unexpected condition type!"
|
| 1008 |
+
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
if condition not in getattr(sample, condition_type):
|
| 1012 |
+
raise ValueError(
|
| 1013 |
+
"dropout_condition received an unexpected condition!"
|
| 1014 |
+
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
| 1015 |
+
f" but got '{condition}' of type '{condition_type}'!"
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
if condition_type == 'wav':
|
| 1019 |
+
wav_cond = sample.wav[condition]
|
| 1020 |
+
sample.wav[condition] = nullify_wav(wav_cond)
|
| 1021 |
+
elif condition_type == 'joint_embed':
|
| 1022 |
+
embed = sample.joint_embed[condition]
|
| 1023 |
+
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
| 1024 |
+
else:
|
| 1025 |
+
sample.text[condition] = None
|
| 1026 |
+
|
| 1027 |
+
return sample
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
class DropoutModule(nn.Module):
|
| 1031 |
+
"""Base module for all dropout modules."""
|
| 1032 |
+
def __init__(self, seed: int = 1234):
|
| 1033 |
+
super().__init__()
|
| 1034 |
+
self.rng = torch.Generator()
|
| 1035 |
+
self.rng.manual_seed(seed)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
class AttributeDropout(DropoutModule):
|
| 1039 |
+
"""Dropout with a given probability per attribute.
|
| 1040 |
+
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
| 1041 |
+
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
| 1042 |
+
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
| 1043 |
+
must also be dropped.
|
| 1044 |
+
|
| 1045 |
+
Args:
|
| 1046 |
+
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
| 1047 |
+
...
|
| 1048 |
+
"genre": 0.1,
|
| 1049 |
+
"artist": 0.5,
|
| 1050 |
+
"wav": 0.25,
|
| 1051 |
+
...
|
| 1052 |
+
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
| 1053 |
+
seed (int, optional): Random seed.
|
| 1054 |
+
"""
|
| 1055 |
+
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
| 1056 |
+
super().__init__(seed=seed)
|
| 1057 |
+
self.active_on_eval = active_on_eval
|
| 1058 |
+
# construct dict that return the values from p otherwise 0
|
| 1059 |
+
self.p = {}
|
| 1060 |
+
for condition_type, probs in p.items():
|
| 1061 |
+
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
| 1062 |
+
|
| 1063 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
| 1064 |
+
"""
|
| 1065 |
+
Args:
|
| 1066 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
| 1067 |
+
Returns:
|
| 1068 |
+
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
| 1069 |
+
"""
|
| 1070 |
+
if not self.training and not self.active_on_eval:
|
| 1071 |
+
return samples
|
| 1072 |
+
|
| 1073 |
+
samples = deepcopy(samples)
|
| 1074 |
+
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
| 1075 |
+
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
| 1076 |
+
if torch.rand(1, generator=self.rng).item() < p:
|
| 1077 |
+
for sample in samples:
|
| 1078 |
+
dropout_condition(sample, condition_type, condition)
|
| 1079 |
+
return samples
|
| 1080 |
+
|
| 1081 |
+
def __repr__(self):
|
| 1082 |
+
return f"AttributeDropout({dict(self.p)})"
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
class ClassifierFreeGuidanceDropout(DropoutModule):
|
| 1086 |
+
"""Classifier Free Guidance dropout.
|
| 1087 |
+
All attributes are dropped with the same probability.
|
| 1088 |
+
|
| 1089 |
+
Args:
|
| 1090 |
+
p (float): Probability to apply condition dropout during training.
|
| 1091 |
+
seed (int): Random seed.
|
| 1092 |
+
"""
|
| 1093 |
+
def __init__(self, p: float, seed: int = 1234):
|
| 1094 |
+
super().__init__(seed=seed)
|
| 1095 |
+
self.p = p
|
| 1096 |
+
|
| 1097 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
| 1098 |
+
"""
|
| 1099 |
+
Args:
|
| 1100 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
| 1101 |
+
Returns:
|
| 1102 |
+
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
| 1103 |
+
"""
|
| 1104 |
+
if not self.training:
|
| 1105 |
+
return samples
|
| 1106 |
+
|
| 1107 |
+
# decide on which attributes to drop in a batched fashion
|
| 1108 |
+
drop = torch.rand(1, generator=self.rng).item() < self.p
|
| 1109 |
+
if not drop:
|
| 1110 |
+
return samples
|
| 1111 |
+
|
| 1112 |
+
# nullify conditions of all attributes
|
| 1113 |
+
samples = deepcopy(samples)
|
| 1114 |
+
for condition_type in ["wav", "text"]:
|
| 1115 |
+
for sample in samples:
|
| 1116 |
+
for condition in sample.attributes[condition_type]:
|
| 1117 |
+
dropout_condition(sample, condition_type, condition)
|
| 1118 |
+
return samples
|
| 1119 |
+
|
| 1120 |
+
def __repr__(self):
|
| 1121 |
+
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
class ConditioningProvider(nn.Module):
|
| 1125 |
+
"""Prepare and provide conditions given all the supported conditioners.
|
| 1126 |
+
|
| 1127 |
+
Args:
|
| 1128 |
+
conditioners (dict): Dictionary of conditioners.
|
| 1129 |
+
device (torch.device or str, optional): Device for conditioners and output condition types.
|
| 1130 |
+
"""
|
| 1131 |
+
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
|
| 1132 |
+
super().__init__()
|
| 1133 |
+
self.device = device
|
| 1134 |
+
self.conditioners = nn.ModuleDict(conditioners)
|
| 1135 |
+
|
| 1136 |
+
@property
|
| 1137 |
+
def joint_embed_conditions(self):
|
| 1138 |
+
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
|
| 1139 |
+
|
| 1140 |
+
@property
|
| 1141 |
+
def has_joint_embed_conditions(self):
|
| 1142 |
+
return len(self.joint_embed_conditions) > 0
|
| 1143 |
+
|
| 1144 |
+
@property
|
| 1145 |
+
def text_conditions(self):
|
| 1146 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
| 1147 |
+
|
| 1148 |
+
@property
|
| 1149 |
+
def wav_conditions(self):
|
| 1150 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
| 1151 |
+
|
| 1152 |
+
@property
|
| 1153 |
+
def has_wav_condition(self):
|
| 1154 |
+
return len(self.wav_conditions) > 0
|
| 1155 |
+
|
| 1156 |
+
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
| 1157 |
+
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
| 1158 |
+
This should be called before starting any real GPU work to avoid synchronization points.
|
| 1159 |
+
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
| 1160 |
+
|
| 1161 |
+
Args:
|
| 1162 |
+
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
| 1163 |
+
text and wav conditions.
|
| 1164 |
+
"""
|
| 1165 |
+
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
| 1166 |
+
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
| 1167 |
+
f" but types were {set([type(x) for x in inputs])}"
|
| 1168 |
+
)
|
| 1169 |
+
|
| 1170 |
+
output = {}
|
| 1171 |
+
text = self._collate_text(inputs)
|
| 1172 |
+
wavs = self._collate_wavs(inputs)
|
| 1173 |
+
joint_embeds = self._collate_joint_embeds(inputs)
|
| 1174 |
+
|
| 1175 |
+
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
| 1176 |
+
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
| 1177 |
+
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
| 1181 |
+
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
| 1182 |
+
return output
|
| 1183 |
+
|
| 1184 |
+
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
| 1185 |
+
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
| 1186 |
+
The output is for example:
|
| 1187 |
+
{
|
| 1188 |
+
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
| 1189 |
+
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
| 1190 |
+
...
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
Args:
|
| 1194 |
+
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
| 1195 |
+
"""
|
| 1196 |
+
output = {}
|
| 1197 |
+
for attribute, inputs in tokenized.items():
|
| 1198 |
+
condition, mask = self.conditioners[attribute](inputs)
|
| 1199 |
+
output[attribute] = (condition, mask)
|
| 1200 |
+
return output
|
| 1201 |
+
|
| 1202 |
+
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
| 1203 |
+
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
| 1204 |
+
are the attributes and the values are the aggregated input per attribute.
|
| 1205 |
+
For example:
|
| 1206 |
+
Input:
|
| 1207 |
+
[
|
| 1208 |
+
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
| 1209 |
+
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
| 1210 |
+
]
|
| 1211 |
+
Output:
|
| 1212 |
+
{
|
| 1213 |
+
"genre": ["Rock", "Hip-hop"],
|
| 1214 |
+
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
Args:
|
| 1218 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
| 1219 |
+
Returns:
|
| 1220 |
+
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
| 1221 |
+
"""
|
| 1222 |
+
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
| 1223 |
+
texts = [x.text for x in samples]
|
| 1224 |
+
for text in texts:
|
| 1225 |
+
for condition in self.text_conditions:
|
| 1226 |
+
out[condition].append(text[condition])
|
| 1227 |
+
return out
|
| 1228 |
+
|
| 1229 |
+
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
| 1230 |
+
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
| 1231 |
+
and the values are Tensors of wavs according to said attributes.
|
| 1232 |
+
|
| 1233 |
+
*Note*: by the time the samples reach this function, each sample should have some waveform
|
| 1234 |
+
inside the "wav" attribute. It should be either:
|
| 1235 |
+
1. A real waveform
|
| 1236 |
+
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
| 1237 |
+
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
| 1238 |
+
|
| 1239 |
+
Args:
|
| 1240 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
| 1241 |
+
Returns:
|
| 1242 |
+
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
| 1243 |
+
"""
|
| 1244 |
+
wavs = defaultdict(list)
|
| 1245 |
+
lengths = defaultdict(list)
|
| 1246 |
+
sample_rates = defaultdict(list)
|
| 1247 |
+
paths = defaultdict(list)
|
| 1248 |
+
seek_times = defaultdict(list)
|
| 1249 |
+
out: tp.Dict[str, WavCondition] = {}
|
| 1250 |
+
|
| 1251 |
+
for sample in samples:
|
| 1252 |
+
for attribute in self.wav_conditions:
|
| 1253 |
+
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
| 1254 |
+
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
| 1255 |
+
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
| 1256 |
+
# mono-channel conditioning
|
| 1257 |
+
wav = wav.mean(1, keepdim=True) # [1, 1, T]
|
| 1258 |
+
wavs[attribute].append(wav.flatten()) # [T]
|
| 1259 |
+
lengths[attribute].append(length)
|
| 1260 |
+
sample_rates[attribute].extend(sample_rate)
|
| 1261 |
+
paths[attribute].extend(path)
|
| 1262 |
+
seek_times[attribute].extend(seek_time)
|
| 1263 |
+
|
| 1264 |
+
# stack all wavs to a single tensor
|
| 1265 |
+
for attribute in self.wav_conditions:
|
| 1266 |
+
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
| 1267 |
+
out[attribute] = WavCondition(
|
| 1268 |
+
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
| 1269 |
+
paths[attribute], seek_times[attribute])
|
| 1270 |
+
|
| 1271 |
+
return out
|
| 1272 |
+
|
| 1273 |
+
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
| 1274 |
+
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
| 1275 |
+
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
| 1276 |
+
|
| 1277 |
+
Args:
|
| 1278 |
+
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
| 1279 |
+
Returns:
|
| 1280 |
+
A dictionary mapping an attribute name to joint embeddings.
|
| 1281 |
+
"""
|
| 1282 |
+
texts = defaultdict(list)
|
| 1283 |
+
wavs = defaultdict(list)
|
| 1284 |
+
lengths = defaultdict(list)
|
| 1285 |
+
sample_rates = defaultdict(list)
|
| 1286 |
+
paths = defaultdict(list)
|
| 1287 |
+
seek_times = defaultdict(list)
|
| 1288 |
+
channels: int = 0
|
| 1289 |
+
|
| 1290 |
+
out = {}
|
| 1291 |
+
for sample in samples:
|
| 1292 |
+
for attribute in self.joint_embed_conditions:
|
| 1293 |
+
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
| 1294 |
+
assert wav.dim() == 3
|
| 1295 |
+
if channels == 0:
|
| 1296 |
+
channels = wav.size(1)
|
| 1297 |
+
else:
|
| 1298 |
+
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
| 1299 |
+
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
| 1300 |
+
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
| 1301 |
+
wavs[attribute].append(wav)
|
| 1302 |
+
texts[attribute].extend(text)
|
| 1303 |
+
lengths[attribute].append(length)
|
| 1304 |
+
sample_rates[attribute].extend(sample_rate)
|
| 1305 |
+
paths[attribute].extend(path)
|
| 1306 |
+
seek_times[attribute].extend(seek_time)
|
| 1307 |
+
|
| 1308 |
+
for attribute in self.joint_embed_conditions:
|
| 1309 |
+
stacked_texts = texts[attribute]
|
| 1310 |
+
stacked_paths = paths[attribute]
|
| 1311 |
+
stacked_seek_times = seek_times[attribute]
|
| 1312 |
+
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
|
| 1313 |
+
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
| 1314 |
+
stacked_sample_rates = sample_rates[attribute]
|
| 1315 |
+
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
|
| 1316 |
+
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
| 1317 |
+
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
| 1318 |
+
assert len(stacked_texts) == stacked_wavs.size(0)
|
| 1319 |
+
out[attribute] = JointEmbedCondition(
|
| 1320 |
+
text=stacked_texts, wav=stacked_wavs,
|
| 1321 |
+
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
| 1322 |
+
path=stacked_paths, seek_time=stacked_seek_times)
|
| 1323 |
+
|
| 1324 |
+
return out
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
+
class ConditionFuser(StreamingModule):
|
| 1328 |
+
"""Condition fuser handles the logic to combine the different conditions
|
| 1329 |
+
to the actual model input.
|
| 1330 |
+
|
| 1331 |
+
Args:
|
| 1332 |
+
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
| 1333 |
+
each condition. For example:
|
| 1334 |
+
{
|
| 1335 |
+
"prepend": ["description"],
|
| 1336 |
+
"sum": ["genre", "bpm"],
|
| 1337 |
+
"cross": ["description"],
|
| 1338 |
+
}
|
| 1339 |
+
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
| 1340 |
+
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
| 1341 |
+
"""
|
| 1342 |
+
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
| 1343 |
+
|
| 1344 |
+
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
| 1345 |
+
cross_attention_pos_emb_scale: float = 1.0):
|
| 1346 |
+
super().__init__()
|
| 1347 |
+
assert all(
|
| 1348 |
+
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
| 1349 |
+
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
| 1350 |
+
self.cross_attention_pos_emb = cross_attention_pos_emb
|
| 1351 |
+
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
| 1352 |
+
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
| 1353 |
+
self.cond2fuse: tp.Dict[str, str] = {}
|
| 1354 |
+
for fuse_method, conditions in fuse2cond.items():
|
| 1355 |
+
for condition in conditions:
|
| 1356 |
+
self.cond2fuse[condition] = fuse_method
|
| 1357 |
+
|
| 1358 |
+
def forward(
|
| 1359 |
+
self,
|
| 1360 |
+
input: torch.Tensor,
|
| 1361 |
+
conditions: tp.Dict[str, ConditionType]
|
| 1362 |
+
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
| 1363 |
+
"""Fuse the conditions to the provided model input.
|
| 1364 |
+
|
| 1365 |
+
Args:
|
| 1366 |
+
input (torch.Tensor): Transformer input.
|
| 1367 |
+
conditions (dict[str, ConditionType]): Dict of conditions.
|
| 1368 |
+
Returns:
|
| 1369 |
+
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
| 1370 |
+
after the conditions have been fused. The second output tensor is the tensor
|
| 1371 |
+
used for cross-attention or None if no cross attention inputs exist.
|
| 1372 |
+
"""
|
| 1373 |
+
B, T, _ = input.shape
|
| 1374 |
+
|
| 1375 |
+
if 'offsets' in self._streaming_state:
|
| 1376 |
+
first_step = False
|
| 1377 |
+
offsets = self._streaming_state['offsets']
|
| 1378 |
+
else:
|
| 1379 |
+
first_step = True
|
| 1380 |
+
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
| 1381 |
+
|
| 1382 |
+
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
| 1383 |
+
f"given conditions contain unknown attributes for fuser, " \
|
| 1384 |
+
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
| 1385 |
+
cross_attention_output = None
|
| 1386 |
+
for cond_type, (cond, cond_mask) in conditions.items():
|
| 1387 |
+
op = self.cond2fuse[cond_type]
|
| 1388 |
+
if op == 'sum':
|
| 1389 |
+
input += cond
|
| 1390 |
+
elif op == 'input_interpolate':
|
| 1391 |
+
cond = einops.rearrange(cond, "b t d -> b d t")
|
| 1392 |
+
cond = F.interpolate(cond, size=input.shape[1])
|
| 1393 |
+
input += einops.rearrange(cond, "b d t -> b t d")
|
| 1394 |
+
elif op == 'prepend':
|
| 1395 |
+
if first_step:
|
| 1396 |
+
input = torch.cat([cond, input], dim=1)
|
| 1397 |
+
elif op == 'cross':
|
| 1398 |
+
if cross_attention_output is not None:
|
| 1399 |
+
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
| 1400 |
+
else:
|
| 1401 |
+
cross_attention_output = cond
|
| 1402 |
+
else:
|
| 1403 |
+
raise ValueError(f"unknown op ({op})")
|
| 1404 |
+
|
| 1405 |
+
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
| 1406 |
+
positions = torch.arange(
|
| 1407 |
+
cross_attention_output.shape[1],
|
| 1408 |
+
device=cross_attention_output.device
|
| 1409 |
+
).view(1, -1, 1)
|
| 1410 |
+
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
| 1411 |
+
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
| 1412 |
+
|
| 1413 |
+
if self._is_streaming:
|
| 1414 |
+
self._streaming_state['offsets'] = offsets + T
|
| 1415 |
+
|
| 1416 |
+
return input, cross_attention_output
|
audiocraft/modules/conv.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import typing as tp
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 18 |
+
'time_group_norm'])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
| 22 |
+
assert norm in CONV_NORMALIZATIONS
|
| 23 |
+
if norm == 'weight_norm':
|
| 24 |
+
return weight_norm(module)
|
| 25 |
+
elif norm == 'spectral_norm':
|
| 26 |
+
return spectral_norm(module)
|
| 27 |
+
else:
|
| 28 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 29 |
+
# doesn't need reparametrization.
|
| 30 |
+
return module
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
| 34 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
| 35 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| 36 |
+
"""
|
| 37 |
+
assert norm in CONV_NORMALIZATIONS
|
| 38 |
+
if norm == 'time_group_norm':
|
| 39 |
+
if causal:
|
| 40 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 41 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 42 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 43 |
+
else:
|
| 44 |
+
return nn.Identity()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| 48 |
+
padding_total: int = 0) -> int:
|
| 49 |
+
"""See `pad_for_conv1d`."""
|
| 50 |
+
length = x.shape[-1]
|
| 51 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 52 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 53 |
+
return ideal_length - length
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
| 57 |
+
"""Pad for a convolution to make sure that the last window is full.
|
| 58 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
| 59 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
| 60 |
+
might get removed.
|
| 61 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
| 62 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
| 63 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
| 64 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
| 65 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
| 66 |
+
"""
|
| 67 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 68 |
+
return F.pad(x, (0, extra_padding))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
| 72 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 73 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 74 |
+
"""
|
| 75 |
+
length = x.shape[-1]
|
| 76 |
+
padding_left, padding_right = paddings
|
| 77 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 78 |
+
if mode == 'reflect':
|
| 79 |
+
max_pad = max(padding_left, padding_right)
|
| 80 |
+
extra_pad = 0
|
| 81 |
+
if length <= max_pad:
|
| 82 |
+
extra_pad = max_pad - length + 1
|
| 83 |
+
x = F.pad(x, (0, extra_pad))
|
| 84 |
+
padded = F.pad(x, paddings, mode, value)
|
| 85 |
+
end = padded.shape[-1] - extra_pad
|
| 86 |
+
return padded[..., :end]
|
| 87 |
+
else:
|
| 88 |
+
return F.pad(x, paddings, mode, value)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 92 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 93 |
+
padding_left, padding_right = paddings
|
| 94 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 95 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 96 |
+
end = x.shape[-1] - padding_right
|
| 97 |
+
return x[..., padding_left: end]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class NormConv1d(nn.Module):
|
| 101 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
| 102 |
+
to provide a uniform interface across normalization approaches.
|
| 103 |
+
"""
|
| 104 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 105 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 108 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 109 |
+
self.norm_type = norm
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
x = self.conv(x)
|
| 113 |
+
x = self.norm(x)
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class NormConv2d(nn.Module):
|
| 118 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
| 119 |
+
to provide a uniform interface across normalization approaches.
|
| 120 |
+
"""
|
| 121 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
| 124 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
| 125 |
+
self.norm_type = norm
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
x = self.conv(x)
|
| 129 |
+
x = self.norm(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class NormConvTranspose1d(nn.Module):
|
| 134 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
| 135 |
+
to provide a uniform interface across normalization approaches.
|
| 136 |
+
"""
|
| 137 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 138 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| 141 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 142 |
+
self.norm_type = norm
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
x = self.convtr(x)
|
| 146 |
+
x = self.norm(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class NormConvTranspose2d(nn.Module):
|
| 151 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
| 152 |
+
to provide a uniform interface across normalization approaches.
|
| 153 |
+
"""
|
| 154 |
+
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
| 157 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
x = self.convtr(x)
|
| 161 |
+
x = self.norm(x)
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class StreamableConv1d(nn.Module):
|
| 166 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
| 167 |
+
and normalization.
|
| 168 |
+
"""
|
| 169 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 170 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
| 171 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
| 172 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 173 |
+
pad_mode: str = 'reflect'):
|
| 174 |
+
super().__init__()
|
| 175 |
+
# warn user on unusual setup between dilation and stride
|
| 176 |
+
if stride > 1 and dilation > 1:
|
| 177 |
+
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
|
| 178 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
|
| 179 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| 180 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| 181 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
| 182 |
+
self.causal = causal
|
| 183 |
+
self.pad_mode = pad_mode
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
B, C, T = x.shape
|
| 187 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
| 188 |
+
stride = self.conv.conv.stride[0]
|
| 189 |
+
dilation = self.conv.conv.dilation[0]
|
| 190 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
| 191 |
+
padding_total = kernel_size - stride
|
| 192 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 193 |
+
if self.causal:
|
| 194 |
+
# Left padding for causal
|
| 195 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 196 |
+
else:
|
| 197 |
+
# Asymmetric padding required for odd strides
|
| 198 |
+
padding_right = padding_total // 2
|
| 199 |
+
padding_left = padding_total - padding_right
|
| 200 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
| 201 |
+
return self.conv(x)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class StreamableConvTranspose1d(nn.Module):
|
| 205 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
| 206 |
+
and normalization.
|
| 207 |
+
"""
|
| 208 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 209 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
| 210 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
| 211 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| 214 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
| 215 |
+
self.causal = causal
|
| 216 |
+
self.trim_right_ratio = trim_right_ratio
|
| 217 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
| 218 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 219 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 223 |
+
stride = self.convtr.convtr.stride[0]
|
| 224 |
+
padding_total = kernel_size - stride
|
| 225 |
+
|
| 226 |
+
y = self.convtr(x)
|
| 227 |
+
|
| 228 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 229 |
+
# removed at the very end, when keeping only the right length for the output,
|
| 230 |
+
# as removing it here would require also passing the length at the matching layer
|
| 231 |
+
# in the encoder.
|
| 232 |
+
if self.causal:
|
| 233 |
+
# Trim the padding on the right according to the specified ratio
|
| 234 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
| 235 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 236 |
+
padding_left = padding_total - padding_right
|
| 237 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 238 |
+
else:
|
| 239 |
+
# Asymmetric padding required for odd strides
|
| 240 |
+
padding_right = padding_total // 2
|
| 241 |
+
padding_left = padding_total - padding_right
|
| 242 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 243 |
+
return y
|
audiocraft/modules/diffusion_schedule.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from collections import namedtuple
|
| 12 |
+
import random
|
| 13 |
+
import typing as tp
|
| 14 |
+
import julius
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def betas_from_alpha_bar(alpha_bar):
|
| 21 |
+
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
|
| 22 |
+
return 1 - alphas
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SampleProcessor(torch.nn.Module):
|
| 26 |
+
def project_sample(self, x: torch.Tensor):
|
| 27 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
def return_sample(self, z: torch.Tensor):
|
| 31 |
+
"""Project back from diffusion space to the actual sample space."""
|
| 32 |
+
return z
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MultiBandProcessor(SampleProcessor):
|
| 36 |
+
"""
|
| 37 |
+
MultiBand sample processor. The input audio is splitted across
|
| 38 |
+
frequency bands evenly distributed in mel-scale.
|
| 39 |
+
|
| 40 |
+
Each band will be rescaled to match the power distribution
|
| 41 |
+
of Gaussian noise in that band, using online metrics
|
| 42 |
+
computed on the first few samples.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
n_bands (int): Number of mel-bands to split the signal over.
|
| 46 |
+
sample_rate (int): Sample rate of the audio.
|
| 47 |
+
num_samples (int): Number of samples to use to fit the rescaling
|
| 48 |
+
for each band. The processor won't be stable
|
| 49 |
+
until it has seen that many samples.
|
| 50 |
+
power_std (float or list/tensor): The rescaling factor computed to match the
|
| 51 |
+
power of Gaussian noise in each band is taken to
|
| 52 |
+
that power, i.e. `1.` means full correction of the energy
|
| 53 |
+
in each band, and values less than `1` means only partial
|
| 54 |
+
correction. Can be used to balance the relative importance
|
| 55 |
+
of low vs. high freq in typical audio signals.
|
| 56 |
+
"""
|
| 57 |
+
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
|
| 58 |
+
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.n_bands = n_bands
|
| 61 |
+
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
|
| 62 |
+
self.num_samples = num_samples
|
| 63 |
+
self.power_std = power_std
|
| 64 |
+
if isinstance(power_std, list):
|
| 65 |
+
assert len(power_std) == n_bands
|
| 66 |
+
power_std = torch.tensor(power_std)
|
| 67 |
+
self.register_buffer('counts', torch.zeros(1))
|
| 68 |
+
self.register_buffer('sum_x', torch.zeros(n_bands))
|
| 69 |
+
self.register_buffer('sum_x2', torch.zeros(n_bands))
|
| 70 |
+
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
|
| 71 |
+
self.counts: torch.Tensor
|
| 72 |
+
self.sum_x: torch.Tensor
|
| 73 |
+
self.sum_x2: torch.Tensor
|
| 74 |
+
self.sum_target_x2: torch.Tensor
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def mean(self):
|
| 78 |
+
mean = self.sum_x / self.counts
|
| 79 |
+
return mean
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def std(self):
|
| 83 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 84 |
+
return std
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def target_std(self):
|
| 88 |
+
target_std = self.sum_target_x2 / self.counts
|
| 89 |
+
return target_std
|
| 90 |
+
|
| 91 |
+
def project_sample(self, x: torch.Tensor):
|
| 92 |
+
assert x.dim() == 3
|
| 93 |
+
bands = self.split_bands(x)
|
| 94 |
+
if self.counts.item() < self.num_samples:
|
| 95 |
+
ref_bands = self.split_bands(torch.randn_like(x))
|
| 96 |
+
self.counts += len(x)
|
| 97 |
+
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
|
| 98 |
+
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
| 99 |
+
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
| 100 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 101 |
+
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
|
| 102 |
+
return bands.sum(dim=0)
|
| 103 |
+
|
| 104 |
+
def return_sample(self, x: torch.Tensor):
|
| 105 |
+
assert x.dim() == 3
|
| 106 |
+
bands = self.split_bands(x)
|
| 107 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
| 108 |
+
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
|
| 109 |
+
return bands.sum(dim=0)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class NoiseSchedule:
|
| 113 |
+
"""Noise schedule for diffusion.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
beta_t0 (float): Variance of the first diffusion step.
|
| 117 |
+
beta_t1 (float): Variance of the last diffusion step.
|
| 118 |
+
beta_exp (float): Power schedule exponent
|
| 119 |
+
num_steps (int): Number of diffusion step.
|
| 120 |
+
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
|
| 121 |
+
clip (float): clipping value for the denoising steps
|
| 122 |
+
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
|
| 123 |
+
repartition (str): shape of the schedule only power schedule is supported
|
| 124 |
+
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
|
| 125 |
+
noise_scale (float): Scaling factor for the noise
|
| 126 |
+
"""
|
| 127 |
+
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
|
| 128 |
+
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
|
| 129 |
+
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
|
| 130 |
+
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
|
| 131 |
+
|
| 132 |
+
self.beta_t0 = beta_t0
|
| 133 |
+
self.beta_t1 = beta_t1
|
| 134 |
+
self.variance = variance
|
| 135 |
+
self.num_steps = num_steps
|
| 136 |
+
self.clip = clip
|
| 137 |
+
self.sample_processor = sample_processor
|
| 138 |
+
self.rescale = rescale
|
| 139 |
+
self.n_bands = n_bands
|
| 140 |
+
self.noise_scale = noise_scale
|
| 141 |
+
assert n_bands is None
|
| 142 |
+
if repartition == "power":
|
| 143 |
+
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
|
| 144 |
+
device=device, dtype=torch.float) ** beta_exp
|
| 145 |
+
else:
|
| 146 |
+
raise RuntimeError('Not implemented')
|
| 147 |
+
self.rng = random.Random(1234)
|
| 148 |
+
|
| 149 |
+
def get_beta(self, step: tp.Union[int, torch.Tensor]):
|
| 150 |
+
if self.n_bands is None:
|
| 151 |
+
return self.betas[step]
|
| 152 |
+
else:
|
| 153 |
+
return self.betas[:, step] # [n_bands, len(step)]
|
| 154 |
+
|
| 155 |
+
def get_initial_noise(self, x: torch.Tensor):
|
| 156 |
+
if self.n_bands is None:
|
| 157 |
+
return torch.randn_like(x)
|
| 158 |
+
return torch.randn((x.size(0), self.n_bands, x.size(2)))
|
| 159 |
+
|
| 160 |
+
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
|
| 161 |
+
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
|
| 162 |
+
if step is None:
|
| 163 |
+
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
|
| 164 |
+
if type(step) is int:
|
| 165 |
+
return (1 - self.betas[:step + 1]).prod()
|
| 166 |
+
else:
|
| 167 |
+
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
|
| 168 |
+
|
| 169 |
+
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
|
| 170 |
+
"""Create a noisy data item for diffusion model training:
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
|
| 174 |
+
tensor_step (bool): If tensor_step = false, only one step t is sample,
|
| 175 |
+
the whole batch is diffused to the same step and t is int.
|
| 176 |
+
If tensor_step = true, t is a tensor of size (x.size(0),)
|
| 177 |
+
every element of the batch is diffused to a independently sampled.
|
| 178 |
+
"""
|
| 179 |
+
step: tp.Union[int, torch.Tensor]
|
| 180 |
+
if tensor_step:
|
| 181 |
+
bs = x.size(0)
|
| 182 |
+
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
|
| 183 |
+
else:
|
| 184 |
+
step = self.rng.randrange(self.num_steps)
|
| 185 |
+
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
|
| 186 |
+
|
| 187 |
+
x = self.sample_processor.project_sample(x)
|
| 188 |
+
noise = torch.randn_like(x)
|
| 189 |
+
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
|
| 190 |
+
return TrainingItem(noisy, noise, step)
|
| 191 |
+
|
| 192 |
+
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
|
| 193 |
+
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
| 194 |
+
"""Full ddpm reverse process.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
model (nn.Module): Diffusion model.
|
| 198 |
+
initial (tensor): Initial Noise.
|
| 199 |
+
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
|
| 200 |
+
return_list (bool): Whether to return the whole process or only the sampled point.
|
| 201 |
+
"""
|
| 202 |
+
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
| 203 |
+
current = initial
|
| 204 |
+
iterates = [initial]
|
| 205 |
+
for step in range(self.num_steps)[::-1]:
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
estimate = model(current, step, condition=condition).sample
|
| 208 |
+
alpha = 1 - self.betas[step]
|
| 209 |
+
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
| 210 |
+
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
|
| 211 |
+
if step == 0:
|
| 212 |
+
sigma2 = 0
|
| 213 |
+
elif self.variance == 'beta':
|
| 214 |
+
sigma2 = 1 - alpha
|
| 215 |
+
elif self.variance == 'beta_tilde':
|
| 216 |
+
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
| 217 |
+
elif self.variance == 'none':
|
| 218 |
+
sigma2 = 0
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f'Invalid variance type {self.variance}')
|
| 221 |
+
|
| 222 |
+
if sigma2 > 0:
|
| 223 |
+
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
| 224 |
+
if self.clip:
|
| 225 |
+
previous = previous.clamp(-self.clip, self.clip)
|
| 226 |
+
current = previous
|
| 227 |
+
alpha_bar = previous_alpha_bar
|
| 228 |
+
if step == 0:
|
| 229 |
+
previous *= self.rescale
|
| 230 |
+
if return_list:
|
| 231 |
+
iterates.append(previous.cpu())
|
| 232 |
+
|
| 233 |
+
if return_list:
|
| 234 |
+
return iterates
|
| 235 |
+
else:
|
| 236 |
+
return self.sample_processor.return_sample(previous)
|
| 237 |
+
|
| 238 |
+
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
|
| 239 |
+
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
| 240 |
+
"""Reverse process that only goes through Markov chain states in step_list."""
|
| 241 |
+
if step_list is None:
|
| 242 |
+
step_list = list(range(1000))[::-50] + [0]
|
| 243 |
+
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
| 244 |
+
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
|
| 245 |
+
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
|
| 246 |
+
current = initial * self.noise_scale
|
| 247 |
+
iterates = [current]
|
| 248 |
+
for idx, step in enumerate(step_list[:-1]):
|
| 249 |
+
with torch.no_grad():
|
| 250 |
+
estimate = model(current, step, condition=condition).sample * self.noise_scale
|
| 251 |
+
alpha = 1 - betas_subsampled[-1 - idx]
|
| 252 |
+
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
| 253 |
+
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
|
| 254 |
+
if step == step_list[-2]:
|
| 255 |
+
sigma2 = 0
|
| 256 |
+
previous_alpha_bar = torch.tensor(1.0)
|
| 257 |
+
else:
|
| 258 |
+
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
| 259 |
+
if sigma2 > 0:
|
| 260 |
+
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
| 261 |
+
if self.clip:
|
| 262 |
+
previous = previous.clamp(-self.clip, self.clip)
|
| 263 |
+
current = previous
|
| 264 |
+
alpha_bar = previous_alpha_bar
|
| 265 |
+
if step == 0:
|
| 266 |
+
previous *= self.rescale
|
| 267 |
+
if return_list:
|
| 268 |
+
iterates.append(previous.cpu())
|
| 269 |
+
if return_list:
|
| 270 |
+
return iterates
|
| 271 |
+
else:
|
| 272 |
+
return self.sample_processor.return_sample(previous)
|
audiocraft/modules/lstm.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class StreamableLSTM(nn.Module):
|
| 11 |
+
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
| 12 |
+
Expects input as convolutional layout.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.skip = skip
|
| 17 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = x.permute(2, 0, 1)
|
| 21 |
+
y, _ = self.lstm(x)
|
| 22 |
+
if self.skip:
|
| 23 |
+
y = y + x
|
| 24 |
+
y = y.permute(1, 2, 0)
|
| 25 |
+
return y
|
audiocraft/modules/rope.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import typing as tp
|
| 8 |
+
|
| 9 |
+
from torch import nn
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class XPos(nn.Module):
|
| 14 |
+
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
|
| 15 |
+
This applies an exponential decay to the RoPE rotation matrix.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
dim (int): Embedding dimension.
|
| 19 |
+
smoothing (float): Smoothing factor applied to the decay rates.
|
| 20 |
+
base_scale (int): Base decay rate, given in terms of scaling time.
|
| 21 |
+
device (torch.device, optional): Device on which to initialize the module.
|
| 22 |
+
dtype (torch.dtype): dtype to use to generate the embedding.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
| 25 |
+
device=None, dtype: torch.dtype = torch.float32):
|
| 26 |
+
super().__init__()
|
| 27 |
+
assert dim % 2 == 0
|
| 28 |
+
assert dtype in [torch.float64, torch.float32]
|
| 29 |
+
self.dtype = dtype
|
| 30 |
+
self.base_scale = base_scale
|
| 31 |
+
|
| 32 |
+
half_dim = dim // 2
|
| 33 |
+
adim = torch.arange(half_dim, device=device, dtype=dtype)
|
| 34 |
+
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
|
| 35 |
+
self.register_buffer("decay_rates", decay_rates)
|
| 36 |
+
self.decay: tp.Optional[torch.Tensor] = None
|
| 37 |
+
|
| 38 |
+
def get_decay(self, start: int, end: int):
|
| 39 |
+
"""Create complex decay tensor, cache values for fast computation."""
|
| 40 |
+
if self.decay is None or end > self.decay.shape[0]:
|
| 41 |
+
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
| 42 |
+
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
| 43 |
+
power = idx / self.base_scale
|
| 44 |
+
scale = self.decay_rates ** power.unsqueeze(-1)
|
| 45 |
+
self.decay = torch.polar(scale, torch.zeros_like(scale))
|
| 46 |
+
return self.decay[start:end] # [T, C/2]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RotaryEmbedding(nn.Module):
|
| 50 |
+
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
dim (int): Embedding dimension (twice the number of frequencies).
|
| 54 |
+
max_period (float): Maximum period of the rotation frequencies.
|
| 55 |
+
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
| 56 |
+
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
| 57 |
+
device (torch.device, optional): Device on which to initialize the module.
|
| 58 |
+
dtype (torch.dtype): dtype to use to generate the embedding.
|
| 59 |
+
"""
|
| 60 |
+
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
| 61 |
+
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
|
| 62 |
+
super().__init__()
|
| 63 |
+
assert dim % 2 == 0
|
| 64 |
+
self.scale = scale
|
| 65 |
+
assert dtype in [torch.float64, torch.float32]
|
| 66 |
+
self.dtype = dtype
|
| 67 |
+
|
| 68 |
+
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
|
| 69 |
+
frequencies = 1.0 / (max_period ** (adim / dim))
|
| 70 |
+
self.register_buffer("frequencies", frequencies)
|
| 71 |
+
self.rotation: tp.Optional[torch.Tensor] = None
|
| 72 |
+
|
| 73 |
+
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
| 74 |
+
|
| 75 |
+
def get_rotation(self, start: int, end: int):
|
| 76 |
+
"""Create complex rotation tensor, cache values for fast computation."""
|
| 77 |
+
if self.rotation is None or end > self.rotation.shape[0]:
|
| 78 |
+
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
| 79 |
+
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
| 80 |
+
angles = torch.outer(idx, self.frequencies)
|
| 81 |
+
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
| 82 |
+
return self.rotation[start:end]
|
| 83 |
+
|
| 84 |
+
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
|
| 85 |
+
"""Apply rope rotation to query or key tensor."""
|
| 86 |
+
T = x.shape[time_dim]
|
| 87 |
+
target_shape = [1] * x.dim()
|
| 88 |
+
target_shape[time_dim] = T
|
| 89 |
+
target_shape[-1] = -1
|
| 90 |
+
rotation = self.get_rotation(start, start + T).view(target_shape)
|
| 91 |
+
|
| 92 |
+
if self.xpos:
|
| 93 |
+
decay = self.xpos.get_decay(start, start + T).view(target_shape)
|
| 94 |
+
else:
|
| 95 |
+
decay = 1.0
|
| 96 |
+
|
| 97 |
+
if invert_decay:
|
| 98 |
+
decay = decay ** -1
|
| 99 |
+
|
| 100 |
+
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
| 101 |
+
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
| 102 |
+
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
|
| 103 |
+
|
| 104 |
+
return x_out.type_as(x)
|
| 105 |
+
|
| 106 |
+
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
|
| 107 |
+
""" Apply rope rotation to both query and key tensors.
|
| 108 |
+
Supports streaming mode, in which query and key are not expected to have the same shape.
|
| 109 |
+
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
|
| 110 |
+
query will be [C] (typically C == 1).
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
query (torch.Tensor): Query to rotate.
|
| 114 |
+
key (torch.Tensor): Key to rotate.
|
| 115 |
+
start (int): Start index of the sequence for time offset.
|
| 116 |
+
time_dim (int): which dimension represent the time steps.
|
| 117 |
+
"""
|
| 118 |
+
query_timesteps = query.shape[time_dim]
|
| 119 |
+
key_timesteps = key.shape[time_dim]
|
| 120 |
+
streaming_offset = key_timesteps - query_timesteps
|
| 121 |
+
|
| 122 |
+
query_out = self.rotate(query, start + streaming_offset, time_dim)
|
| 123 |
+
key_out = self.rotate(key, start, time_dim, invert_decay=True)
|
| 124 |
+
|
| 125 |
+
return query_out, key_out
|
audiocraft/modules/seanet.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import typing as tp
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from .conv import StreamableConv1d, StreamableConvTranspose1d
|
| 13 |
+
from .lstm import StreamableLSTM
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SEANetResnetBlock(nn.Module):
|
| 17 |
+
"""Residual block from SEANet model.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dim (int): Dimension of the input/output.
|
| 21 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
| 22 |
+
dilations (list): List of dilations for the convolutions.
|
| 23 |
+
activation (str): Activation function.
|
| 24 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 25 |
+
norm (str): Normalization method.
|
| 26 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 27 |
+
causal (bool): Whether to use fully causal convolution.
|
| 28 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 29 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 30 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
| 31 |
+
(streamable) convolution as the skip connection.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
| 34 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 35 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
| 36 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
| 37 |
+
super().__init__()
|
| 38 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
| 39 |
+
act = getattr(nn, activation)
|
| 40 |
+
hidden = dim // compress
|
| 41 |
+
block = []
|
| 42 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
| 43 |
+
in_chs = dim if i == 0 else hidden
|
| 44 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
| 45 |
+
block += [
|
| 46 |
+
act(**activation_params),
|
| 47 |
+
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
| 48 |
+
norm=norm, norm_kwargs=norm_params,
|
| 49 |
+
causal=causal, pad_mode=pad_mode),
|
| 50 |
+
]
|
| 51 |
+
self.block = nn.Sequential(*block)
|
| 52 |
+
self.shortcut: nn.Module
|
| 53 |
+
if true_skip:
|
| 54 |
+
self.shortcut = nn.Identity()
|
| 55 |
+
else:
|
| 56 |
+
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
| 57 |
+
causal=causal, pad_mode=pad_mode)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return self.shortcut(x) + self.block(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SEANetEncoder(nn.Module):
|
| 64 |
+
"""SEANet encoder.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
channels (int): Audio channels.
|
| 68 |
+
dimension (int): Intermediate representation dimension.
|
| 69 |
+
n_filters (int): Base width for the model.
|
| 70 |
+
n_residual_layers (int): nb of residual layers.
|
| 71 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
| 72 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
| 73 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
| 74 |
+
activation (str): Activation function.
|
| 75 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 76 |
+
norm (str): Normalization method.
|
| 77 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 78 |
+
kernel_size (int): Kernel size for the initial convolution.
|
| 79 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
| 80 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
| 81 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
| 82 |
+
causal (bool): Whether to use fully causal convolution.
|
| 83 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 84 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
| 85 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
| 86 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 87 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 88 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 89 |
+
For the encoder, it corresponds to the N first blocks.
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 92 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 93 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 94 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 95 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
| 96 |
+
disable_norm_outer_blocks: int = 0):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.channels = channels
|
| 99 |
+
self.dimension = dimension
|
| 100 |
+
self.n_filters = n_filters
|
| 101 |
+
self.ratios = list(reversed(ratios))
|
| 102 |
+
del ratios
|
| 103 |
+
self.n_residual_layers = n_residual_layers
|
| 104 |
+
self.hop_length = np.prod(self.ratios)
|
| 105 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 106 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 107 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
| 108 |
+
"Number of blocks for which to disable norm is invalid." \
|
| 109 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 110 |
+
|
| 111 |
+
act = getattr(nn, activation)
|
| 112 |
+
mult = 1
|
| 113 |
+
model: tp.List[nn.Module] = [
|
| 114 |
+
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
| 115 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
| 116 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 117 |
+
]
|
| 118 |
+
# Downsample to raw audio scale
|
| 119 |
+
for i, ratio in enumerate(self.ratios):
|
| 120 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
| 121 |
+
# Add residual layers
|
| 122 |
+
for j in range(n_residual_layers):
|
| 123 |
+
model += [
|
| 124 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
| 125 |
+
dilations=[dilation_base ** j, 1],
|
| 126 |
+
norm=block_norm, norm_params=norm_params,
|
| 127 |
+
activation=activation, activation_params=activation_params,
|
| 128 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 129 |
+
|
| 130 |
+
# Add downsampling layers
|
| 131 |
+
model += [
|
| 132 |
+
act(**activation_params),
|
| 133 |
+
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
| 134 |
+
kernel_size=ratio * 2, stride=ratio,
|
| 135 |
+
norm=block_norm, norm_kwargs=norm_params,
|
| 136 |
+
causal=causal, pad_mode=pad_mode),
|
| 137 |
+
]
|
| 138 |
+
mult *= 2
|
| 139 |
+
|
| 140 |
+
if lstm:
|
| 141 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 142 |
+
|
| 143 |
+
model += [
|
| 144 |
+
act(**activation_params),
|
| 145 |
+
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
| 146 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
| 147 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
self.model = nn.Sequential(*model)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
return self.model(x)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class SEANetDecoder(nn.Module):
|
| 157 |
+
"""SEANet decoder.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
channels (int): Audio channels.
|
| 161 |
+
dimension (int): Intermediate representation dimension.
|
| 162 |
+
n_filters (int): Base width for the model.
|
| 163 |
+
n_residual_layers (int): nb of residual layers.
|
| 164 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
| 165 |
+
activation (str): Activation function.
|
| 166 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 167 |
+
final_activation (str): Final activation function after all convolutions.
|
| 168 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
| 169 |
+
norm (str): Normalization method.
|
| 170 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 171 |
+
kernel_size (int): Kernel size for the initial convolution.
|
| 172 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
| 173 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
| 174 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
| 175 |
+
causal (bool): Whether to use fully causal convolution.
|
| 176 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 177 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
| 178 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
| 179 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 180 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 181 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
| 182 |
+
For the decoder, it corresponds to the N last blocks.
|
| 183 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
| 184 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
| 185 |
+
"""
|
| 186 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
| 187 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 188 |
+
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
| 189 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 190 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 191 |
+
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
| 192 |
+
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.dimension = dimension
|
| 195 |
+
self.channels = channels
|
| 196 |
+
self.n_filters = n_filters
|
| 197 |
+
self.ratios = ratios
|
| 198 |
+
del ratios
|
| 199 |
+
self.n_residual_layers = n_residual_layers
|
| 200 |
+
self.hop_length = np.prod(self.ratios)
|
| 201 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
| 202 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
| 203 |
+
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
| 204 |
+
"Number of blocks for which to disable norm is invalid." \
|
| 205 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
| 206 |
+
|
| 207 |
+
act = getattr(nn, activation)
|
| 208 |
+
mult = int(2 ** len(self.ratios))
|
| 209 |
+
model: tp.List[nn.Module] = [
|
| 210 |
+
StreamableConv1d(dimension, mult * n_filters, kernel_size,
|
| 211 |
+
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
| 212 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
if lstm:
|
| 216 |
+
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
| 217 |
+
|
| 218 |
+
# Upsample to raw audio scale
|
| 219 |
+
for i, ratio in enumerate(self.ratios):
|
| 220 |
+
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
|
| 221 |
+
# Add upsampling layers
|
| 222 |
+
model += [
|
| 223 |
+
act(**activation_params),
|
| 224 |
+
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
| 225 |
+
kernel_size=ratio * 2, stride=ratio,
|
| 226 |
+
norm=block_norm, norm_kwargs=norm_params,
|
| 227 |
+
causal=causal, trim_right_ratio=trim_right_ratio),
|
| 228 |
+
]
|
| 229 |
+
# Add residual layers
|
| 230 |
+
for j in range(n_residual_layers):
|
| 231 |
+
model += [
|
| 232 |
+
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
| 233 |
+
dilations=[dilation_base ** j, 1],
|
| 234 |
+
activation=activation, activation_params=activation_params,
|
| 235 |
+
norm=block_norm, norm_params=norm_params, causal=causal,
|
| 236 |
+
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 237 |
+
|
| 238 |
+
mult //= 2
|
| 239 |
+
|
| 240 |
+
# Add final layers
|
| 241 |
+
model += [
|
| 242 |
+
act(**activation_params),
|
| 243 |
+
StreamableConv1d(n_filters, channels, last_kernel_size,
|
| 244 |
+
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
| 245 |
+
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
| 246 |
+
]
|
| 247 |
+
# Add optional final activation to decoder (eg. tanh)
|
| 248 |
+
if final_activation is not None:
|
| 249 |
+
final_act = getattr(nn, final_activation)
|
| 250 |
+
final_activation_params = final_activation_params or {}
|
| 251 |
+
model += [
|
| 252 |
+
final_act(**final_activation_params)
|
| 253 |
+
]
|
| 254 |
+
self.model = nn.Sequential(*model)
|
| 255 |
+
|
| 256 |
+
def forward(self, z):
|
| 257 |
+
y = self.model(z)
|
| 258 |
+
return y
|
audiocraft/modules/streaming.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Streaming module API that should be implemented by all Streaming components,
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
import typing as tp
|
| 13 |
+
from torch import nn
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
State = tp.Dict[str, torch.Tensor]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class StreamingModule(nn.Module):
|
| 21 |
+
"""Common API for streaming components.
|
| 22 |
+
|
| 23 |
+
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
| 24 |
+
By convention, the first dim of each tensor must be the batch size.
|
| 25 |
+
Don't use dots in the key names, as this would clash with submodules
|
| 26 |
+
(like in state_dict).
|
| 27 |
+
|
| 28 |
+
If `self._is_streaming` is True, the component should use and remember
|
| 29 |
+
the proper state inside `self._streaming_state`.
|
| 30 |
+
|
| 31 |
+
To set a streaming component in streaming state, use
|
| 32 |
+
|
| 33 |
+
with module.streaming():
|
| 34 |
+
...
|
| 35 |
+
|
| 36 |
+
This will automatically reset the streaming state when exiting the context manager.
|
| 37 |
+
This also automatically propagates to all streaming children module.
|
| 38 |
+
|
| 39 |
+
Some module might also implement the `StreamingModule.flush` method, although
|
| 40 |
+
this one is trickier, as all parents module must be StreamingModule and implement
|
| 41 |
+
it as well for it to work properly. See `StreamingSequential` after.
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
self._streaming_state: State = {}
|
| 46 |
+
self._is_streaming = False
|
| 47 |
+
|
| 48 |
+
def _apply_named_streaming(self, fn: tp.Any):
|
| 49 |
+
for name, module in self.named_modules():
|
| 50 |
+
if isinstance(module, StreamingModule):
|
| 51 |
+
fn(name, module)
|
| 52 |
+
|
| 53 |
+
def _set_streaming(self, streaming: bool):
|
| 54 |
+
def _set_streaming(name, module):
|
| 55 |
+
module._is_streaming = streaming
|
| 56 |
+
self._apply_named_streaming(_set_streaming)
|
| 57 |
+
|
| 58 |
+
@contextmanager
|
| 59 |
+
def streaming(self):
|
| 60 |
+
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
| 61 |
+
self._set_streaming(True)
|
| 62 |
+
try:
|
| 63 |
+
yield
|
| 64 |
+
finally:
|
| 65 |
+
self._set_streaming(False)
|
| 66 |
+
self.reset_streaming()
|
| 67 |
+
|
| 68 |
+
def reset_streaming(self):
|
| 69 |
+
"""Reset the streaming state."""
|
| 70 |
+
def _reset(name: str, module: StreamingModule):
|
| 71 |
+
module._streaming_state.clear()
|
| 72 |
+
|
| 73 |
+
self._apply_named_streaming(_reset)
|
| 74 |
+
|
| 75 |
+
def get_streaming_state(self) -> State:
|
| 76 |
+
"""Return the streaming state, including that of sub-modules."""
|
| 77 |
+
state: State = {}
|
| 78 |
+
|
| 79 |
+
def _add(name: str, module: StreamingModule):
|
| 80 |
+
if name:
|
| 81 |
+
name += "."
|
| 82 |
+
for key, value in module._streaming_state.items():
|
| 83 |
+
state[name + key] = value
|
| 84 |
+
|
| 85 |
+
self._apply_named_streaming(_add)
|
| 86 |
+
return state
|
| 87 |
+
|
| 88 |
+
def set_streaming_state(self, state: State):
|
| 89 |
+
"""Set the streaming state, including that of sub-modules."""
|
| 90 |
+
state = dict(state)
|
| 91 |
+
|
| 92 |
+
def _set(name: str, module: StreamingModule):
|
| 93 |
+
if name:
|
| 94 |
+
name += "."
|
| 95 |
+
module._streaming_state.clear()
|
| 96 |
+
for key, value in list(state.items()):
|
| 97 |
+
# complexity is not ideal here, but probably fine.
|
| 98 |
+
if key.startswith(name):
|
| 99 |
+
local_key = key[len(name):]
|
| 100 |
+
if '.' not in local_key:
|
| 101 |
+
module._streaming_state[local_key] = value
|
| 102 |
+
del state[key]
|
| 103 |
+
|
| 104 |
+
self._apply_named_streaming(_set)
|
| 105 |
+
assert len(state) == 0, list(state.keys())
|
| 106 |
+
|
| 107 |
+
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
| 108 |
+
"""Flush any remaining outputs that were waiting for completion.
|
| 109 |
+
Typically, for convolutions, this will add the final padding
|
| 110 |
+
and process the last buffer.
|
| 111 |
+
|
| 112 |
+
This should take an optional argument `x`, which will be provided
|
| 113 |
+
if a module before this one in the streaming pipeline has already
|
| 114 |
+
spitted out a flushed out buffer.
|
| 115 |
+
"""
|
| 116 |
+
if x is None:
|
| 117 |
+
return None
|
| 118 |
+
else:
|
| 119 |
+
return self(x)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class StreamingSequential(StreamingModule, nn.Sequential):
|
| 123 |
+
"""A streaming compatible alternative of `nn.Sequential`.
|
| 124 |
+
"""
|
| 125 |
+
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
| 126 |
+
for module in self:
|
| 127 |
+
if isinstance(module, StreamingModule):
|
| 128 |
+
x = module.flush(x)
|
| 129 |
+
elif x is not None:
|
| 130 |
+
x = module(x)
|
| 131 |
+
return x
|
audiocraft/modules/transformer.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Transformer model, with streaming support, xformer attention support
|
| 9 |
+
and easy causal attention with a potentially finite receptive field.
|
| 10 |
+
|
| 11 |
+
See `StreamingTransformer` for more information.
|
| 12 |
+
|
| 13 |
+
Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import typing as tp
|
| 17 |
+
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.nn import functional as F
|
| 22 |
+
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
| 23 |
+
from xformers import ops
|
| 24 |
+
|
| 25 |
+
from .rope import RotaryEmbedding
|
| 26 |
+
from .streaming import StreamingModule
|
| 27 |
+
|
| 28 |
+
_efficient_attention_backend: str = 'torch'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def set_efficient_attention_backend(backend: str = 'torch'):
|
| 32 |
+
# Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
|
| 33 |
+
global _efficient_attention_backend
|
| 34 |
+
assert _efficient_attention_backend in ['xformers', 'torch']
|
| 35 |
+
_efficient_attention_backend = backend
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
| 39 |
+
if _efficient_attention_backend == 'torch' and memory_efficient:
|
| 40 |
+
return 2
|
| 41 |
+
else:
|
| 42 |
+
return 1
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _is_profiled() -> bool:
|
| 46 |
+
# Return true if we are currently running with a xformers profiler activated.
|
| 47 |
+
try:
|
| 48 |
+
from xformers.profiler import profiler
|
| 49 |
+
except ImportError:
|
| 50 |
+
return False
|
| 51 |
+
return profiler._Profiler._CURRENT_PROFILER is not None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
|
| 55 |
+
"""Create normalization module for transformer encoder layer.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
norm_type (str): Normalization method.
|
| 59 |
+
dim (int): Dimension of the normalized layer.
|
| 60 |
+
**kwargs (dict): Additional parameters for normalization layer.
|
| 61 |
+
Returns:
|
| 62 |
+
nn.Module: Normalization module.
|
| 63 |
+
"""
|
| 64 |
+
if norm_type == 'layer_norm':
|
| 65 |
+
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"Unknown norm type: {norm_type}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
|
| 71 |
+
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 72 |
+
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
positions (torch.Tensor): LongTensor of positions.
|
| 76 |
+
dim (int): Dimension of the embedding.
|
| 77 |
+
max_period (float): Maximum period of the cosine/sine functions.
|
| 78 |
+
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
| 79 |
+
Returns:
|
| 80 |
+
torch.Tensor: Sinusoidal positional embedding.
|
| 81 |
+
"""
|
| 82 |
+
# We aim for BTC format
|
| 83 |
+
assert dim % 2 == 0
|
| 84 |
+
half_dim = dim // 2
|
| 85 |
+
positions = positions.to(dtype)
|
| 86 |
+
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
| 87 |
+
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
|
| 88 |
+
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
| 89 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
|
| 93 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
|
| 94 |
+
if n_rep == 1:
|
| 95 |
+
return x
|
| 96 |
+
if _efficient_attention_backend == 'torch' and memory_efficient:
|
| 97 |
+
bs, n_kv_heads, slen, head_dim = x.shape
|
| 98 |
+
return (
|
| 99 |
+
x[:, :, None, :, :]
|
| 100 |
+
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
| 101 |
+
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
| 105 |
+
return (
|
| 106 |
+
x[:, :, :, None, :]
|
| 107 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
| 108 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class LayerScale(nn.Module):
|
| 113 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 114 |
+
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
channels (int): Number of channels.
|
| 118 |
+
init (float): Initial scale.
|
| 119 |
+
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
| 120 |
+
device (torch.device or str, optional): Device on which to initialize the module.
|
| 121 |
+
dtype (torch.dtype, optional): dtype to use to initialize the module.
|
| 122 |
+
"""
|
| 123 |
+
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
|
| 124 |
+
device=None, dtype=None):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.channel_last = channel_last
|
| 127 |
+
self.scale = nn.Parameter(
|
| 128 |
+
torch.full((channels,), init,
|
| 129 |
+
requires_grad=True, device=device, dtype=dtype))
|
| 130 |
+
|
| 131 |
+
def forward(self, x: torch.Tensor):
|
| 132 |
+
if self.channel_last:
|
| 133 |
+
return self.scale * x
|
| 134 |
+
else:
|
| 135 |
+
return self.scale[:, None] * x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class StreamingMultiheadAttention(StreamingModule):
|
| 139 |
+
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
embed_dim (int): Dimension to project to.
|
| 143 |
+
num_heads (int): Number of heads.
|
| 144 |
+
dropout (float): Dropout level.
|
| 145 |
+
bias (bool): Use bias in projections.
|
| 146 |
+
causal (bool): Causal mask applied automatically.
|
| 147 |
+
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 148 |
+
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 149 |
+
memory_efficient (bool): Use xformers based memory efficient attention.
|
| 150 |
+
attention_as_float32 (bool): Perform the attention as float32
|
| 151 |
+
(especially important with memory_efficient as autocast won't do this automatically).
|
| 152 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
| 153 |
+
cross_attention: Should be true when used as a cross attention.
|
| 154 |
+
All keys and values must be available at once, streaming is only for the queries.
|
| 155 |
+
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
|
| 156 |
+
interpret the time steps in the keys relative to those in the queries).
|
| 157 |
+
safe_streaming (bool): Bug fix, will go away with xformers update.
|
| 158 |
+
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
|
| 159 |
+
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
| 160 |
+
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
| 161 |
+
device (torch.device, optional): Device on which to initialize.
|
| 162 |
+
dtype (torch.dtype, optional): dtype to use.
|
| 163 |
+
"""
|
| 164 |
+
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
|
| 165 |
+
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
| 166 |
+
memory_efficient: bool = False, attention_as_float32: bool = False,
|
| 167 |
+
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
|
| 168 |
+
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
|
| 169 |
+
device=None, dtype=None):
|
| 170 |
+
super().__init__()
|
| 171 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 172 |
+
if past_context is not None:
|
| 173 |
+
assert causal
|
| 174 |
+
|
| 175 |
+
self.embed_dim = embed_dim
|
| 176 |
+
self.causal = causal
|
| 177 |
+
self.past_context = past_context
|
| 178 |
+
self.memory_efficient = memory_efficient
|
| 179 |
+
self.attention_as_float32 = attention_as_float32
|
| 180 |
+
self.rope = rope
|
| 181 |
+
self.cross_attention = cross_attention
|
| 182 |
+
self.safe_streaming = safe_streaming
|
| 183 |
+
self.num_heads = num_heads
|
| 184 |
+
self.dropout = dropout
|
| 185 |
+
self.kv_repeat = kv_repeat
|
| 186 |
+
if cross_attention:
|
| 187 |
+
assert not causal, "Causal cannot work with cross attention."
|
| 188 |
+
assert rope is None, "Rope cannot work with cross attention."
|
| 189 |
+
|
| 190 |
+
if memory_efficient:
|
| 191 |
+
_verify_xformers_memory_efficient_compat()
|
| 192 |
+
|
| 193 |
+
self.custom = _is_custom(custom, memory_efficient)
|
| 194 |
+
if self.custom:
|
| 195 |
+
out_dim = embed_dim
|
| 196 |
+
assert num_heads % kv_repeat == 0
|
| 197 |
+
assert not cross_attention or kv_repeat == 1
|
| 198 |
+
num_kv = num_heads // kv_repeat
|
| 199 |
+
kv_dim = (embed_dim // num_heads) * num_kv
|
| 200 |
+
out_dim += 2 * kv_dim
|
| 201 |
+
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
|
| 202 |
+
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
| 203 |
+
self.in_proj_weight = in_proj.weight
|
| 204 |
+
self.in_proj_bias = in_proj.bias
|
| 205 |
+
if bias:
|
| 206 |
+
self.in_proj_bias.data.zero_() # Following Pytorch convention
|
| 207 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
| 208 |
+
if bias:
|
| 209 |
+
self.out_proj.bias.data.zero_()
|
| 210 |
+
else:
|
| 211 |
+
assert not qk_layer_norm
|
| 212 |
+
assert kv_repeat == 1
|
| 213 |
+
self.mha = nn.MultiheadAttention(
|
| 214 |
+
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
|
| 215 |
+
**factory_kwargs)
|
| 216 |
+
self.qk_layer_norm = qk_layer_norm
|
| 217 |
+
if qk_layer_norm:
|
| 218 |
+
assert self.custom
|
| 219 |
+
assert kv_repeat == 1
|
| 220 |
+
ln_dim = embed_dim
|
| 221 |
+
self.q_layer_norm = nn.LayerNorm(ln_dim)
|
| 222 |
+
self.k_layer_norm = nn.LayerNorm(ln_dim)
|
| 223 |
+
|
| 224 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
| 225 |
+
if not self.custom:
|
| 226 |
+
# Support compat with regular MHA
|
| 227 |
+
keys = [n for n, _ in self.mha.named_parameters()]
|
| 228 |
+
for key in keys:
|
| 229 |
+
if prefix + key in state_dict:
|
| 230 |
+
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
| 231 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
| 232 |
+
|
| 233 |
+
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
|
| 234 |
+
# Return a causal mask, accounting for potentially stored past keys/values
|
| 235 |
+
# We actually return a bias for the attention score, as this has the same
|
| 236 |
+
# convention both in the builtin MHA in Pytorch, and Xformers functions.
|
| 237 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 238 |
+
if self.memory_efficient:
|
| 239 |
+
from xformers.ops import LowerTriangularMask
|
| 240 |
+
if current_steps == 1:
|
| 241 |
+
# If we only have one step, then we do not need a mask.
|
| 242 |
+
return None
|
| 243 |
+
elif 'past_keys' in self._streaming_state:
|
| 244 |
+
raise RuntimeError("Not supported at the moment")
|
| 245 |
+
else:
|
| 246 |
+
# Then we can safely use a lower triangular mask
|
| 247 |
+
return LowerTriangularMask()
|
| 248 |
+
if self._streaming_state:
|
| 249 |
+
past_keys = self._streaming_state['past_keys']
|
| 250 |
+
past_steps = past_keys.shape[time_dim]
|
| 251 |
+
else:
|
| 252 |
+
past_steps = 0
|
| 253 |
+
|
| 254 |
+
queries_pos = torch.arange(
|
| 255 |
+
past_steps, current_steps + past_steps, device=device).view(-1, 1)
|
| 256 |
+
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
|
| 257 |
+
delta = queries_pos - keys_pos
|
| 258 |
+
valid = delta >= 0
|
| 259 |
+
if self.past_context is not None:
|
| 260 |
+
valid &= (delta <= self.past_context)
|
| 261 |
+
return torch.where(
|
| 262 |
+
valid,
|
| 263 |
+
torch.zeros([], device=device, dtype=dtype),
|
| 264 |
+
torch.full([], float('-inf'), device=device, dtype=dtype))
|
| 265 |
+
|
| 266 |
+
def _complete_kv(self, k, v):
|
| 267 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 268 |
+
if self.cross_attention:
|
| 269 |
+
# With cross attention we assume all keys and values
|
| 270 |
+
# are already available, and streaming is with respect
|
| 271 |
+
# to the queries only.
|
| 272 |
+
return k, v
|
| 273 |
+
# Complete the key/value pair using the streaming state.
|
| 274 |
+
if self._streaming_state:
|
| 275 |
+
pk = self._streaming_state['past_keys']
|
| 276 |
+
nk = torch.cat([pk, k], dim=time_dim)
|
| 277 |
+
if v is k:
|
| 278 |
+
nv = nk
|
| 279 |
+
else:
|
| 280 |
+
pv = self._streaming_state['past_values']
|
| 281 |
+
nv = torch.cat([pv, v], dim=time_dim)
|
| 282 |
+
else:
|
| 283 |
+
nk = k
|
| 284 |
+
nv = v
|
| 285 |
+
|
| 286 |
+
assert nk.shape[time_dim] == nv.shape[time_dim]
|
| 287 |
+
offset = 0
|
| 288 |
+
if self.past_context is not None:
|
| 289 |
+
offset = max(0, nk.shape[time_dim] - self.past_context)
|
| 290 |
+
if self._is_streaming:
|
| 291 |
+
self._streaming_state['past_keys'] = nk[:, offset:]
|
| 292 |
+
if v is not k:
|
| 293 |
+
self._streaming_state['past_values'] = nv[:, offset:]
|
| 294 |
+
if 'offset' in self._streaming_state:
|
| 295 |
+
self._streaming_state['offset'] += offset
|
| 296 |
+
else:
|
| 297 |
+
self._streaming_state['offset'] = torch.tensor(0)
|
| 298 |
+
return nk, nv
|
| 299 |
+
|
| 300 |
+
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
|
| 301 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 302 |
+
# Apply rope embeddings to query and key tensors.
|
| 303 |
+
assert self.rope is not None
|
| 304 |
+
if 'past_keys' in self._streaming_state:
|
| 305 |
+
past_keys_offset = self._streaming_state['past_keys'].shape[1]
|
| 306 |
+
else:
|
| 307 |
+
past_keys_offset = 0
|
| 308 |
+
if 'offset' in self._streaming_state:
|
| 309 |
+
past_context_offset = int(self._streaming_state['offset'].item())
|
| 310 |
+
else:
|
| 311 |
+
past_context_offset = 0
|
| 312 |
+
streaming_offset = past_context_offset + past_keys_offset
|
| 313 |
+
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
|
| 314 |
+
|
| 315 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
| 316 |
+
key_padding_mask=None, need_weights=False, attn_mask=None,
|
| 317 |
+
average_attn_weights=True, is_causal=False):
|
| 318 |
+
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
| 319 |
+
"use the causal args in the constructor.")
|
| 320 |
+
|
| 321 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
| 322 |
+
if time_dim == 2:
|
| 323 |
+
layout = "b h t d"
|
| 324 |
+
else:
|
| 325 |
+
layout = "b t h d"
|
| 326 |
+
dtype = query.dtype
|
| 327 |
+
if self._is_streaming:
|
| 328 |
+
assert self.causal or self.cross_attention, \
|
| 329 |
+
"Streaming only available for causal or cross attention"
|
| 330 |
+
|
| 331 |
+
custom_attn_mask = attn_mask is not None
|
| 332 |
+
|
| 333 |
+
if self.causal:
|
| 334 |
+
assert attn_mask is None
|
| 335 |
+
# At the moment we specialize only for the self-attention case.
|
| 336 |
+
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
| 337 |
+
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
| 338 |
+
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
|
| 339 |
+
|
| 340 |
+
if self.custom:
|
| 341 |
+
# custom implementation
|
| 342 |
+
assert need_weights is False
|
| 343 |
+
assert key_padding_mask is None
|
| 344 |
+
if self.cross_attention:
|
| 345 |
+
# Different queries, keys, values, we have to spit manually the weights
|
| 346 |
+
# before applying the linear.
|
| 347 |
+
dim = self.in_proj_weight.shape[0] // 3
|
| 348 |
+
if self.in_proj_bias is None:
|
| 349 |
+
bias_q, bias_k, bias_v = None, None, None
|
| 350 |
+
else:
|
| 351 |
+
bias_q = self.in_proj_bias[:dim]
|
| 352 |
+
bias_k = self.in_proj_bias[dim: 2 * dim]
|
| 353 |
+
bias_v = self.in_proj_bias[2 * dim:]
|
| 354 |
+
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
| 355 |
+
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
| 356 |
+
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
|
| 357 |
+
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
|
| 358 |
+
if self.qk_layer_norm is True:
|
| 359 |
+
q = self.q_layer_norm(q)
|
| 360 |
+
k = self.k_layer_norm(k)
|
| 361 |
+
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
| 362 |
+
else:
|
| 363 |
+
if not _is_profiled():
|
| 364 |
+
# profiling breaks that propertysomehow.
|
| 365 |
+
assert query is key, "specialized implementation"
|
| 366 |
+
assert value is key, "specialized implementation"
|
| 367 |
+
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
|
| 368 |
+
if self.kv_repeat == 1:
|
| 369 |
+
if time_dim == 2:
|
| 370 |
+
bound_layout = "b h p t d"
|
| 371 |
+
else:
|
| 372 |
+
bound_layout = "b t p h d"
|
| 373 |
+
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
| 374 |
+
q, k, v = ops.unbind(packed, dim=2)
|
| 375 |
+
else:
|
| 376 |
+
embed_dim = self.embed_dim
|
| 377 |
+
per_head_dim = (embed_dim // self.num_heads)
|
| 378 |
+
kv_heads = self.num_heads // self.kv_repeat
|
| 379 |
+
q = projected[:, :, :embed_dim]
|
| 380 |
+
start = embed_dim
|
| 381 |
+
end = start + per_head_dim * kv_heads
|
| 382 |
+
k = projected[:, :, start: end]
|
| 383 |
+
v = projected[:, :, end:]
|
| 384 |
+
q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
|
| 385 |
+
k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
|
| 386 |
+
v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
|
| 387 |
+
|
| 388 |
+
if self.qk_layer_norm is True:
|
| 389 |
+
assert self.kv_repeat == 1
|
| 390 |
+
q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
|
| 391 |
+
q = self.q_layer_norm(q)
|
| 392 |
+
k = self.k_layer_norm(k)
|
| 393 |
+
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
|
| 394 |
+
if self.rope:
|
| 395 |
+
q, k = self._apply_rope(q, k)
|
| 396 |
+
k, v = self._complete_kv(k, v)
|
| 397 |
+
if self.kv_repeat > 1:
|
| 398 |
+
k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
|
| 399 |
+
v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
|
| 400 |
+
if self.attention_as_float32:
|
| 401 |
+
q, k, v = [x.float() for x in [q, k, v]]
|
| 402 |
+
if self.memory_efficient:
|
| 403 |
+
if custom_attn_mask:
|
| 404 |
+
# When using a custom attn mask:
|
| 405 |
+
# Move to query's device, repeat for each sample, remove align8 padding
|
| 406 |
+
seq_len = query.shape[1]
|
| 407 |
+
attn_mask = attn_mask.to(q.dtype)
|
| 408 |
+
attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
|
| 409 |
+
attn_mask = attn_mask[..., :seq_len, :seq_len]
|
| 410 |
+
|
| 411 |
+
p = self.dropout if self.training else 0
|
| 412 |
+
if _efficient_attention_backend == 'torch':
|
| 413 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 414 |
+
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
|
| 415 |
+
else:
|
| 416 |
+
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
|
| 417 |
+
else:
|
| 418 |
+
# We include the dot product as float32, for consistency
|
| 419 |
+
# with the other implementations that include that step
|
| 420 |
+
# as part of the attention. Note that when using `autocast`,
|
| 421 |
+
# the einsums would be done as bfloat16, but the softmax
|
| 422 |
+
# would be done as bfloat16, so `attention_as_float32` will
|
| 423 |
+
# extend a bit the range of operations done in float32,
|
| 424 |
+
# although this should make no difference.
|
| 425 |
+
q = q / q.shape[-1] ** 0.5
|
| 426 |
+
key_layout = layout.replace('t', 'k')
|
| 427 |
+
query_layout = layout
|
| 428 |
+
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
|
| 429 |
+
with torch.autocast(device_type=q.device.type, dtype=torch.float32):
|
| 430 |
+
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
| 431 |
+
else:
|
| 432 |
+
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
| 433 |
+
if attn_mask is not None:
|
| 434 |
+
pre_w = pre_w + attn_mask
|
| 435 |
+
w = torch.softmax(pre_w, dim=-1)
|
| 436 |
+
w = F.dropout(w, self.dropout, training=self.training).to(v)
|
| 437 |
+
# Key and value have the same format.
|
| 438 |
+
x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
|
| 439 |
+
x = x.to(dtype)
|
| 440 |
+
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
| 441 |
+
x = self.out_proj(x)
|
| 442 |
+
else:
|
| 443 |
+
key, value = self._complete_kv(key, value)
|
| 444 |
+
if self.attention_as_float32:
|
| 445 |
+
query, key, value = [x.float() for x in [query, key, value]]
|
| 446 |
+
x, _ = self.mha(
|
| 447 |
+
query, key, value, key_padding_mask,
|
| 448 |
+
need_weights, attn_mask, average_attn_weights)
|
| 449 |
+
x = x.to(dtype)
|
| 450 |
+
|
| 451 |
+
return x, None
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
| 455 |
+
"""TransformerLayer with Streaming / Causal support.
|
| 456 |
+
This also integrates cross_attention, when passing `cross_attention=True`,
|
| 457 |
+
rather than having two separate classes like in PyTorch.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
d_model (int): Dimension of the data.
|
| 461 |
+
num_heads (int): Number of heads.
|
| 462 |
+
dim_feedforward (int): Intermediate dimension of FF module.
|
| 463 |
+
dropout (float): Dropout both for MHA and FF.
|
| 464 |
+
bias_ff (bool): Use bias for FF.
|
| 465 |
+
bias_attn (bool): Use bias for MHA.
|
| 466 |
+
causal (bool): Causal mask applied automatically.
|
| 467 |
+
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 468 |
+
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 469 |
+
memory_efficient (bool): Use xformers based memory efficient attention.
|
| 470 |
+
attention_as_float32 (bool): Perform the attention as float32
|
| 471 |
+
(especially important with memory_efficient as autocast won't do this automatically).
|
| 472 |
+
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
|
| 473 |
+
qk_layer_norm_cross (bool): Same for the cross attention.
|
| 474 |
+
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
| 475 |
+
Cross attention will use the default MHA, as it typically won't require
|
| 476 |
+
special treatment.
|
| 477 |
+
layer_scale (float, optional): If not None, LayerScale will be used with
|
| 478 |
+
the given value as initial scale.
|
| 479 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
| 480 |
+
attention_dropout (float, optional): If not None, separate the value of the dimension dropout
|
| 481 |
+
in FFN and of the attention dropout.
|
| 482 |
+
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
| 483 |
+
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
| 484 |
+
device (torch.device, optional): Device on which to initialize.
|
| 485 |
+
dtype (torch.dtype, optional): dtype to use.
|
| 486 |
+
**kwargs: See `nn.TransformerEncoderLayer`.
|
| 487 |
+
"""
|
| 488 |
+
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
| 489 |
+
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
|
| 490 |
+
past_context: tp.Optional[int] = None, custom: bool = False,
|
| 491 |
+
memory_efficient: bool = False, attention_as_float32: bool = False,
|
| 492 |
+
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
|
| 493 |
+
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
|
| 494 |
+
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
|
| 495 |
+
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
|
| 496 |
+
super().__init__(d_model, num_heads, dim_feedforward, dropout,
|
| 497 |
+
device=device, dtype=dtype, batch_first=True, **kwargs)
|
| 498 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 499 |
+
# Redefine self_attn to our streaming multi-head attention
|
| 500 |
+
attn_kwargs: tp.Dict[str, tp.Any] = {
|
| 501 |
+
'embed_dim': d_model,
|
| 502 |
+
'num_heads': num_heads,
|
| 503 |
+
'dropout': dropout if attention_dropout is None else attention_dropout,
|
| 504 |
+
'bias': bias_attn,
|
| 505 |
+
'custom': custom,
|
| 506 |
+
'memory_efficient': memory_efficient,
|
| 507 |
+
'attention_as_float32': attention_as_float32,
|
| 508 |
+
}
|
| 509 |
+
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
| 510 |
+
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
|
| 511 |
+
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
| 512 |
+
# Redefine feedforward layers to expose bias parameter
|
| 513 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
|
| 514 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
|
| 515 |
+
|
| 516 |
+
self.layer_scale_1: nn.Module
|
| 517 |
+
self.layer_scale_2: nn.Module
|
| 518 |
+
if layer_scale is None:
|
| 519 |
+
self.layer_scale_1 = nn.Identity()
|
| 520 |
+
self.layer_scale_2 = nn.Identity()
|
| 521 |
+
else:
|
| 522 |
+
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
| 523 |
+
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
| 524 |
+
|
| 525 |
+
self.cross_attention: tp.Optional[nn.Module] = None
|
| 526 |
+
if cross_attention:
|
| 527 |
+
self.cross_attention = StreamingMultiheadAttention(
|
| 528 |
+
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
|
| 529 |
+
**attn_kwargs, **factory_kwargs)
|
| 530 |
+
# Norm and dropout
|
| 531 |
+
self.dropout_cross = nn.Dropout(dropout)
|
| 532 |
+
# eps value matching that used in PyTorch reference implementation.
|
| 533 |
+
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
|
| 534 |
+
self.layer_scale_cross: nn.Module
|
| 535 |
+
if layer_scale is None:
|
| 536 |
+
self.layer_scale_cross = nn.Identity()
|
| 537 |
+
else:
|
| 538 |
+
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
|
| 539 |
+
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
| 540 |
+
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
| 541 |
+
|
| 542 |
+
def _cross_attention_block(self, src: torch.Tensor,
|
| 543 |
+
cross_attention_src: torch.Tensor) -> torch.Tensor:
|
| 544 |
+
assert self.cross_attention is not None
|
| 545 |
+
# queries are from src, keys and values from cross_attention_src.
|
| 546 |
+
x = self.cross_attention(
|
| 547 |
+
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
|
| 548 |
+
return self.dropout_cross(x) # type: ignore
|
| 549 |
+
|
| 550 |
+
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
|
| 551 |
+
src_key_padding_mask: tp.Optional[torch.Tensor] = None,
|
| 552 |
+
cross_attention_src: tp.Optional[torch.Tensor] = None):
|
| 553 |
+
if self.cross_attention is None:
|
| 554 |
+
assert cross_attention_src is None
|
| 555 |
+
else:
|
| 556 |
+
assert cross_attention_src is not None
|
| 557 |
+
x = src
|
| 558 |
+
if self.norm_first:
|
| 559 |
+
x = x + self.layer_scale_1(
|
| 560 |
+
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
| 561 |
+
if cross_attention_src is not None:
|
| 562 |
+
x = x + self.layer_scale_cross(
|
| 563 |
+
self._cross_attention_block(
|
| 564 |
+
self.norm_cross(x), cross_attention_src))
|
| 565 |
+
x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
|
| 566 |
+
else:
|
| 567 |
+
x = self.norm1(x + self.layer_scale_1(
|
| 568 |
+
self._sa_block(x, src_mask, src_key_padding_mask)))
|
| 569 |
+
if cross_attention_src is not None:
|
| 570 |
+
x = self.norm_cross(
|
| 571 |
+
x + self.layer_scale_cross(
|
| 572 |
+
self._cross_attention_block(src, cross_attention_src)))
|
| 573 |
+
x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
|
| 574 |
+
return x
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class StreamingTransformer(StreamingModule):
|
| 578 |
+
"""Transformer with Streaming / Causal support.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
d_model (int): Dimension of the data.
|
| 582 |
+
num_heads (int): Number of heads.
|
| 583 |
+
dim_feedforward (int): Intermediate dimension of FF module.
|
| 584 |
+
dropout (float): Dropout both for MHA and FF.
|
| 585 |
+
bias_ff (bool): Use bias for FF.
|
| 586 |
+
bias_attn (bool): Use bias for MHA.
|
| 587 |
+
causal (bool): Causal mask applied automatically.
|
| 588 |
+
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
| 589 |
+
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
| 590 |
+
memory_efficient (bool): Use xformers based memory efficient attention.
|
| 591 |
+
attention_as_float32 (bool): Perform the attention as float32
|
| 592 |
+
(especially important with memory_efficient as autocast won't do this automatically).
|
| 593 |
+
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
| 594 |
+
layer_scale (float, optional): If not None, LayerScale will be used
|
| 595 |
+
with the given value as initial scale.
|
| 596 |
+
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
|
| 597 |
+
max_period (float): Maximum period of the time embedding.
|
| 598 |
+
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
| 599 |
+
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
|
| 600 |
+
lr (float, optional): learning rate override through the `make_optim_group` API.
|
| 601 |
+
weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
|
| 602 |
+
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
| 603 |
+
to initialize the layers, allowing further customization outside of AudioCraft.
|
| 604 |
+
checkpointing (str): Checkpointing strategy to reduce memory usage.
|
| 605 |
+
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
|
| 606 |
+
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
|
| 607 |
+
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
|
| 608 |
+
a policy for opting-out some operations of the checkpointing like
|
| 609 |
+
linear layers and attention, providing a middle ground between speed and memory.
|
| 610 |
+
device (torch.device, optional): Device on which to initialize.
|
| 611 |
+
dtype (torch.dtype, optional): dtype to use.
|
| 612 |
+
**kwargs: See `nn.TransformerEncoderLayer`.
|
| 613 |
+
"""
|
| 614 |
+
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
| 615 |
+
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
|
| 616 |
+
causal: bool = False, past_context: tp.Optional[int] = None,
|
| 617 |
+
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
|
| 618 |
+
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
|
| 619 |
+
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
|
| 620 |
+
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
|
| 621 |
+
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
|
| 622 |
+
checkpointing: str = 'none', device=None, dtype=None, **kwargs):
|
| 623 |
+
super().__init__()
|
| 624 |
+
assert d_model % num_heads == 0
|
| 625 |
+
|
| 626 |
+
self.positional_embedding = positional_embedding
|
| 627 |
+
self.max_period = max_period
|
| 628 |
+
self.positional_scale = positional_scale
|
| 629 |
+
self.weight_decay = weight_decay
|
| 630 |
+
self.lr = lr
|
| 631 |
+
|
| 632 |
+
assert positional_embedding in ['sin', 'rope', 'sin_rope']
|
| 633 |
+
self.rope: tp.Optional[RotaryEmbedding] = None
|
| 634 |
+
if self.positional_embedding in ['rope', 'sin_rope']:
|
| 635 |
+
assert _is_custom(custom, memory_efficient)
|
| 636 |
+
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
|
| 637 |
+
xpos=xpos, scale=positional_scale, device=device)
|
| 638 |
+
|
| 639 |
+
self.checkpointing = checkpointing
|
| 640 |
+
|
| 641 |
+
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
|
| 642 |
+
if self.checkpointing.startswith('xformers'):
|
| 643 |
+
_verify_xformers_internal_compat()
|
| 644 |
+
|
| 645 |
+
self.layers = nn.ModuleList()
|
| 646 |
+
for idx in range(num_layers):
|
| 647 |
+
self.layers.append(
|
| 648 |
+
layer_class(
|
| 649 |
+
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
|
| 650 |
+
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
|
| 651 |
+
causal=causal, past_context=past_context, custom=custom,
|
| 652 |
+
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
|
| 653 |
+
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
|
| 654 |
+
device=device, dtype=dtype, **kwargs))
|
| 655 |
+
|
| 656 |
+
if self.checkpointing != 'none':
|
| 657 |
+
for layer in self.layers:
|
| 658 |
+
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
| 659 |
+
# backward hook inside of FSDP...
|
| 660 |
+
layer._magma_checkpointed = True # type: ignore
|
| 661 |
+
|
| 662 |
+
def _apply_layer(self, layer, *args, **kwargs):
|
| 663 |
+
method = self.checkpointing
|
| 664 |
+
if method == 'none':
|
| 665 |
+
return layer(*args, **kwargs)
|
| 666 |
+
elif method == 'torch':
|
| 667 |
+
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
|
| 668 |
+
elif method.startswith('xformers'):
|
| 669 |
+
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
|
| 670 |
+
if method == 'xformers_default':
|
| 671 |
+
# those operations will be saved, and not recomputed.
|
| 672 |
+
# According to Francisco we can get smarter policies but this is a good start.
|
| 673 |
+
allow_list = [
|
| 674 |
+
"xformers.efficient_attention_forward_cutlass.default",
|
| 675 |
+
"xformers_flash.flash_fwd.default",
|
| 676 |
+
"aten.addmm.default",
|
| 677 |
+
"aten.mm.default",
|
| 678 |
+
]
|
| 679 |
+
elif method == 'xformers_mm':
|
| 680 |
+
# those operations will be saved, and not recomputed.
|
| 681 |
+
# According to Francisco we can get smarter policies but this is a good start.
|
| 682 |
+
allow_list = [
|
| 683 |
+
"aten.addmm.default",
|
| 684 |
+
"aten.mm.default",
|
| 685 |
+
]
|
| 686 |
+
else:
|
| 687 |
+
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
|
| 688 |
+
policy_fn = _get_default_policy(allow_list)
|
| 689 |
+
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
|
| 690 |
+
else:
|
| 691 |
+
raise ValueError(f"Checkpointing method {method} is unknown.")
|
| 692 |
+
|
| 693 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 694 |
+
B, T, C = x.shape
|
| 695 |
+
|
| 696 |
+
if 'offsets' in self._streaming_state:
|
| 697 |
+
offsets = self._streaming_state['offsets']
|
| 698 |
+
else:
|
| 699 |
+
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
|
| 700 |
+
|
| 701 |
+
if self.positional_embedding in ['sin', 'sin_rope']:
|
| 702 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
| 703 |
+
positions = positions + offsets.view(-1, 1, 1)
|
| 704 |
+
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
| 705 |
+
x = x + self.positional_scale * pos_emb
|
| 706 |
+
|
| 707 |
+
for layer in self.layers:
|
| 708 |
+
x = self._apply_layer(layer, x, *args, **kwargs)
|
| 709 |
+
|
| 710 |
+
if self._is_streaming:
|
| 711 |
+
self._streaming_state['offsets'] = offsets + T
|
| 712 |
+
|
| 713 |
+
return x
|
| 714 |
+
|
| 715 |
+
def make_optim_group(self):
|
| 716 |
+
group = {"params": list(self.parameters())}
|
| 717 |
+
if self.lr is not None:
|
| 718 |
+
group["lr"] = self.lr
|
| 719 |
+
if self.weight_decay is not None:
|
| 720 |
+
group["weight_decay"] = self.weight_decay
|
| 721 |
+
return group
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
# special attention related function
|
| 725 |
+
|
| 726 |
+
def _verify_xformers_memory_efficient_compat():
|
| 727 |
+
try:
|
| 728 |
+
from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
|
| 729 |
+
except ImportError:
|
| 730 |
+
raise ImportError(
|
| 731 |
+
"xformers is not installed. Please install it and try again.\n"
|
| 732 |
+
"To install on AWS and Azure, run \n"
|
| 733 |
+
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
|
| 734 |
+
"pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n"
|
| 735 |
+
"To install on FAIR Cluster, run \n"
|
| 736 |
+
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
|
| 737 |
+
"pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n")
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def _verify_xformers_internal_compat():
|
| 741 |
+
try:
|
| 742 |
+
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
|
| 743 |
+
except ImportError:
|
| 744 |
+
raise ImportError(
|
| 745 |
+
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
|
| 746 |
+
"To install on AWS and Azure, run \n"
|
| 747 |
+
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
|
| 748 |
+
"pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n"
|
| 749 |
+
"To install on FAIR Cluster, run \n"
|
| 750 |
+
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
|
| 751 |
+
"pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n")
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def _is_custom(custom: bool, memory_efficient: bool):
|
| 755 |
+
return custom or memory_efficient
|