File size: 12,018 Bytes
26246bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
"""

from collections import namedtuple
import random
import typing as tp
import julius
import torch

TrainingItem = namedtuple("TrainingItem", "noisy noise step")


def betas_from_alpha_bar(alpha_bar):
    alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
    return 1 - alphas


class SampleProcessor(torch.nn.Module):
    def project_sample(self, x: torch.Tensor):
        """Project the original sample to the 'space' where the diffusion will happen."""
        return x

    def return_sample(self, z: torch.Tensor):
        """Project back from diffusion space to the actual sample space."""
        return z


class MultiBandProcessor(SampleProcessor):
    """
    MultiBand sample processor. The input audio is splitted across
    frequency bands evenly distributed in mel-scale.

    Each band will be rescaled to match the power distribution
    of Gaussian noise in that band, using online metrics
    computed on the first few samples.

    Args:
        n_bands (int): Number of mel-bands to split the signal over.
        sample_rate (int): Sample rate of the audio.
        num_samples (int): Number of samples to use to fit the rescaling
            for each band. The processor won't be stable
            until it has seen that many samples.
        power_std (float or list/tensor): The rescaling factor computed to match the
            power of Gaussian noise in each band is taken to
            that power, i.e. `1.` means full correction of the energy
            in each band, and values less than `1` means only partial
            correction. Can be used to balance the relative importance
            of low vs. high freq in typical audio signals.
    """
    def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
                 num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
        super().__init__()
        self.n_bands = n_bands
        self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
        self.num_samples = num_samples
        self.power_std = power_std
        if isinstance(power_std, list):
            assert len(power_std) == n_bands
            power_std = torch.tensor(power_std)
        self.register_buffer('counts', torch.zeros(1))
        self.register_buffer('sum_x', torch.zeros(n_bands))
        self.register_buffer('sum_x2', torch.zeros(n_bands))
        self.register_buffer('sum_target_x2', torch.zeros(n_bands))
        self.counts: torch.Tensor
        self.sum_x: torch.Tensor
        self.sum_x2: torch.Tensor
        self.sum_target_x2: torch.Tensor

    @property
    def mean(self):
        mean = self.sum_x / self.counts
        return mean

    @property
    def std(self):
        std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
        return std

    @property
    def target_std(self):
        target_std = self.sum_target_x2 / self.counts
        return target_std

    def project_sample(self, x: torch.Tensor):
        assert x.dim() == 3
        bands = self.split_bands(x)
        if self.counts.item() < self.num_samples:
            ref_bands = self.split_bands(torch.randn_like(x))
            self.counts += len(x)
            self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
            self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
            self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
        rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std  # same output size
        bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
        return bands.sum(dim=0)

    def return_sample(self, x: torch.Tensor):
        assert x.dim() == 3
        bands = self.split_bands(x)
        rescale = (self.std / self.target_std) ** self.power_std
        bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
        return bands.sum(dim=0)


class NoiseSchedule:
    """Noise schedule for diffusion.

    Args:
        beta_t0 (float): Variance of the first diffusion step.
        beta_t1 (float): Variance of the last diffusion step.
        beta_exp (float): Power schedule exponent
        num_steps (int): Number of diffusion step.
        variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
        clip (float): clipping value for the denoising steps
        rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
        repartition (str): shape of the schedule only power schedule is supported
        sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
        noise_scale (float): Scaling factor for the noise
    """
    def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
                 clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
                 repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
                 sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):

        self.beta_t0 = beta_t0
        self.beta_t1 = beta_t1
        self.variance = variance
        self.num_steps = num_steps
        self.clip = clip
        self.sample_processor = sample_processor
        self.rescale = rescale
        self.n_bands = n_bands
        self.noise_scale = noise_scale
        assert n_bands is None
        if repartition == "power":
            self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
                                        device=device, dtype=torch.float) ** beta_exp
        else:
            raise RuntimeError('Not implemented')
        self.rng = random.Random(1234)

    def get_beta(self, step: tp.Union[int, torch.Tensor]):
        if self.n_bands is None:
            return self.betas[step]
        else:
            return self.betas[:, step]  # [n_bands, len(step)]

    def get_initial_noise(self, x: torch.Tensor):
        if self.n_bands is None:
            return torch.randn_like(x)
        return torch.randn((x.size(0), self.n_bands, x.size(2)))

    def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
        """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
        if step is None:
            return (1 - self.betas).cumprod(dim=-1)  # works for simgle and multi bands
        if type(step) is int:
            return (1 - self.betas[:step + 1]).prod()
        else:
            return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)

    def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
        """Create a noisy data item for diffusion model training:

        Args:
            x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
            tensor_step (bool): If tensor_step = false, only one step t is sample,
                the whole batch is diffused to the same step and t is int.
                If tensor_step = true, t is a tensor of size (x.size(0),)
                every element of the batch is diffused to a independently sampled.
        """
        step: tp.Union[int, torch.Tensor]
        if tensor_step:
            bs = x.size(0)
            step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
        else:
            step = self.rng.randrange(self.num_steps)
        alpha_bar = self.get_alpha_bar(step)  # [batch_size, n_bands, 1]

        x = self.sample_processor.project_sample(x)
        noise = torch.randn_like(x)
        noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
        return TrainingItem(noisy, noise, step)

    def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
                 condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
        """Full ddpm reverse process.

        Args:
            model (nn.Module): Diffusion model.
            initial (tensor): Initial Noise.
            condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
            return_list (bool): Whether to return the whole process or only the sampled point.
        """
        alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
        current = initial
        iterates = [initial]
        for step in range(self.num_steps)[::-1]:
            with torch.no_grad():
                estimate = model(current, step, condition=condition).sample
            alpha = 1 - self.betas[step]
            previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
            previous_alpha_bar = self.get_alpha_bar(step=step - 1)
            if step == 0:
                sigma2 = 0
            elif self.variance == 'beta':
                sigma2 = 1 - alpha
            elif self.variance == 'beta_tilde':
                sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
            elif self.variance == 'none':
                sigma2 = 0
            else:
                raise ValueError(f'Invalid variance type {self.variance}')

            if sigma2 > 0:
                previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
            if self.clip:
                previous = previous.clamp(-self.clip, self.clip)
            current = previous
            alpha_bar = previous_alpha_bar
            if step == 0:
                previous *= self.rescale
            if return_list:
                iterates.append(previous.cpu())

        if return_list:
            return iterates
        else:
            return self.sample_processor.return_sample(previous)

    def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
                            condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
        """Reverse process that only goes through Markov chain states in step_list."""
        if step_list is None:
            step_list = list(range(1000))[::-50] + [0]
        alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
        alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
        betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
        current = initial * self.noise_scale
        iterates = [current]
        for idx, step in enumerate(step_list[:-1]):
            with torch.no_grad():
                estimate = model(current, step, condition=condition).sample * self.noise_scale
            alpha = 1 - betas_subsampled[-1 - idx]
            previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
            previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
            if step == step_list[-2]:
                sigma2 = 0
                previous_alpha_bar = torch.tensor(1.0)
            else:
                sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
            if sigma2 > 0:
                previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
            if self.clip:
                previous = previous.clamp(-self.clip, self.clip)
            current = previous
            alpha_bar = previous_alpha_bar
            if step == 0:
                previous *= self.rescale
            if return_list:
                iterates.append(previous.cpu())
        if return_list:
            return iterates
        else:
            return self.sample_processor.return_sample(previous)