|
|
|
|
|
|
|
|
|
|
|
|
|
import typing as tp |
|
|
|
import flashy |
|
import julius |
|
import omegaconf |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from . import builders |
|
from . import base |
|
from .. import models |
|
from ..modules.diffusion_schedule import NoiseSchedule |
|
from ..metrics import RelativeVolumeMel |
|
from ..models.builders import get_processor |
|
from ..utils.samples.manager import SampleManager |
|
from ..solvers.compression import CompressionSolver |
|
|
|
|
|
class PerStageMetrics: |
|
"""Handle prompting the metrics per stage. |
|
It outputs the metrics per range of diffusion states. |
|
e.g. avg loss when t in [250, 500] |
|
""" |
|
def __init__(self, num_steps: int, num_stages: int = 4): |
|
self.num_steps = num_steps |
|
self.num_stages = num_stages |
|
|
|
def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): |
|
if type(step) is int: |
|
stage = int((step / self.num_steps) * self.num_stages) |
|
return {f"{name}_{stage}": loss for name, loss in losses.items()} |
|
elif type(step) is torch.Tensor: |
|
stage_tensor = ((step / self.num_steps) * self.num_stages).long() |
|
out: tp.Dict[str, float] = {} |
|
for stage_idx in range(self.num_stages): |
|
mask = (stage_tensor == stage_idx) |
|
N = mask.sum() |
|
stage_out = {} |
|
if N > 0: |
|
for name, loss in losses.items(): |
|
stage_loss = (mask * loss).sum() / N |
|
stage_out[f"{name}_{stage_idx}"] = stage_loss |
|
out = {**out, **stage_out} |
|
return out |
|
|
|
|
|
class DataProcess: |
|
"""Apply filtering or resampling. |
|
|
|
Args: |
|
initial_sr (int): Initial sample rate. |
|
target_sr (int): Target sample rate. |
|
use_resampling: Whether to use resampling or not. |
|
use_filter (bool): |
|
n_bands (int): Number of bands to consider. |
|
idx_band (int): |
|
device (torch.device or str): |
|
cutoffs (): |
|
boost (bool): |
|
""" |
|
def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, |
|
use_filter: bool = False, n_bands: int = 4, |
|
idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): |
|
"""Apply filtering or resampling |
|
Args: |
|
initial_sr (int): sample rate of the dataset |
|
target_sr (int): sample rate after resampling |
|
use_resampling (bool): whether or not performs resampling |
|
use_filter (bool): when True filter the data to keep only one frequency band |
|
n_bands (int): Number of bands used |
|
cuts (none or list): The cutoff frequencies of the band filtering |
|
if None then we use mel scale bands. |
|
idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs |
|
boost (bool): make the data scale match our music dataset. |
|
""" |
|
assert idx_band < n_bands |
|
self.idx_band = idx_band |
|
if use_filter: |
|
if cutoffs is not None: |
|
self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) |
|
else: |
|
self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) |
|
self.use_filter = use_filter |
|
self.use_resampling = use_resampling |
|
self.target_sr = target_sr |
|
self.initial_sr = initial_sr |
|
self.boost = boost |
|
|
|
def process_data(self, x, metric=False): |
|
if x is None: |
|
return None |
|
if self.boost: |
|
x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) |
|
x * 0.22 |
|
if self.use_filter and not metric: |
|
x = self.filter(x)[self.idx_band] |
|
if self.use_resampling: |
|
x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) |
|
return x |
|
|
|
def inverse_process(self, x): |
|
"""Upsampling only.""" |
|
if self.use_resampling: |
|
x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) |
|
return x |
|
|
|
|
|
class DiffusionSolver(base.StandardSolver): |
|
"""Solver for compression task. |
|
|
|
The diffusion task allows for MultiBand diffusion model training. |
|
|
|
Args: |
|
cfg (DictConfig): Configuration. |
|
""" |
|
def __init__(self, cfg: omegaconf.DictConfig): |
|
super().__init__(cfg) |
|
self.cfg = cfg |
|
self.device = cfg.device |
|
self.sample_rate: int = self.cfg.sample_rate |
|
self.codec_model = CompressionSolver.model_from_checkpoint( |
|
cfg.compression_model_checkpoint, device=self.device) |
|
|
|
self.codec_model.set_num_codebooks(cfg.n_q) |
|
assert self.codec_model.sample_rate == self.cfg.sample_rate, ( |
|
f"Codec model sample rate is {self.codec_model.sample_rate} but " |
|
f"Solver sample rate is {self.cfg.sample_rate}." |
|
) |
|
assert self.codec_model.sample_rate == self.sample_rate, \ |
|
f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ |
|
"don't match." |
|
|
|
self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) |
|
self.register_stateful('sample_processor') |
|
self.sample_processor.to(self.device) |
|
|
|
self.schedule = NoiseSchedule( |
|
**cfg.schedule, device=self.device, sample_processor=self.sample_processor) |
|
|
|
self.eval_metric: tp.Optional[torch.nn.Module] = None |
|
|
|
self.rvm = RelativeVolumeMel() |
|
self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, |
|
use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, |
|
use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, |
|
idx_band=cfg.filter.idx_band, device=self.device) |
|
|
|
@property |
|
def best_metric_name(self) -> tp.Optional[str]: |
|
if self._current_stage == "evaluate": |
|
return 'rvm' |
|
else: |
|
return 'loss' |
|
|
|
@torch.no_grad() |
|
def get_condition(self, wav: torch.Tensor) -> torch.Tensor: |
|
codes, scale = self.codec_model.encode(wav) |
|
assert scale is None, "Scaled compression models not supported." |
|
emb = self.codec_model.decode_latent(codes) |
|
return emb |
|
|
|
def build_model(self): |
|
"""Build model and optimizer as well as optional Exponential Moving Average of the model. |
|
""" |
|
|
|
self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) |
|
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) |
|
self.register_stateful('model', 'optimizer') |
|
self.register_best_state('model') |
|
self.register_ema('model') |
|
|
|
def build_dataloaders(self): |
|
"""Build audio dataloaders for each stage.""" |
|
self.dataloaders = builders.get_audio_datasets(self.cfg) |
|
|
|
def show(self): |
|
|
|
raise NotImplementedError() |
|
|
|
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): |
|
"""Perform one training or valid step on a given batch.""" |
|
x = batch.to(self.device) |
|
loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss |
|
|
|
condition = self.get_condition(x) |
|
sample = self.data_processor.process_data(x) |
|
|
|
input_, target, step = self.schedule.get_training_item(sample, |
|
tensor_step=self.cfg.schedule.variable_step_batch) |
|
out = self.model(input_, step, condition=condition).sample |
|
|
|
base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) |
|
reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) |
|
loss = base_loss / reference_loss ** self.cfg.loss.norm_power |
|
|
|
if self.is_training: |
|
loss.mean().backward() |
|
flashy.distrib.sync_model(self.model) |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
metrics = { |
|
'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), |
|
} |
|
metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) |
|
metrics.update({ |
|
'std_in': input_.std(), 'std_out': out.std()}) |
|
return metrics |
|
|
|
def run_epoch(self): |
|
|
|
self.rng = torch.Generator() |
|
self.rng.manual_seed(1234 + self.epoch) |
|
self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) |
|
|
|
super().run_epoch() |
|
|
|
def evaluate(self): |
|
"""Evaluate stage. |
|
Runs audio reconstruction evaluation. |
|
""" |
|
self.model.eval() |
|
evaluate_stage_name = f'{self.current_stage}' |
|
loader = self.dataloaders['evaluate'] |
|
updates = len(loader) |
|
lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) |
|
|
|
metrics = {} |
|
n = 1 |
|
for idx, batch in enumerate(lp): |
|
x = batch.to(self.device) |
|
with torch.no_grad(): |
|
y_pred = self.regenerate(x) |
|
|
|
y_pred = y_pred.cpu() |
|
y = batch.cpu() |
|
rvm = self.rvm(y_pred, y) |
|
lp.update(**rvm) |
|
if len(metrics) == 0: |
|
metrics = rvm |
|
else: |
|
for key in rvm.keys(): |
|
metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) |
|
metrics = flashy.distrib.average_metrics(metrics) |
|
return metrics |
|
|
|
@torch.no_grad() |
|
def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): |
|
"""Regenerate the given waveform.""" |
|
condition = self.get_condition(wav) |
|
initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) |
|
result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, |
|
step_list=step_list) |
|
result = self.data_processor.inverse_process(result) |
|
return result |
|
|
|
def generate(self): |
|
"""Generate stage.""" |
|
sample_manager = SampleManager(self.xp) |
|
self.model.eval() |
|
generate_stage_name = f'{self.current_stage}' |
|
|
|
loader = self.dataloaders['generate'] |
|
updates = len(loader) |
|
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) |
|
|
|
for batch in lp: |
|
reference, _ = batch |
|
reference = reference.to(self.device) |
|
estimate = self.regenerate(reference) |
|
reference = reference.cpu() |
|
estimate = estimate.cpu() |
|
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) |
|
flashy.distrib.barrier() |
|
|