|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
import time |
|
import typing as tp |
|
|
|
import flashy |
|
import math |
|
import omegaconf |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from . import base, builders |
|
from .compression import CompressionSolver |
|
from .. import metrics as eval_metrics |
|
from .. import models |
|
from ..data.audio_dataset import AudioDataset |
|
from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo |
|
from ..data.audio_utils import normalize_audio |
|
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition |
|
from ..utils.cache import CachedBatchWriter, CachedBatchLoader |
|
from ..utils.samples.manager import SampleManager |
|
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once |
|
|
|
|
|
class MusicGenSolver(base.StandardSolver): |
|
"""Solver for MusicGen training task. |
|
|
|
Used in: https://arxiv.org/abs/2306.05284 |
|
""" |
|
DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC |
|
|
|
def __init__(self, cfg: omegaconf.DictConfig): |
|
super().__init__(cfg) |
|
|
|
self.generation_params = { |
|
'use_sampling': self.cfg.generate.lm.use_sampling, |
|
'temp': self.cfg.generate.lm.temp, |
|
'top_k': self.cfg.generate.lm.top_k, |
|
'top_p': self.cfg.generate.lm.top_p, |
|
} |
|
self._best_metric_name: tp.Optional[str] = 'ce' |
|
|
|
self._cached_batch_writer = None |
|
self._cached_batch_loader = None |
|
if cfg.cache.path: |
|
if cfg.cache.write: |
|
self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path)) |
|
if self.cfg.cache.write_num_shards: |
|
self.logger.warning("Multiple shard cache, best_metric_name will be set to None.") |
|
self._best_metric_name = None |
|
else: |
|
self._cached_batch_loader = CachedBatchLoader( |
|
Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers, |
|
min_length=self.cfg.optim.updates_per_epoch or 1) |
|
self.dataloaders['original_train'] = self.dataloaders['train'] |
|
self.dataloaders['train'] = self._cached_batch_loader |
|
|
|
@staticmethod |
|
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, |
|
device: tp.Optional[str] = None, autocast: bool = True, |
|
batch_size: tp.Optional[int] = None, |
|
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, |
|
**kwargs): |
|
"""Mostly a convenience function around magma.train.get_solver_from_sig, |
|
populating all the proper param, deactivating EMA, FSDP, loading the best state, |
|
basically all you need to get a solver ready to "play" with in single GPU mode |
|
and with minimal memory overhead. |
|
|
|
Args: |
|
sig (str): signature to load. |
|
dtype (str or None): potential dtype, as a string, i.e. 'float16'. |
|
device (str or None): potential device, as a string, i.e. 'cuda'. |
|
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. |
|
""" |
|
from audiocraft import train |
|
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} |
|
our_override_cfg['autocast'] = autocast |
|
if dtype is not None: |
|
our_override_cfg['dtype'] = dtype |
|
if device is not None: |
|
our_override_cfg['device'] = device |
|
if batch_size is not None: |
|
our_override_cfg['dataset'] = {'batch_size': batch_size} |
|
if override_cfg is None: |
|
override_cfg = {} |
|
override_cfg = omegaconf.OmegaConf.merge( |
|
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) |
|
solver = train.get_solver_from_sig( |
|
sig, override_cfg=override_cfg, |
|
load_best=True, disable_fsdp=True, |
|
ignore_state_keys=['optimizer', 'ema'], **kwargs) |
|
solver.model.eval() |
|
return solver |
|
|
|
def get_formatter(self, stage_name: str) -> flashy.Formatter: |
|
return flashy.Formatter({ |
|
'lr': '.2E', |
|
'ce': '.3f', |
|
'ppl': '.3f', |
|
'grad_norm': '.3E', |
|
}, exclude_keys=['ce_q*', 'ppl_q*']) |
|
|
|
@property |
|
def best_metric_name(self) -> tp.Optional[str]: |
|
return self._best_metric_name |
|
|
|
def build_model(self) -> None: |
|
"""Instantiate models and optimizer.""" |
|
|
|
|
|
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( |
|
self.cfg, self.cfg.compression_model_checkpoint, device=self.device) |
|
assert self.compression_model.sample_rate == self.cfg.sample_rate, ( |
|
f"Compression model sample rate is {self.compression_model.sample_rate} but " |
|
f"Solver sample rate is {self.cfg.sample_rate}." |
|
) |
|
|
|
assert self.cfg.transformer_lm.card == self.compression_model.cardinality, ( |
|
"Cardinalities of the LM and compression model don't match: ", |
|
f"LM cardinality is {self.cfg.transformer_lm.card} vs ", |
|
f"compression model cardinality is {self.compression_model.cardinality}" |
|
) |
|
assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, ( |
|
"Numbers of codebooks of the LM and compression models don't match: ", |
|
f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ", |
|
f"compression model numer of codebooks is {self.compression_model.num_codebooks}" |
|
) |
|
self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d", |
|
self.compression_model.num_codebooks, self.compression_model.cardinality, |
|
self.compression_model.frame_rate) |
|
|
|
self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) |
|
if self.cfg.fsdp.use: |
|
assert not self.cfg.autocast, "Cannot use autocast with fsdp" |
|
self.model = self.wrap_with_fsdp(self.model) |
|
self.register_ema('model') |
|
|
|
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) |
|
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) |
|
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') |
|
self.register_best_state('model') |
|
self.autocast_dtype = { |
|
'float16': torch.float16, 'bfloat16': torch.bfloat16 |
|
}[self.cfg.autocast_dtype] |
|
self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None |
|
if self.cfg.fsdp.use: |
|
need_scaler = self.cfg.fsdp.param_dtype == 'float16' |
|
else: |
|
need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 |
|
if need_scaler: |
|
if self.cfg.fsdp.use: |
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
|
self.scaler = ShardedGradScaler() |
|
else: |
|
self.scaler = torch.cuda.amp.GradScaler() |
|
self.register_stateful('scaler') |
|
|
|
def build_dataloaders(self) -> None: |
|
"""Instantiate audio dataloaders for each stage.""" |
|
self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE) |
|
|
|
def show(self) -> None: |
|
"""Show the compression model and LM model.""" |
|
self.logger.info("Compression model:") |
|
self.log_model_summary(self.compression_model) |
|
self.logger.info("LM model:") |
|
self.log_model_summary(self.model) |
|
|
|
def load_state_dict(self, state: dict) -> None: |
|
if 'condition_provider' in state: |
|
model_state = state['model'] |
|
condition_provider_state = state.pop('condition_provider') |
|
prefix = 'condition_provider.' |
|
for key, value in condition_provider_state.items(): |
|
key = prefix + key |
|
assert key not in model_state |
|
model_state[key] = value |
|
super().load_state_dict(state) |
|
|
|
def load_from_pretrained(self, name: str): |
|
|
|
lm_pkg = models.loaders.load_lm_model_ckpt(name) |
|
state: dict = { |
|
'best_state': { |
|
'model': lm_pkg['best_state'], |
|
}, |
|
} |
|
return state |
|
|
|
def _compute_cross_entropy( |
|
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor |
|
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: |
|
"""Compute cross entropy between multi-codebook targets and model's logits. |
|
The cross entropy is computed per codebook to provide codebook-level cross entropy. |
|
Valid timesteps for each of the codebook are pulled from the mask, where invalid |
|
timesteps are set to 0. |
|
|
|
Args: |
|
logits (torch.Tensor): Model's logits of shape [B, K, T, card]. |
|
targets (torch.Tensor): Target codes, of shape [B, K, T]. |
|
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. |
|
Returns: |
|
ce (torch.Tensor): Cross entropy averaged over the codebooks |
|
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). |
|
""" |
|
B, K, T = targets.shape |
|
assert logits.shape[:-1] == targets.shape |
|
assert mask.shape == targets.shape |
|
ce = torch.zeros([], device=targets.device) |
|
ce_per_codebook: tp.List[torch.Tensor] = [] |
|
for k in range(K): |
|
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) |
|
targets_k = targets[:, k, ...].contiguous().view(-1) |
|
mask_k = mask[:, k, ...].contiguous().view(-1) |
|
ce_targets = targets_k[mask_k] |
|
ce_logits = logits_k[mask_k] |
|
q_ce = F.cross_entropy(ce_logits, ce_targets) |
|
ce += q_ce |
|
ce_per_codebook.append(q_ce.detach()) |
|
|
|
ce = ce / K |
|
return ce, ce_per_codebook |
|
|
|
@torch.no_grad() |
|
def _prepare_tokens_and_attributes( |
|
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], |
|
check_synchronization_points: bool = False |
|
) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: |
|
"""Prepare input batchs for language model training. |
|
|
|
Args: |
|
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] |
|
and corresponding metadata as SegmentWithAttributes (with B items). |
|
check_synchronization_points (bool): Whether to check for synchronization points slowing down training. |
|
Returns: |
|
Condition tensors (dict[str, any]): Preprocessed condition attributes. |
|
Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], |
|
with B the batch size, K the number of codebooks, T_s the token timesteps. |
|
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. |
|
""" |
|
if self._cached_batch_loader is None or self.current_stage != "train": |
|
audio, infos = batch |
|
audio = audio.to(self.device) |
|
audio_tokens = None |
|
assert audio.size(0) == len(infos), ( |
|
f"Mismatch between number of items in audio batch ({audio.size(0)})", |
|
f" and in metadata ({len(infos)})" |
|
) |
|
else: |
|
audio = None |
|
|
|
infos, = batch |
|
assert all([isinstance(info, AudioInfo) for info in infos]) |
|
assert all([info.audio_tokens is not None for info in infos]) |
|
audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) |
|
audio_tokens = audio_tokens.long() |
|
for info in infos: |
|
if isinstance(info, MusicInfo): |
|
|
|
|
|
|
|
info.self_wav = WavCondition( |
|
torch.full([1, info.channels, info.total_frames], float('NaN')), |
|
length=torch.tensor([info.n_frames]), |
|
sample_rate=[info.sample_rate], |
|
path=[info.meta.path], |
|
seek_time=[info.seek_time]) |
|
dataset = get_dataset_from_loader(self.dataloaders['original_train']) |
|
assert isinstance(dataset, MusicDataset), type(dataset) |
|
if dataset.paraphraser is not None and info.description is not None: |
|
|
|
info.description = dataset.paraphraser.sample_paraphrase( |
|
info.meta.path, info.description) |
|
|
|
attributes = [info.to_condition_attributes() for info in infos] |
|
attributes = self.model.cfg_dropout(attributes) |
|
attributes = self.model.att_dropout(attributes) |
|
tokenized = self.model.condition_provider.tokenize(attributes) |
|
|
|
|
|
if self.device == "cuda" and check_synchronization_points: |
|
torch.cuda.set_sync_debug_mode("warn") |
|
|
|
if audio_tokens is None: |
|
with torch.no_grad(): |
|
audio_tokens, scale = self.compression_model.encode(audio) |
|
assert scale is None, "Scaled compression model not supported with LM." |
|
|
|
with self.autocast: |
|
condition_tensors = self.model.condition_provider(tokenized) |
|
|
|
|
|
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device) |
|
|
|
if self.cfg.tokens.padding_with_special_token: |
|
audio_tokens = audio_tokens.clone() |
|
padding_mask = padding_mask.clone() |
|
token_sample_rate = self.compression_model.frame_rate |
|
B, K, T_s = audio_tokens.shape |
|
for i in range(B): |
|
n_samples = infos[i].n_frames |
|
audio_sample_rate = infos[i].sample_rate |
|
|
|
valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate) |
|
audio_tokens[i, :, valid_tokens:] = self.model.special_token_id |
|
padding_mask[i, :, valid_tokens:] = 0 |
|
|
|
if self.device == "cuda" and check_synchronization_points: |
|
torch.cuda.set_sync_debug_mode("default") |
|
|
|
if self._cached_batch_writer is not None and self.current_stage == 'train': |
|
assert self._cached_batch_loader is None |
|
assert audio_tokens is not None |
|
for info, one_audio_tokens in zip(infos, audio_tokens): |
|
assert isinstance(info, AudioInfo) |
|
if isinstance(info, MusicInfo): |
|
assert not info.joint_embed, "joint_embed and cache not supported yet." |
|
info.self_wav = None |
|
assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item() |
|
info.audio_tokens = one_audio_tokens.short().cpu() |
|
self._cached_batch_writer.save(infos) |
|
|
|
return condition_tensors, audio_tokens, padding_mask |
|
|
|
def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: |
|
"""Perform one training or valid step on a given batch.""" |
|
check_synchronization_points = idx == 1 and self.device == 'cuda' |
|
|
|
condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( |
|
batch, check_synchronization_points) |
|
|
|
self.deadlock_detect.update('tokens_and_conditions') |
|
|
|
if check_synchronization_points: |
|
torch.cuda.set_sync_debug_mode('warn') |
|
|
|
with self.autocast: |
|
model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) |
|
logits = model_output.logits |
|
mask = padding_mask & model_output.mask |
|
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) |
|
loss = ce |
|
self.deadlock_detect.update('loss') |
|
|
|
if check_synchronization_points: |
|
torch.cuda.set_sync_debug_mode('default') |
|
|
|
if self.is_training: |
|
metrics['lr'] = self.optimizer.param_groups[0]['lr'] |
|
if self.scaler is not None: |
|
loss = self.scaler.scale(loss) |
|
self.deadlock_detect.update('scale') |
|
if self.cfg.fsdp.use: |
|
loss.backward() |
|
flashy.distrib.average_tensors(self.model.buffers()) |
|
elif self.cfg.optim.eager_sync: |
|
with flashy.distrib.eager_sync_model(self.model): |
|
loss.backward() |
|
else: |
|
|
|
|
|
loss.backward() |
|
flashy.distrib.sync_model(self.model) |
|
self.deadlock_detect.update('backward') |
|
|
|
if self.scaler is not None: |
|
self.scaler.unscale_(self.optimizer) |
|
if self.cfg.optim.max_norm: |
|
if self.cfg.fsdp.use: |
|
metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) |
|
else: |
|
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.cfg.optim.max_norm |
|
) |
|
if self.scaler is None: |
|
self.optimizer.step() |
|
else: |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
if self.lr_scheduler: |
|
self.lr_scheduler.step() |
|
self.optimizer.zero_grad() |
|
self.deadlock_detect.update('optim') |
|
if self.scaler is not None: |
|
scale = self.scaler.get_scale() |
|
metrics['grad_scale'] = scale |
|
if not loss.isfinite().all(): |
|
raise RuntimeError("Model probably diverged.") |
|
|
|
metrics['ce'] = ce |
|
metrics['ppl'] = torch.exp(ce) |
|
for k, ce_q in enumerate(ce_per_codebook): |
|
metrics[f'ce_q{k + 1}'] = ce_q |
|
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) |
|
|
|
return metrics |
|
|
|
@torch.no_grad() |
|
def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], |
|
gen_duration: float, prompt_duration: tp.Optional[float] = None, |
|
remove_prompt: bool = False, |
|
**generation_params) -> dict: |
|
"""Run generate step on a batch of optional audio tensor and corresponding attributes. |
|
|
|
Args: |
|
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): |
|
use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. |
|
gen_duration (float): Target audio duration for the generation. |
|
prompt_duration (float, optional): Duration for the audio prompt to use for continuation. |
|
remove_prompt (bool, optional): Whether to remove the prompt from the generated audio. |
|
generation_params: Additional generation parameters. |
|
Returns: |
|
gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation |
|
and the prompt along with additional information. |
|
""" |
|
bench_start = time.time() |
|
audio, meta = batch |
|
assert audio.size(0) == len(meta), ( |
|
f"Mismatch between number of items in audio batch ({audio.size(0)})", |
|
f" and in metadata ({len(meta)})" |
|
) |
|
|
|
attributes = [x.to_condition_attributes() for x in meta] |
|
|
|
|
|
|
|
if prompt_duration is None: |
|
prompt_audio = None |
|
else: |
|
assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" |
|
prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) |
|
prompt_audio = audio[..., :prompt_audio_frames] |
|
|
|
|
|
if prompt_audio is None or prompt_audio.nelement() == 0: |
|
num_samples = len(attributes) |
|
prompt_tokens = None |
|
else: |
|
num_samples = None |
|
prompt_audio = prompt_audio.to(self.device) |
|
prompt_tokens, scale = self.compression_model.encode(prompt_audio) |
|
assert scale is None, "Compression model in MusicGen should not require rescaling." |
|
|
|
|
|
with self.autocast: |
|
total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) |
|
gen_tokens = self.model.generate( |
|
prompt_tokens, attributes, max_gen_len=total_gen_len, |
|
num_samples=num_samples, **self.generation_params) |
|
|
|
|
|
assert gen_tokens.dim() == 3 |
|
gen_audio = self.compression_model.decode(gen_tokens, None) |
|
|
|
bench_end = time.time() |
|
gen_outputs = { |
|
'rtf': (bench_end - bench_start) / gen_duration, |
|
'ref_audio': audio, |
|
'gen_audio': gen_audio, |
|
'gen_tokens': gen_tokens, |
|
'prompt_audio': prompt_audio, |
|
'prompt_tokens': prompt_tokens, |
|
} |
|
return gen_outputs |
|
|
|
def generate_audio(self) -> dict: |
|
"""Audio generation stage.""" |
|
generate_stage_name = f'{self.current_stage}' |
|
sample_manager = SampleManager(self.xp) |
|
self.logger.info(f"Generating samples in {sample_manager.base_folder}") |
|
loader = self.dataloaders['generate'] |
|
updates = len(loader) |
|
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) |
|
|
|
dataset = get_dataset_from_loader(loader) |
|
dataset_duration = dataset.segment_duration |
|
assert dataset_duration is not None |
|
assert isinstance(dataset, AudioDataset) |
|
target_duration = self.cfg.generate.lm.gen_duration |
|
prompt_duration = self.cfg.generate.lm.prompt_duration |
|
if target_duration is None: |
|
target_duration = dataset_duration |
|
if prompt_duration is None: |
|
prompt_duration = dataset_duration / 4 |
|
assert prompt_duration < dataset_duration, ( |
|
f"Specified prompt duration ({prompt_duration}s) is longer", |
|
f" than reference audio duration ({dataset_duration}s)" |
|
) |
|
|
|
def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): |
|
hydrated_conditions = [] |
|
for sample in [x.to_condition_attributes() for x in meta]: |
|
cond_dict = {} |
|
for cond_type in sample.__annotations__.keys(): |
|
for cond_key, cond_val in getattr(sample, cond_type).items(): |
|
if cond_key not in self.model.condition_provider.conditioners.keys(): |
|
continue |
|
if is_jsonable(cond_val): |
|
cond_dict[cond_key] = cond_val |
|
elif isinstance(cond_val, WavCondition): |
|
cond_dict[cond_key] = cond_val.path |
|
elif isinstance(cond_val, JointEmbedCondition): |
|
cond_dict[cond_key] = cond_val.text |
|
else: |
|
|
|
|
|
cond_dict[cond_key] = str(type(cond_val)) |
|
continue |
|
hydrated_conditions.append(cond_dict) |
|
return hydrated_conditions |
|
|
|
metrics: dict = {} |
|
average = flashy.averager() |
|
for batch in lp: |
|
audio, meta = batch |
|
|
|
hydrated_conditions = get_hydrated_conditions(meta) |
|
sample_generation_params = { |
|
**{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()}, |
|
**self.generation_params |
|
} |
|
if self.cfg.generate.lm.unprompted_samples: |
|
if self.cfg.generate.lm.gen_gt_samples: |
|
|
|
self.logger.warn( |
|
"Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true") |
|
gen_unprompted_audio = audio |
|
rtf = 1. |
|
else: |
|
gen_unprompted_outputs = self.run_generate_step( |
|
batch, gen_duration=target_duration, prompt_duration=prompt_duration, |
|
**self.generation_params) |
|
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() |
|
rtf = gen_unprompted_outputs['rtf'] |
|
sample_manager.add_samples( |
|
gen_unprompted_audio, self.epoch, hydrated_conditions, |
|
ground_truth_wavs=audio, generation_args=sample_generation_params) |
|
|
|
if self.cfg.generate.lm.prompted_samples: |
|
gen_outputs = self.run_generate_step( |
|
batch, gen_duration=target_duration, prompt_duration=prompt_duration, |
|
**self.generation_params) |
|
gen_audio = gen_outputs['gen_audio'].cpu() |
|
prompt_audio = gen_outputs['prompt_audio'].cpu() |
|
sample_manager.add_samples( |
|
gen_audio, self.epoch, hydrated_conditions, |
|
prompt_wavs=prompt_audio, ground_truth_wavs=audio, |
|
generation_args=sample_generation_params) |
|
|
|
metrics['rtf'] = rtf |
|
metrics = average(metrics) |
|
|
|
flashy.distrib.barrier() |
|
return metrics |
|
|
|
def generate(self) -> dict: |
|
"""Generate stage.""" |
|
self.model.eval() |
|
with torch.no_grad(): |
|
return self.generate_audio() |
|
|
|
def run_epoch(self): |
|
if self.cfg.cache.write: |
|
if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard: |
|
return |
|
super().run_epoch() |
|
|
|
def train(self): |
|
"""Train stage. |
|
""" |
|
if self._cached_batch_writer is not None: |
|
self._cached_batch_writer.start_epoch(self.epoch) |
|
if self._cached_batch_loader is None: |
|
dataset = get_dataset_from_loader(self.dataloaders['train']) |
|
assert isinstance(dataset, AudioDataset) |
|
dataset.current_epoch = self.epoch |
|
else: |
|
self._cached_batch_loader.start_epoch(self.epoch) |
|
return super().train() |
|
|
|
def evaluate_audio_generation(self) -> dict: |
|
"""Evaluate audio generation with off-the-shelf metrics.""" |
|
evaluate_stage_name = f'{self.current_stage}_generation' |
|
|
|
fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None |
|
kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None |
|
text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None |
|
chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None |
|
should_run_eval = False |
|
eval_chroma_wavs: tp.Optional[torch.Tensor] = None |
|
if self.cfg.evaluate.metrics.fad: |
|
fad = builders.get_fad(self.cfg.metrics.fad).to(self.device) |
|
should_run_eval = True |
|
if self.cfg.evaluate.metrics.kld: |
|
kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device) |
|
should_run_eval = True |
|
if self.cfg.evaluate.metrics.text_consistency: |
|
text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device) |
|
should_run_eval = True |
|
if self.cfg.evaluate.metrics.chroma_cosine: |
|
chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device) |
|
|
|
has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \ |
|
self.model.condition_provider.conditioners['self_wav'].has_eval_wavs() |
|
if has_predefined_eval_chromas: |
|
warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! " |
|
'Resetting eval chromas to None for evaluation.') |
|
eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs |
|
self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) |
|
should_run_eval = True |
|
|
|
def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: |
|
audio_tokens, scale = self.compression_model.encode(audio.to(self.device)) |
|
compressed_audio = self.compression_model.decode(audio_tokens, scale) |
|
return compressed_audio[..., :audio.shape[-1]] |
|
|
|
metrics: dict = {} |
|
if should_run_eval: |
|
loader = self.dataloaders['evaluate'] |
|
updates = len(loader) |
|
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) |
|
average = flashy.averager() |
|
dataset = get_dataset_from_loader(loader) |
|
assert isinstance(dataset, AudioDataset) |
|
self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples") |
|
|
|
for idx, batch in enumerate(lp): |
|
audio, meta = batch |
|
assert all([self.cfg.sample_rate == m.sample_rate for m in meta]) |
|
|
|
target_duration = audio.shape[-1] / self.cfg.sample_rate |
|
if self.cfg.evaluate.fixed_generation_duration: |
|
target_duration = self.cfg.evaluate.fixed_generation_duration |
|
|
|
gen_outputs = self.run_generate_step( |
|
batch, gen_duration=target_duration, |
|
**self.generation_params |
|
) |
|
y_pred = gen_outputs['gen_audio'].detach() |
|
y_pred = y_pred[..., :audio.shape[-1]] |
|
|
|
normalize_kwargs = dict(self.cfg.generate.audio) |
|
normalize_kwargs.pop('format', None) |
|
y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu() |
|
y = audio.cpu() |
|
sizes = torch.tensor([m.n_frames for m in meta]) |
|
sample_rates = torch.tensor([m.sample_rate for m in meta]) |
|
audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta] |
|
|
|
if fad is not None: |
|
if self.cfg.metrics.fad.use_gt: |
|
y_pred = get_compressed_audio(y).cpu() |
|
fad.update(y_pred, y, sizes, sample_rates, audio_stems) |
|
if kldiv is not None: |
|
if self.cfg.metrics.kld.use_gt: |
|
y_pred = get_compressed_audio(y).cpu() |
|
kldiv.update(y_pred, y, sizes, sample_rates) |
|
if text_consistency is not None: |
|
texts = [m.description for m in meta] |
|
if self.cfg.metrics.text_consistency.use_gt: |
|
y_pred = y |
|
text_consistency.update(y_pred, texts, sizes, sample_rates) |
|
if chroma_cosine is not None: |
|
if self.cfg.metrics.chroma_cosine.use_gt: |
|
y_pred = get_compressed_audio(y).cpu() |
|
chroma_cosine.update(y_pred, y, sizes, sample_rates) |
|
|
|
if eval_chroma_wavs is not None: |
|
self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs) |
|
|
|
flashy.distrib.barrier() |
|
if fad is not None: |
|
metrics['fad'] = fad.compute() |
|
if kldiv is not None: |
|
kld_metrics = kldiv.compute() |
|
metrics.update(kld_metrics) |
|
if text_consistency is not None: |
|
metrics['text_consistency'] = text_consistency.compute() |
|
if chroma_cosine is not None: |
|
metrics['chroma_cosine'] = chroma_cosine.compute() |
|
metrics = average(metrics) |
|
metrics = flashy.distrib.average_metrics(metrics, len(loader)) |
|
|
|
return metrics |
|
|
|
def evaluate(self) -> dict: |
|
"""Evaluate stage.""" |
|
self.model.eval() |
|
with torch.no_grad(): |
|
metrics: dict = {} |
|
if self.cfg.evaluate.metrics.base: |
|
metrics.update(self.common_train_valid('evaluate')) |
|
gen_metrics = self.evaluate_audio_generation() |
|
return {**metrics, **gen_metrics} |
|
|