Spaces:
Build error
Build error
| import torch | |
| from text_to_speech.modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow | |
| from tasks.tts.fs import FastSpeechTask | |
| from tasks.tts.ps import PortaSpeechTask | |
| from text_to_speech.utils.audio.pitch.utils import denorm_f0 | |
| from text_to_speech.utils.commons.hparams import hparams | |
| class PortaSpeechFlowTask(PortaSpeechTask): | |
| def __init__(self): | |
| super().__init__() | |
| self.training_post_glow = False | |
| def build_tts_model(self): | |
| ph_dict_size = len(self.token_encoder) | |
| word_dict_size = len(self.word_encoder) | |
| self.model = PortaSpeechFlow(ph_dict_size, word_dict_size, hparams) | |
| def _training_step(self, sample, batch_idx, opt_idx): | |
| self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \ | |
| and hparams['use_post_flow'] | |
| if hparams['two_stage'] and \ | |
| ((opt_idx == 0 and self.training_post_glow) or (opt_idx == 1 and not self.training_post_glow)): | |
| return None | |
| loss_output, _ = self.run_model(sample) | |
| total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
| loss_output['batch_size'] = sample['txt_tokens'].size()[0] | |
| if 'postflow' in loss_output and loss_output['postflow'] is None: | |
| return None | |
| return total_loss, loss_output | |
| def run_model(self, sample, infer=False, *args, **kwargs): | |
| if not infer: | |
| training_post_glow = self.training_post_glow | |
| spk_embed = sample.get('spk_embed') | |
| spk_id = sample.get('spk_ids') | |
| output = self.model(sample['txt_tokens'], | |
| sample['word_tokens'], | |
| ph2word=sample['ph2word'], | |
| mel2word=sample['mel2word'], | |
| mel2ph=sample['mel2ph'], | |
| word_len=sample['word_lengths'].max(), | |
| tgt_mels=sample['mels'], | |
| pitch=sample.get('pitch'), | |
| spk_embed=spk_embed, | |
| spk_id=spk_id, | |
| infer=False, | |
| forward_post_glow=training_post_glow, | |
| two_stage=hparams['two_stage'], | |
| global_step=self.global_step, | |
| bert_feats=sample.get('bert_feats')) | |
| losses = {} | |
| self.add_mel_loss(output['mel_out'], sample['mels'], losses) | |
| if (training_post_glow or not hparams['two_stage']) and hparams['use_post_flow']: | |
| losses['postflow'] = output['postflow'] | |
| losses['l1'] = losses['l1'].detach() | |
| losses['ssim'] = losses['ssim'].detach() | |
| if not training_post_glow or not hparams['two_stage'] or not self.training: | |
| losses['kl'] = output['kl'] | |
| if self.global_step < hparams['kl_start_steps']: | |
| losses['kl'] = losses['kl'].detach() | |
| else: | |
| losses['kl'] = torch.clamp(losses['kl'], min=hparams['kl_min']) | |
| losses['kl'] = losses['kl'] * hparams['lambda_kl'] | |
| if hparams['dur_level'] == 'word': | |
| self.add_dur_loss( | |
| output['dur'], sample['mel2word'], sample['word_lengths'], sample['txt_tokens'], losses) | |
| self.get_attn_stats(output['attn'], sample, losses) | |
| else: | |
| super().add_dur_loss(output['dur'], sample['mel2ph'], sample['txt_tokens'], losses) | |
| return losses, output | |
| else: | |
| use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur']) | |
| forward_post_glow = self.global_step >= hparams['post_glow_training_start'] + 1000 \ | |
| and hparams['use_post_flow'] | |
| spk_embed = sample.get('spk_embed') | |
| spk_id = sample.get('spk_ids') | |
| output = self.model( | |
| sample['txt_tokens'], | |
| sample['word_tokens'], | |
| ph2word=sample['ph2word'], | |
| word_len=sample['word_lengths'].max(), | |
| pitch=sample.get('pitch'), | |
| mel2ph=sample['mel2ph'] if use_gt_dur else None, | |
| mel2word=sample['mel2word'] if hparams['profile_infer'] or hparams['use_gt_dur'] else None, | |
| infer=True, | |
| forward_post_glow=forward_post_glow, | |
| spk_embed=spk_embed, | |
| spk_id=spk_id, | |
| two_stage=hparams['two_stage'], | |
| bert_feats=sample.get('bert_feats')) | |
| return output | |
| def validation_step(self, sample, batch_idx): | |
| self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \ | |
| and hparams['use_post_flow'] | |
| return super().validation_step(sample, batch_idx) | |
| def save_valid_result(self, sample, batch_idx, model_out): | |
| super(PortaSpeechFlowTask, self).save_valid_result(sample, batch_idx, model_out) | |
| sr = hparams['audio_sample_rate'] | |
| f0_gt = None | |
| if sample.get('f0') is not None: | |
| f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu()) | |
| if self.global_step > 0: | |
| # save FVAE result | |
| if hparams['use_post_flow']: | |
| wav_pred = self.vocoder.spec2wav(model_out['mel_out_fvae'][0].cpu(), f0=f0_gt) | |
| self.logger.add_audio(f'wav_fvae_{batch_idx}', wav_pred, self.global_step, sr) | |
| self.plot_mel(batch_idx, sample['mels'], model_out['mel_out_fvae'][0], | |
| f'mel_fvae_{batch_idx}', f0s=f0_gt) | |
| def build_optimizer(self, model): | |
| if hparams['two_stage'] and hparams['use_post_flow']: | |
| self.optimizer = torch.optim.AdamW( | |
| [p for name, p in self.model.named_parameters() if 'post_flow' not in name], | |
| lr=hparams['lr'], | |
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
| weight_decay=hparams['weight_decay']) | |
| self.post_flow_optimizer = torch.optim.AdamW( | |
| self.model.post_flow.parameters(), | |
| lr=hparams['post_flow_lr'], | |
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
| weight_decay=hparams['weight_decay']) | |
| return [self.optimizer, self.post_flow_optimizer] | |
| else: | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=hparams['lr'], | |
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
| weight_decay=hparams['weight_decay']) | |
| return [self.optimizer] | |
| def build_scheduler(self, optimizer): | |
| return FastSpeechTask.build_scheduler(self, optimizer[0]) |