Spaces:
Build error
Build error
| from tasks.tts.fs2 import FastSpeech2Task | |
| from modules.syntaspeech.multi_window_disc import Discriminator | |
| from utils.hparams import hparams | |
| from torch import nn | |
| import torch | |
| import torch.optim | |
| import torch.utils.data | |
| import utils | |
| class FastSpeech2AdvTask(FastSpeech2Task): | |
| def build_model(self): | |
| self.build_tts_model() | |
| if hparams['load_ckpt'] != '': | |
| self.load_ckpt(hparams['load_ckpt'], strict=False) | |
| utils.print_arch(self.model, 'Generator') | |
| self.build_disc_model() | |
| if not hasattr(self, 'gen_params'): | |
| self.gen_params = list(self.model.parameters()) | |
| return self.model | |
| def build_disc_model(self): | |
| disc_win_num = hparams['disc_win_num'] | |
| h = hparams['mel_disc_hidden_size'] | |
| self.mel_disc = Discriminator( | |
| time_lengths=[32, 64, 128][:disc_win_num], | |
| freq_length=80, hidden_size=h, kernel=(3, 3) | |
| ) | |
| self.disc_params = list(self.mel_disc.parameters()) | |
| utils.print_arch(self.mel_disc, model_name='Mel Disc') | |
| def _training_step(self, sample, batch_idx, optimizer_idx): | |
| log_outputs = {} | |
| loss_weights = {} | |
| disc_start = hparams['mel_gan'] and self.global_step >= hparams["disc_start_steps"] and \ | |
| hparams['lambda_mel_adv'] > 0 | |
| if optimizer_idx == 0: | |
| ####################### | |
| # Generator # | |
| ####################### | |
| log_outputs, model_out = self.run_model(self.model, sample, return_output=True) | |
| self.model_out = {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} | |
| if disc_start: | |
| self.disc_cond = disc_cond = self.model_out['decoder_inp'].detach() \ | |
| if hparams['use_cond_disc'] else None | |
| if hparams['mel_loss_no_noise']: | |
| self.add_mel_loss(model_out['mel_out_nonoise'], sample['mels'], log_outputs) | |
| mel_p = model_out['mel_out'] | |
| if hasattr(self.model, 'out2mel'): | |
| mel_p = self.model.out2mel(mel_p) | |
| o_ = self.mel_disc(mel_p, disc_cond) | |
| p_, pc_ = o_['y'], o_['y_c'] | |
| if p_ is not None: | |
| log_outputs['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size())) | |
| loss_weights['a'] = hparams['lambda_mel_adv'] | |
| if pc_ is not None: | |
| log_outputs['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size())) | |
| loss_weights['ac'] = hparams['lambda_mel_adv'] | |
| else: | |
| ####################### | |
| # Discriminator # | |
| ####################### | |
| if disc_start and self.global_step % hparams['disc_interval'] == 0: | |
| if hparams['rerun_gen']: | |
| with torch.no_grad(): | |
| _, model_out = self.run_model(self.model, sample, return_output=True) | |
| else: | |
| model_out = self.model_out | |
| mel_g = sample['mels'] | |
| mel_p = model_out['mel_out'] | |
| if hasattr(self.model, 'out2mel'): | |
| mel_p = self.model.out2mel(mel_p) | |
| o = self.mel_disc(mel_g, self.disc_cond) | |
| p, pc = o['y'], o['y_c'] | |
| o_ = self.mel_disc(mel_p, self.disc_cond) | |
| p_, pc_ = o_['y'], o_['y_c'] | |
| if p_ is not None: | |
| log_outputs["r"] = self.mse_loss_fn(p, p.new_ones(p.size())) | |
| log_outputs["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size())) | |
| if pc_ is not None: | |
| log_outputs["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size())) | |
| log_outputs["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size())) | |
| if len(log_outputs) == 0: | |
| return None | |
| total_loss = sum([loss_weights.get(k, 1) * v for k, v in log_outputs.items()]) | |
| log_outputs['bs'] = sample['mels'].shape[0] | |
| return total_loss, log_outputs | |
| def configure_optimizers(self): | |
| if not hasattr(self, 'gen_params'): | |
| self.gen_params = list(self.model.parameters()) | |
| optimizer_gen = torch.optim.AdamW( | |
| self.gen_params, | |
| lr=hparams['lr'], | |
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
| weight_decay=hparams['weight_decay']) | |
| optimizer_disc = torch.optim.AdamW( | |
| self.disc_params, | |
| lr=hparams['disc_lr'], | |
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
| **hparams["discriminator_optimizer_params"]) if len(self.disc_params) > 0 else None | |
| self.scheduler = self.build_scheduler({'gen': optimizer_gen, 'disc': optimizer_disc}) | |
| return [optimizer_gen, optimizer_disc] | |
| def build_scheduler(self, optimizer): | |
| return { | |
| "gen": super().build_scheduler(optimizer['gen']), | |
| "disc": torch.optim.lr_scheduler.StepLR( | |
| optimizer=optimizer["disc"], | |
| **hparams["discriminator_scheduler_params"]) if optimizer["disc"] is not None else None, | |
| } | |
| def on_before_optimization(self, opt_idx): | |
| if opt_idx == 0: | |
| nn.utils.clip_grad_norm_(self.gen_params, hparams['generator_grad_norm']) | |
| else: | |
| nn.utils.clip_grad_norm_(self.disc_params, hparams["discriminator_grad_norm"]) | |
| def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): | |
| if optimizer_idx == 0: | |
| self.scheduler['gen'].step(self.global_step) | |
| else: | |
| self.scheduler['disc'].step(max(self.global_step - hparams["disc_start_steps"], 1)) | |