Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| import json | |
| import itertools | |
| import accelerate | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from torch.nn.parallel import DistributedDataParallel | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import ExponentialLR | |
| from librosa.filters import mel as librosa_mel_fn | |
| from accelerate.logging import get_logger | |
| from pathlib import Path | |
| from utils.io import save_audio | |
| from utils.data_utils import * | |
| from utils.util import ( | |
| Logger, | |
| ValueWindow, | |
| remove_older_ckpt, | |
| set_all_random_seed, | |
| save_config, | |
| ) | |
| from utils.mel import extract_mel_features | |
| from models.vocoders.vocoder_trainer import VocoderTrainer | |
| from models.vocoders.gan.gan_vocoder_dataset import ( | |
| GANVocoderDataset, | |
| GANVocoderCollator, | |
| ) | |
| from models.vocoders.gan.generator.bigvgan import BigVGAN | |
| from models.vocoders.gan.generator.hifigan import HiFiGAN | |
| from models.vocoders.gan.generator.melgan import MelGAN | |
| from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN | |
| from models.vocoders.gan.generator.apnet import APNet | |
| from models.vocoders.gan.discriminator.mpd import MultiPeriodDiscriminator | |
| from models.vocoders.gan.discriminator.mrd import MultiResolutionDiscriminator | |
| from models.vocoders.gan.discriminator.mssbcqtd import MultiScaleSubbandCQTDiscriminator | |
| from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator | |
| from models.vocoders.gan.discriminator.msstftd import MultiScaleSTFTDiscriminator | |
| from models.vocoders.gan.gan_vocoder_inference import vocoder_inference | |
| supported_generators = { | |
| "bigvgan": BigVGAN, | |
| "hifigan": HiFiGAN, | |
| "melgan": MelGAN, | |
| "nsfhifigan": NSFHiFiGAN, | |
| "apnet": APNet, | |
| } | |
| supported_discriminators = { | |
| "mpd": MultiPeriodDiscriminator, | |
| "msd": MultiScaleDiscriminator, | |
| "mrd": MultiResolutionDiscriminator, | |
| "msstftd": MultiScaleSTFTDiscriminator, | |
| "mssbcqtd": MultiScaleSubbandCQTDiscriminator, | |
| } | |
| class GANVocoderTrainer(VocoderTrainer): | |
| def __init__(self, args, cfg): | |
| super().__init__() | |
| self.args = args | |
| self.cfg = cfg | |
| cfg.exp_name = args.exp_name | |
| # Init accelerator | |
| self._init_accelerator() | |
| self.accelerator.wait_for_everyone() | |
| # Init logger | |
| with self.accelerator.main_process_first(): | |
| self.logger = get_logger(args.exp_name, log_level=args.log_level) | |
| self.logger.info("=" * 56) | |
| self.logger.info("||\t\t" + "New training process started." + "\t\t||") | |
| self.logger.info("=" * 56) | |
| self.logger.info("\n") | |
| self.logger.debug(f"Using {args.log_level.upper()} logging level.") | |
| self.logger.info(f"Experiment name: {args.exp_name}") | |
| self.logger.info(f"Experiment directory: {self.exp_dir}") | |
| self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") | |
| if self.accelerator.is_main_process: | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") | |
| # Init training status | |
| self.batch_count: int = 0 | |
| self.step: int = 0 | |
| self.epoch: int = 0 | |
| self.max_epoch = ( | |
| self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") | |
| ) | |
| self.logger.info( | |
| "Max epoch: {}".format( | |
| self.max_epoch if self.max_epoch < float("inf") else "Unlimited" | |
| ) | |
| ) | |
| # Check potential erorrs | |
| if self.accelerator.is_main_process: | |
| self._check_basic_configs() | |
| self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride | |
| self.checkpoints_path = [ | |
| [] for _ in range(len(self.save_checkpoint_stride)) | |
| ] | |
| self.run_eval = self.cfg.train.run_eval | |
| # Set random seed | |
| with self.accelerator.main_process_first(): | |
| start = time.monotonic_ns() | |
| self._set_random_seed(self.cfg.train.random_seed) | |
| end = time.monotonic_ns() | |
| self.logger.debug( | |
| f"Setting random seed done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") | |
| # Build dataloader | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building dataset...") | |
| start = time.monotonic_ns() | |
| self.train_dataloader, self.valid_dataloader = self._build_dataloader() | |
| end = time.monotonic_ns() | |
| self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") | |
| # Build model | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building model...") | |
| start = time.monotonic_ns() | |
| self.generator, self.discriminators = self._build_model() | |
| end = time.monotonic_ns() | |
| self.logger.debug(self.generator) | |
| for _, discriminator in self.discriminators.items(): | |
| self.logger.debug(discriminator) | |
| self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") | |
| self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M") | |
| # Build optimizers and schedulers | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building optimizer and scheduler...") | |
| start = time.monotonic_ns() | |
| ( | |
| self.generator_optimizer, | |
| self.discriminator_optimizer, | |
| ) = self._build_optimizer() | |
| ( | |
| self.generator_scheduler, | |
| self.discriminator_scheduler, | |
| ) = self._build_scheduler() | |
| end = time.monotonic_ns() | |
| self.logger.info( | |
| f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| # Accelerator preparing | |
| self.logger.info("Initializing accelerate...") | |
| start = time.monotonic_ns() | |
| ( | |
| self.train_dataloader, | |
| self.valid_dataloader, | |
| self.generator, | |
| self.generator_optimizer, | |
| self.discriminator_optimizer, | |
| self.generator_scheduler, | |
| self.discriminator_scheduler, | |
| ) = self.accelerator.prepare( | |
| self.train_dataloader, | |
| self.valid_dataloader, | |
| self.generator, | |
| self.generator_optimizer, | |
| self.discriminator_optimizer, | |
| self.generator_scheduler, | |
| self.discriminator_scheduler, | |
| ) | |
| for key, discriminator in self.discriminators.items(): | |
| self.discriminators[key] = self.accelerator.prepare_model(discriminator) | |
| end = time.monotonic_ns() | |
| self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") | |
| # Build criterions | |
| with self.accelerator.main_process_first(): | |
| self.logger.info("Building criterion...") | |
| start = time.monotonic_ns() | |
| self.criterions = self._build_criterion() | |
| end = time.monotonic_ns() | |
| self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") | |
| # Resume checkpoints | |
| with self.accelerator.main_process_first(): | |
| if args.resume_type: | |
| self.logger.info("Resuming from checkpoint...") | |
| start = time.monotonic_ns() | |
| ckpt_path = Path(args.checkpoint) | |
| if self._is_valid_pattern(ckpt_path.parts[-1]): | |
| ckpt_path = self._load_model( | |
| None, args.checkpoint, args.resume_type | |
| ) | |
| else: | |
| ckpt_path = self._load_model( | |
| args.checkpoint, resume_type=args.resume_type | |
| ) | |
| end = time.monotonic_ns() | |
| self.logger.info( | |
| f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" | |
| ) | |
| self.checkpoints_path = json.load( | |
| open(os.path.join(ckpt_path, "ckpts.json"), "r") | |
| ) | |
| self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") | |
| if self.accelerator.is_main_process: | |
| os.makedirs(self.checkpoint_dir, exist_ok=True) | |
| self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") | |
| # Save config | |
| self.config_save_path = os.path.join(self.exp_dir, "args.json") | |
| def _build_dataset(self): | |
| return GANVocoderDataset, GANVocoderCollator | |
| def _build_criterion(self): | |
| class feature_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(feature_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, fmap_r, fmap_g): | |
| loss = 0 | |
| if self.cfg.model.generator in [ | |
| "hifigan", | |
| "nsfhifigan", | |
| "bigvgan", | |
| "apnet", | |
| ]: | |
| for dr, dg in zip(fmap_r, fmap_g): | |
| for rl, gl in zip(dr, dg): | |
| loss += torch.mean(torch.abs(rl - gl)) | |
| loss = loss * 2 | |
| elif self.cfg.model.generator in ["melgan"]: | |
| for dr, dg in zip(fmap_r, fmap_g): | |
| for rl, gl in zip(dr, dg): | |
| loss += self.l1Loss(rl, gl) | |
| loss = loss * 10 | |
| elif self.cfg.model.generator in ["codec"]: | |
| for dr, dg in zip(fmap_r, fmap_g): | |
| for rl, gl in zip(dr, dg): | |
| loss = loss + self.l1Loss(rl, gl) / torch.mean( | |
| torch.abs(rl) | |
| ) | |
| KL_scale = len(fmap_r) * len(fmap_r[0]) | |
| loss = 3 * loss / KL_scale | |
| else: | |
| raise NotImplementedError | |
| return loss | |
| class discriminator_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(discriminator_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, disc_real_outputs, disc_generated_outputs): | |
| loss = 0 | |
| r_losses = [] | |
| g_losses = [] | |
| if self.cfg.model.generator in [ | |
| "hifigan", | |
| "nsfhifigan", | |
| "bigvgan", | |
| "apnet", | |
| ]: | |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
| r_loss = torch.mean((1 - dr) ** 2) | |
| g_loss = torch.mean(dg**2) | |
| loss += r_loss + g_loss | |
| r_losses.append(r_loss.item()) | |
| g_losses.append(g_loss.item()) | |
| elif self.cfg.model.generator in ["melgan"]: | |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
| r_loss = torch.mean(self.relu(1 - dr)) | |
| g_loss = torch.mean(self.relu(1 + dg)) | |
| loss = loss + r_loss + g_loss | |
| r_losses.append(r_loss.item()) | |
| g_losses.append(g_loss.item()) | |
| elif self.cfg.model.generator in ["codec"]: | |
| for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
| r_loss = torch.mean(self.relu(1 - dr)) | |
| g_loss = torch.mean(self.relu(1 + dg)) | |
| loss = loss + r_loss + g_loss | |
| r_losses.append(r_loss.item()) | |
| g_losses.append(g_loss.item()) | |
| loss = loss / len(disc_real_outputs) | |
| else: | |
| raise NotImplementedError | |
| return loss, r_losses, g_losses | |
| class generator_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(generator_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, disc_outputs): | |
| loss = 0 | |
| gen_losses = [] | |
| if self.cfg.model.generator in [ | |
| "hifigan", | |
| "nsfhifigan", | |
| "bigvgan", | |
| "apnet", | |
| ]: | |
| for dg in disc_outputs: | |
| l = torch.mean((1 - dg) ** 2) | |
| gen_losses.append(l) | |
| loss += l | |
| elif self.cfg.model.generator in ["melgan"]: | |
| for dg in disc_outputs: | |
| l = -torch.mean(dg) | |
| gen_losses.append(l) | |
| loss += l | |
| elif self.cfg.model.generator in ["codec"]: | |
| for dg in disc_outputs: | |
| l = torch.mean(self.relu(1 - dg)) / len(disc_outputs) | |
| gen_losses.append(l) | |
| loss += l | |
| else: | |
| raise NotImplementedError | |
| return loss, gen_losses | |
| class mel_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(mel_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, y_gt, y_pred): | |
| loss = 0 | |
| if self.cfg.model.generator in [ | |
| "hifigan", | |
| "nsfhifigan", | |
| "bigvgan", | |
| "melgan", | |
| "codec", | |
| "apnet", | |
| ]: | |
| y_gt_mel = extract_mel_features(y_gt, self.cfg.preprocess) | |
| y_pred_mel = extract_mel_features( | |
| y_pred.squeeze(1), self.cfg.preprocess | |
| ) | |
| loss = self.l1Loss(y_gt_mel, y_pred_mel) * 45 | |
| else: | |
| raise NotImplementedError | |
| return loss | |
| class wav_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(wav_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, y_gt, y_pred): | |
| loss = 0 | |
| if self.cfg.model.generator in [ | |
| "hifigan", | |
| "nsfhifigan", | |
| "bigvgan", | |
| "apnet", | |
| ]: | |
| loss = self.l2Loss(y_gt, y_pred.squeeze(1)) * 100 | |
| elif self.cfg.model.generator in ["melgan"]: | |
| loss = self.l1Loss(y_gt, y_pred.squeeze(1)) / 10 | |
| elif self.cfg.model.generator in ["codec"]: | |
| loss = self.l1Loss(y_gt, y_pred.squeeze(1)) + self.l2Loss( | |
| y_gt, y_pred.squeeze(1) | |
| ) | |
| loss /= 10 | |
| else: | |
| raise NotImplementedError | |
| return loss | |
| class phase_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(phase_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, phase_gt, phase_pred): | |
| n_fft = self.cfg.preprocess.n_fft | |
| frames = phase_gt.size()[-1] | |
| GD_matrix = ( | |
| torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) | |
| - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) | |
| - torch.eye(n_fft // 2 + 1) | |
| ) | |
| GD_matrix = GD_matrix.to(phase_pred.device) | |
| GD_r = torch.matmul(phase_gt.permute(0, 2, 1), GD_matrix) | |
| GD_g = torch.matmul(phase_pred.permute(0, 2, 1), GD_matrix) | |
| PTD_matrix = ( | |
| torch.triu(torch.ones(frames, frames), diagonal=1) | |
| - torch.triu(torch.ones(frames, frames), diagonal=2) | |
| - torch.eye(frames) | |
| ) | |
| PTD_matrix = PTD_matrix.to(phase_pred.device) | |
| PTD_r = torch.matmul(phase_gt, PTD_matrix) | |
| PTD_g = torch.matmul(phase_pred, PTD_matrix) | |
| IP_loss = torch.mean(-torch.cos(phase_gt - phase_pred)) | |
| GD_loss = torch.mean(-torch.cos(GD_r - GD_g)) | |
| PTD_loss = torch.mean(-torch.cos(PTD_r - PTD_g)) | |
| return 100 * (IP_loss + GD_loss + PTD_loss) | |
| class amplitude_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(amplitude_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__(self, log_amplitude_gt, log_amplitude_pred): | |
| amplitude_loss = self.l2Loss(log_amplitude_gt, log_amplitude_pred) | |
| return 45 * amplitude_loss | |
| class consistency_criterion(torch.nn.Module): | |
| def __init__(self, cfg): | |
| super(consistency_criterion, self).__init__() | |
| self.cfg = cfg | |
| self.l1Loss = torch.nn.L1Loss(reduction="mean") | |
| self.l2Loss = torch.nn.MSELoss(reduction="mean") | |
| self.relu = torch.nn.ReLU() | |
| def __call__( | |
| self, | |
| rea_gt, | |
| rea_pred, | |
| rea_pred_final, | |
| imag_gt, | |
| imag_pred, | |
| imag_pred_final, | |
| ): | |
| C_loss = torch.mean( | |
| torch.mean( | |
| (rea_pred - rea_pred_final) ** 2 | |
| + (imag_pred - imag_pred_final) ** 2, | |
| (1, 2), | |
| ) | |
| ) | |
| L_R = self.l1Loss(rea_gt, rea_pred) | |
| L_I = self.l1Loss(imag_gt, imag_pred) | |
| return 20 * (C_loss + 2.25 * (L_R + L_I)) | |
| criterions = dict() | |
| for key in self.cfg.train.criterions: | |
| if key == "feature": | |
| criterions["feature"] = feature_criterion(self.cfg) | |
| elif key == "discriminator": | |
| criterions["discriminator"] = discriminator_criterion(self.cfg) | |
| elif key == "generator": | |
| criterions["generator"] = generator_criterion(self.cfg) | |
| elif key == "mel": | |
| criterions["mel"] = mel_criterion(self.cfg) | |
| elif key == "wav": | |
| criterions["wav"] = wav_criterion(self.cfg) | |
| elif key == "phase": | |
| criterions["phase"] = phase_criterion(self.cfg) | |
| elif key == "amplitude": | |
| criterions["amplitude"] = amplitude_criterion(self.cfg) | |
| elif key == "consistency": | |
| criterions["consistency"] = consistency_criterion(self.cfg) | |
| else: | |
| raise NotImplementedError | |
| return criterions | |
| def _build_model(self): | |
| generator = supported_generators[self.cfg.model.generator](self.cfg) | |
| discriminators = dict() | |
| for key in self.cfg.model.discriminators: | |
| discriminators[key] = supported_discriminators[key](self.cfg) | |
| return generator, discriminators | |
| def _build_optimizer(self): | |
| optimizer_params_generator = [dict(params=self.generator.parameters())] | |
| generator_optimizer = AdamW( | |
| optimizer_params_generator, | |
| lr=self.cfg.train.adamw.lr, | |
| betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2), | |
| ) | |
| optimizer_params_discriminator = [] | |
| for discriminator in self.discriminators.keys(): | |
| optimizer_params_discriminator.append( | |
| dict(params=self.discriminators[discriminator].parameters()) | |
| ) | |
| discriminator_optimizer = AdamW( | |
| optimizer_params_discriminator, | |
| lr=self.cfg.train.adamw.lr, | |
| betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2), | |
| ) | |
| return generator_optimizer, discriminator_optimizer | |
| def _build_scheduler(self): | |
| discriminator_scheduler = ExponentialLR( | |
| self.discriminator_optimizer, | |
| gamma=self.cfg.train.exponential_lr.lr_decay, | |
| last_epoch=self.epoch - 1, | |
| ) | |
| generator_scheduler = ExponentialLR( | |
| self.generator_optimizer, | |
| gamma=self.cfg.train.exponential_lr.lr_decay, | |
| last_epoch=self.epoch - 1, | |
| ) | |
| return generator_scheduler, discriminator_scheduler | |
| def train_loop(self): | |
| """Training process""" | |
| self.accelerator.wait_for_everyone() | |
| # Dump config | |
| if self.accelerator.is_main_process: | |
| self._dump_cfg(self.config_save_path) | |
| self.generator.train() | |
| for key in self.discriminators.keys(): | |
| self.discriminators[key].train() | |
| self.generator_optimizer.zero_grad() | |
| self.discriminator_optimizer.zero_grad() | |
| # Sync and start training | |
| self.accelerator.wait_for_everyone() | |
| while self.epoch < self.max_epoch: | |
| self.logger.info("\n") | |
| self.logger.info("-" * 32) | |
| self.logger.info("Epoch {}: ".format(self.epoch)) | |
| # Train and Validate | |
| train_total_loss, train_losses = self._train_epoch() | |
| for key, loss in train_losses.items(): | |
| self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) | |
| self.accelerator.log( | |
| {"Epoch/Train {} Loss".format(key): loss}, | |
| step=self.epoch, | |
| ) | |
| valid_total_loss, valid_losses = self._valid_epoch() | |
| for key, loss in valid_losses.items(): | |
| self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss)) | |
| self.accelerator.log( | |
| {"Epoch/Valid {} Loss".format(key): loss}, | |
| step=self.epoch, | |
| ) | |
| self.accelerator.log( | |
| { | |
| "Epoch/Train Total Loss": train_total_loss, | |
| "Epoch/Valid Total Loss": valid_total_loss, | |
| }, | |
| step=self.epoch, | |
| ) | |
| # Update scheduler | |
| self.accelerator.wait_for_everyone() | |
| self.generator_scheduler.step() | |
| self.discriminator_scheduler.step() | |
| # Check save checkpoint interval | |
| run_eval = False | |
| if self.accelerator.is_main_process: | |
| save_checkpoint = False | |
| for i, num in enumerate(self.save_checkpoint_stride): | |
| if self.epoch % num == 0: | |
| save_checkpoint = True | |
| run_eval |= self.run_eval[i] | |
| # Save checkpoints | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process and save_checkpoint: | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, valid_total_loss | |
| ), | |
| ) | |
| self.accelerator.save_state(path) | |
| json.dump( | |
| self.checkpoints_path, | |
| open(os.path.join(path, "ckpts.json"), "w"), | |
| ensure_ascii=False, | |
| indent=4, | |
| ) | |
| # Save eval audios | |
| self.accelerator.wait_for_everyone() | |
| if self.accelerator.is_main_process and run_eval: | |
| for i in range(len(self.valid_dataloader.dataset.eval_audios)): | |
| if self.cfg.preprocess.use_frame_pitch: | |
| eval_audio = self._inference( | |
| self.valid_dataloader.dataset.eval_mels[i], | |
| eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i], | |
| use_pitch=True, | |
| ) | |
| else: | |
| eval_audio = self._inference( | |
| self.valid_dataloader.dataset.eval_mels[i] | |
| ) | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format( | |
| self.epoch, | |
| self.step, | |
| valid_total_loss, | |
| self.valid_dataloader.dataset.eval_dataset_names[i], | |
| ), | |
| ) | |
| path_gt = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format( | |
| self.epoch, | |
| self.step, | |
| valid_total_loss, | |
| self.valid_dataloader.dataset.eval_dataset_names[i], | |
| ), | |
| ) | |
| save_audio(path, eval_audio, self.cfg.preprocess.sample_rate) | |
| save_audio( | |
| path_gt, | |
| self.valid_dataloader.dataset.eval_audios[i], | |
| self.cfg.preprocess.sample_rate, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| self.epoch += 1 | |
| # Finish training | |
| self.accelerator.wait_for_everyone() | |
| path = os.path.join( | |
| self.checkpoint_dir, | |
| "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( | |
| self.epoch, self.step, valid_total_loss | |
| ), | |
| ) | |
| self.accelerator.save_state(path) | |
| def _train_epoch(self): | |
| """Training epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| self.generator.train() | |
| for key, _ in self.discriminators.items(): | |
| self.discriminators[key].train() | |
| epoch_losses: dict = {} | |
| epoch_total_loss: int = 0 | |
| for batch in tqdm( | |
| self.train_dataloader, | |
| desc=f"Training Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Get losses | |
| total_loss, losses = self._train_step(batch) | |
| self.batch_count += 1 | |
| # Log info | |
| if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
| self.accelerator.log( | |
| { | |
| "Step/Generator Learning Rate": self.generator_optimizer.param_groups[ | |
| 0 | |
| ][ | |
| "lr" | |
| ], | |
| "Step/Discriminator Learning Rate": self.discriminator_optimizer.param_groups[ | |
| 0 | |
| ][ | |
| "lr" | |
| ], | |
| }, | |
| step=self.step, | |
| ) | |
| for key, _ in losses.items(): | |
| self.accelerator.log( | |
| { | |
| "Step/Train {} Loss".format(key): losses[key], | |
| }, | |
| step=self.step, | |
| ) | |
| if not epoch_losses: | |
| epoch_losses = losses | |
| else: | |
| for key, value in losses.items(): | |
| epoch_losses[key] += value | |
| epoch_total_loss += total_loss | |
| self.step += 1 | |
| # Get and log total losses | |
| self.accelerator.wait_for_everyone() | |
| epoch_total_loss = ( | |
| epoch_total_loss | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = ( | |
| epoch_losses[key] | |
| / len(self.train_dataloader) | |
| * self.cfg.train.gradient_accumulation_step | |
| ) | |
| return epoch_total_loss, epoch_losses | |
| def _train_step(self, data): | |
| """Training forward step. Should return average loss of a sample over | |
| one batch. Provoke ``_forward_step`` is recommended except for special case. | |
| See ``_train_epoch`` for usage. | |
| """ | |
| # Init losses | |
| train_losses = {} | |
| total_loss = 0 | |
| generator_losses = {} | |
| generator_total_loss = 0 | |
| discriminator_losses = {} | |
| discriminator_total_loss = 0 | |
| # Use input feature to get predictions | |
| mel_input = data["mel"] | |
| audio_gt = data["audio"] | |
| if self.cfg.preprocess.extract_amplitude_phase: | |
| logamp_gt = data["logamp"] | |
| pha_gt = data["pha"] | |
| rea_gt = data["rea"] | |
| imag_gt = data["imag"] | |
| if self.cfg.preprocess.use_frame_pitch: | |
| pitch_input = data["frame_pitch"] | |
| if self.cfg.preprocess.use_frame_pitch: | |
| pitch_input = pitch_input.float() | |
| audio_pred = self.generator.forward(mel_input, pitch_input) | |
| elif self.cfg.preprocess.extract_amplitude_phase: | |
| ( | |
| logamp_pred, | |
| pha_pred, | |
| rea_pred, | |
| imag_pred, | |
| audio_pred, | |
| ) = self.generator.forward(mel_input) | |
| from utils.mel import amplitude_phase_spectrum | |
| _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum( | |
| audio_pred.squeeze(1), self.cfg.preprocess | |
| ) | |
| else: | |
| audio_pred = self.generator.forward(mel_input) | |
| # Calculate and BP Discriminator losses | |
| self.discriminator_optimizer.zero_grad() | |
| for key, _ in self.discriminators.items(): | |
| y_r, y_g, _, _ = self.discriminators[key].forward( | |
| audio_gt.unsqueeze(1), audio_pred.detach() | |
| ) | |
| ( | |
| discriminator_losses["{}_discriminator".format(key)], | |
| _, | |
| _, | |
| ) = self.criterions["discriminator"](y_r, y_g) | |
| discriminator_total_loss += discriminator_losses[ | |
| "{}_discriminator".format(key) | |
| ] | |
| self.accelerator.backward(discriminator_total_loss) | |
| self.discriminator_optimizer.step() | |
| # Calculate and BP Generator losses | |
| self.generator_optimizer.zero_grad() | |
| for key, _ in self.discriminators.items(): | |
| y_r, y_g, f_r, f_g = self.discriminators[key].forward( | |
| audio_gt.unsqueeze(1), audio_pred | |
| ) | |
| generator_losses["{}_feature".format(key)] = self.criterions["feature"]( | |
| f_r, f_g | |
| ) | |
| generator_losses["{}_generator".format(key)], _ = self.criterions[ | |
| "generator" | |
| ](y_g) | |
| generator_total_loss += generator_losses["{}_feature".format(key)] | |
| generator_total_loss += generator_losses["{}_generator".format(key)] | |
| if "mel" in self.criterions.keys(): | |
| generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) | |
| generator_total_loss += generator_losses["mel"] | |
| if "wav" in self.criterions.keys(): | |
| generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) | |
| generator_total_loss += generator_losses["wav"] | |
| if "amplitude" in self.criterions.keys(): | |
| generator_losses["amplitude"] = self.criterions["amplitude"]( | |
| logamp_gt, logamp_pred | |
| ) | |
| generator_total_loss += generator_losses["amplitude"] | |
| if "phase" in self.criterions.keys(): | |
| generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred) | |
| generator_total_loss += generator_losses["phase"] | |
| if "consistency" in self.criterions.keys(): | |
| generator_losses["consistency"] = self.criterions["consistency"]( | |
| rea_gt, rea_pred, rea_pred_final, imag_gt, imag_pred, imag_pred_final | |
| ) | |
| generator_total_loss += generator_losses["consistency"] | |
| self.accelerator.backward(generator_total_loss) | |
| self.generator_optimizer.step() | |
| # Get the total losses | |
| total_loss = discriminator_total_loss + generator_total_loss | |
| train_losses.update(discriminator_losses) | |
| train_losses.update(generator_losses) | |
| for key, _ in train_losses.items(): | |
| train_losses[key] = train_losses[key].item() | |
| return total_loss.item(), train_losses | |
| def _valid_epoch(self): | |
| """Testing epoch. Should return average loss of a batch (sample) over | |
| one epoch. See ``train_loop`` for usage. | |
| """ | |
| self.generator.eval() | |
| for key, _ in self.discriminators.items(): | |
| self.discriminators[key].eval() | |
| epoch_losses: dict = {} | |
| epoch_total_loss: int = 0 | |
| for batch in tqdm( | |
| self.valid_dataloader, | |
| desc=f"Validating Epoch {self.epoch}", | |
| unit="batch", | |
| colour="GREEN", | |
| leave=False, | |
| dynamic_ncols=True, | |
| smoothing=0.04, | |
| disable=not self.accelerator.is_main_process, | |
| ): | |
| # Get losses | |
| total_loss, losses = self._valid_step(batch) | |
| # Log info | |
| for key, _ in losses.items(): | |
| self.accelerator.log( | |
| { | |
| "Step/Valid {} Loss".format(key): losses[key], | |
| }, | |
| step=self.step, | |
| ) | |
| if not epoch_losses: | |
| epoch_losses = losses | |
| else: | |
| for key, value in losses.items(): | |
| epoch_losses[key] += value | |
| epoch_total_loss += total_loss | |
| # Get and log total losses | |
| self.accelerator.wait_for_everyone() | |
| epoch_total_loss = epoch_total_loss / len(self.valid_dataloader) | |
| for key in epoch_losses.keys(): | |
| epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) | |
| return epoch_total_loss, epoch_losses | |
| def _valid_step(self, data): | |
| """Testing forward step. Should return average loss of a sample over | |
| one batch. Provoke ``_forward_step`` is recommended except for special case. | |
| See ``_test_epoch`` for usage. | |
| """ | |
| # Init losses | |
| valid_losses = {} | |
| total_loss = 0 | |
| generator_losses = {} | |
| generator_total_loss = 0 | |
| discriminator_losses = {} | |
| discriminator_total_loss = 0 | |
| # Use feature inputs to get the predicted audio | |
| mel_input = data["mel"] | |
| audio_gt = data["audio"] | |
| if self.cfg.preprocess.extract_amplitude_phase: | |
| logamp_gt = data["logamp"] | |
| pha_gt = data["pha"] | |
| rea_gt = data["rea"] | |
| imag_gt = data["imag"] | |
| if self.cfg.preprocess.use_frame_pitch: | |
| pitch_input = data["frame_pitch"] | |
| if self.cfg.preprocess.use_frame_pitch: | |
| pitch_input = pitch_input.float() | |
| audio_pred = self.generator.forward(mel_input, pitch_input) | |
| elif self.cfg.preprocess.extract_amplitude_phase: | |
| ( | |
| logamp_pred, | |
| pha_pred, | |
| rea_pred, | |
| imag_pred, | |
| audio_pred, | |
| ) = self.generator.forward(mel_input) | |
| from utils.mel import amplitude_phase_spectrum | |
| _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum( | |
| audio_pred.squeeze(1), self.cfg.preprocess | |
| ) | |
| else: | |
| audio_pred = self.generator.forward(mel_input) | |
| # Get Discriminator losses | |
| for key, _ in self.discriminators.items(): | |
| y_r, y_g, _, _ = self.discriminators[key].forward( | |
| audio_gt.unsqueeze(1), audio_pred | |
| ) | |
| ( | |
| discriminator_losses["{}_discriminator".format(key)], | |
| _, | |
| _, | |
| ) = self.criterions["discriminator"](y_r, y_g) | |
| discriminator_total_loss += discriminator_losses[ | |
| "{}_discriminator".format(key) | |
| ] | |
| for key, _ in self.discriminators.items(): | |
| y_r, y_g, f_r, f_g = self.discriminators[key].forward( | |
| audio_gt.unsqueeze(1), audio_pred | |
| ) | |
| generator_losses["{}_feature".format(key)] = self.criterions["feature"]( | |
| f_r, f_g | |
| ) | |
| generator_losses["{}_generator".format(key)], _ = self.criterions[ | |
| "generator" | |
| ](y_g) | |
| generator_total_loss += generator_losses["{}_feature".format(key)] | |
| generator_total_loss += generator_losses["{}_generator".format(key)] | |
| if "mel" in self.criterions.keys(): | |
| generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) | |
| generator_total_loss += generator_losses["mel"] | |
| if "mel" in self.criterions.keys(): | |
| generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) | |
| generator_total_loss += generator_losses["mel"] | |
| if "wav" in self.criterions.keys(): | |
| generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) | |
| generator_total_loss += generator_losses["wav"] | |
| if "wav" in self.criterions.keys(): | |
| generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) | |
| generator_total_loss += generator_losses["wav"] | |
| if "amplitude" in self.criterions.keys(): | |
| generator_losses["amplitude"] = self.criterions["amplitude"]( | |
| logamp_gt, logamp_pred | |
| ) | |
| generator_total_loss += generator_losses["amplitude"] | |
| if "phase" in self.criterions.keys(): | |
| generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred) | |
| generator_total_loss += generator_losses["phase"] | |
| if "consistency" in self.criterions.keys(): | |
| generator_losses["consistency"] = self.criterions["consistency"]( | |
| rea_gt, | |
| rea_pred, | |
| rea_pred_final, | |
| imag_gt, | |
| imag_pred, | |
| imag_pred_final, | |
| ) | |
| generator_total_loss += generator_losses["consistency"] | |
| total_loss = discriminator_total_loss + generator_total_loss | |
| valid_losses.update(discriminator_losses) | |
| valid_losses.update(generator_losses) | |
| for item in valid_losses: | |
| valid_losses[item] = valid_losses[item].item() | |
| for item in valid_losses: | |
| valid_losses[item] = valid_losses[item].item() | |
| return total_loss.item(), valid_losses | |
| return total_loss.item(), valid_losses | |
| def _inference(self, eval_mel, eval_pitch=None, use_pitch=False): | |
| """Inference during training for test audios.""" | |
| if use_pitch: | |
| eval_pitch = align_length(eval_pitch, eval_mel.shape[1]) | |
| eval_audio = vocoder_inference( | |
| self.cfg, | |
| self.generator, | |
| torch.from_numpy(eval_mel).unsqueeze(0), | |
| f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(), | |
| device=next(self.generator.parameters()).device, | |
| ).squeeze(0) | |
| else: | |
| eval_audio = vocoder_inference( | |
| self.cfg, | |
| self.generator, | |
| torch.from_numpy(eval_mel).unsqueeze(0), | |
| device=next(self.generator.parameters()).device, | |
| ).squeeze(0) | |
| return eval_audio | |
| def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): | |
| """Load model from checkpoint. If checkpoint_path is None, it will | |
| load the latest checkpoint in checkpoint_dir. If checkpoint_path is not | |
| None, it will load the checkpoint specified by checkpoint_path. **Only use this | |
| method after** ``accelerator.prepare()``. | |
| """ | |
| if checkpoint_path is None: | |
| ls = [str(i) for i in Path(checkpoint_dir).glob("*")] | |
| ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) | |
| checkpoint_path = ls[0] | |
| if resume_type == "resume": | |
| self.accelerator.load_state(checkpoint_path) | |
| self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 | |
| self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 | |
| elif resume_type == "finetune": | |
| accelerate.load_checkpoint_and_dispatch( | |
| self.accelerator.unwrap_model(self.generator), | |
| os.path.join(checkpoint_path, "pytorch_model.bin"), | |
| ) | |
| for key, _ in self.discriminators.items(): | |
| accelerate.load_checkpoint_and_dispatch( | |
| self.accelerator.unwrap_model(self.discriminators[key]), | |
| os.path.join(checkpoint_path, "pytorch_model.bin"), | |
| ) | |
| self.logger.info("Load model weights for finetune SUCCESS!") | |
| else: | |
| raise ValueError("Unsupported resume type: {}".format(resume_type)) | |
| return checkpoint_path | |
| def _count_parameters(self): | |
| result = sum(p.numel() for p in self.generator.parameters()) | |
| for _, discriminator in self.discriminators.items(): | |
| result += sum(p.numel() for p in discriminator.parameters()) | |
| return result | |