# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT # # Permission is hereby granted, free of charge, to any person obtaining a # copy of this software and associated documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. import torch import torch.nn as nn from torch.nn import functional as F from common import get_mask_from_lengths def compute_flow_loss( z, log_det_W_list, log_s_list, n_elements, n_dims, mask, sigma=1.0 ): log_det_W_total = 0.0 for i, log_s in enumerate(log_s_list): if i == 0: log_s_total = torch.sum(log_s * mask) if len(log_det_W_list): log_det_W_total = log_det_W_list[i] else: log_s_total = log_s_total + torch.sum(log_s * mask) if len(log_det_W_list): log_det_W_total += log_det_W_list[i] if len(log_det_W_list): log_det_W_total *= n_elements z = z * mask prior_NLL = torch.sum(z * z) / (2 * sigma * sigma) loss = prior_NLL - log_s_total - log_det_W_total denom = n_elements * n_dims loss = loss / denom loss_prior = prior_NLL / denom return loss, loss_prior def compute_regression_loss(x_hat, x, mask, name=False): x = x[:, None] if len(x.shape) == 2 else x # add channel dim mask = mask[:, None] if len(mask.shape) == 2 else mask # add channel dim assert len(x.shape) == len(mask.shape) x = x * mask x_hat = x_hat * mask if name == "vpred": loss = F.binary_cross_entropy_with_logits(x_hat, x, reduction="sum") else: loss = F.mse_loss(x_hat, x, reduction="sum") loss = loss / mask.sum() loss_dict = {"loss_{}".format(name): loss} return loss_dict class AttributePredictionLoss(torch.nn.Module): def __init__(self, name, model_config, loss_weight, sigma=1.0): super(AttributePredictionLoss, self).__init__() self.name = name self.sigma = sigma self.model_name = model_config["name"] self.loss_weight = loss_weight self.n_group_size = 1 if "n_group_size" in model_config["hparams"]: self.n_group_size = model_config["hparams"]["n_group_size"] def forward(self, model_output, lens): mask = get_mask_from_lengths(lens // self.n_group_size) mask = mask[:, None].float() loss_dict = {} if "z" in model_output: n_elements = lens.sum() // self.n_group_size n_dims = model_output["z"].size(1) loss, loss_prior = compute_flow_loss( model_output["z"], model_output["log_det_W_list"], model_output["log_s_list"], n_elements, n_dims, mask, self.sigma, ) loss_dict = { "loss_{}".format(self.name): (loss, self.loss_weight), "loss_prior_{}".format(self.name): (loss_prior, 0.0), } elif "x_hat" in model_output: loss_dict = compute_regression_loss( model_output["x_hat"], model_output["x"], mask, self.name ) for k, v in loss_dict.items(): loss_dict[k] = (v, self.loss_weight) if len(loss_dict) == 0: raise Exception("loss not supported") return loss_dict class AttentionCTCLoss(torch.nn.Module): def __init__(self, blank_logprob=-1): super(AttentionCTCLoss, self).__init__() self.log_softmax = torch.nn.LogSoftmax(dim=3) self.blank_logprob = blank_logprob self.CTCLoss = nn.CTCLoss(zero_infinity=True) def forward(self, attn_logprob, in_lens, out_lens): key_lens = in_lens query_lens = out_lens attn_logprob_padded = F.pad( input=attn_logprob, pad=(1, 0, 0, 0, 0, 0, 0, 0), value=self.blank_logprob ) cost_total = 0.0 for bid in range(attn_logprob.shape[0]): target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[ : query_lens[bid], :, : key_lens[bid] + 1 ] curr_logprob = self.log_softmax(curr_logprob[None])[0] ctc_cost = self.CTCLoss( curr_logprob, target_seq, input_lengths=query_lens[bid : bid + 1], target_lengths=key_lens[bid : bid + 1], ) cost_total += ctc_cost cost = cost_total / attn_logprob.shape[0] return cost class AttentionBinarizationLoss(torch.nn.Module): def __init__(self): super(AttentionBinarizationLoss, self).__init__() def forward(self, hard_attention, soft_attention): log_sum = torch.log(soft_attention[hard_attention == 1]).sum() return -log_sum / hard_attention.sum() class RADTTSLoss(torch.nn.Module): def __init__( self, sigma=1.0, n_group_size=1, dur_model_config=None, f0_model_config=None, energy_model_config=None, vpred_model_config=None, loss_weights=None, ): super(RADTTSLoss, self).__init__() self.sigma = sigma self.n_group_size = n_group_size self.loss_weights = loss_weights self.attn_ctc_loss = AttentionCTCLoss( blank_logprob=loss_weights.get("blank_logprob", -1) ) self.loss_fns = {} if dur_model_config is not None: self.loss_fns["duration_model_outputs"] = AttributePredictionLoss( "duration", dur_model_config, loss_weights["dur_loss_weight"] ) if f0_model_config is not None: self.loss_fns["f0_model_outputs"] = AttributePredictionLoss( "f0", f0_model_config, loss_weights["f0_loss_weight"], sigma=1.0 ) if energy_model_config is not None: self.loss_fns["energy_model_outputs"] = AttributePredictionLoss( "energy", energy_model_config, loss_weights["energy_loss_weight"] ) if vpred_model_config is not None: self.loss_fns["vpred_model_outputs"] = AttributePredictionLoss( "vpred", vpred_model_config, loss_weights["vpred_loss_weight"] ) def forward(self, model_output, in_lens, out_lens): loss_dict = {} if len(model_output["z_mel"]): n_elements = out_lens.sum() // self.n_group_size mask = get_mask_from_lengths(out_lens // self.n_group_size) mask = mask[:, None].float() n_dims = model_output["z_mel"].size(1) loss_mel, loss_prior_mel = compute_flow_loss( model_output["z_mel"], model_output["log_det_W_list"], model_output["log_s_list"], n_elements, n_dims, mask, self.sigma, ) loss_dict["loss_mel"] = (loss_mel, 1.0) # loss, weight loss_dict["loss_prior_mel"] = (loss_prior_mel, 0.0) ctc_cost = self.attn_ctc_loss(model_output["attn_logprob"], in_lens, out_lens) loss_dict["loss_ctc"] = (ctc_cost, self.loss_weights["ctc_loss_weight"]) for k in model_output: if k in self.loss_fns: if model_output[k] is not None and len(model_output[k]) > 0: t_lens = in_lens if "dur" in k else out_lens mout = model_output[k] for loss_name, v in self.loss_fns[k](mout, t_lens).items(): loss_dict[loss_name] = v return loss_dict