|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from models.base.new_trainer import BaseTrainer |
|
from models.svc.base.svc_dataset import SVCCollator, SVCDataset |
|
|
|
|
|
class SVCTrainer(BaseTrainer): |
|
r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements |
|
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this |
|
class, and implement ``_build_model``, ``_forward_step``. |
|
""" |
|
|
|
def __init__(self, args=None, cfg=None): |
|
self.args = args |
|
self.cfg = cfg |
|
|
|
self._init_accelerator() |
|
|
|
|
|
with self.accelerator.main_process_first(): |
|
self.singers = self._build_singer_lut() |
|
|
|
|
|
BaseTrainer.__init__(self, args, cfg) |
|
|
|
|
|
self.task_type = "SVC" |
|
self.logger.info("Task type: {}".format(self.task_type)) |
|
|
|
|
|
|
|
def _build_dataset(self): |
|
return SVCDataset, SVCCollator |
|
|
|
@staticmethod |
|
def _build_criterion(): |
|
criterion = nn.MSELoss(reduction="none") |
|
return criterion |
|
|
|
@staticmethod |
|
def _compute_loss(criterion, y_pred, y_gt, loss_mask): |
|
""" |
|
Args: |
|
criterion: MSELoss(reduction='none') |
|
y_pred, y_gt: (bs, seq_len, D) |
|
loss_mask: (bs, seq_len, 1) |
|
Returns: |
|
loss: Tensor of shape [] |
|
""" |
|
|
|
|
|
loss = criterion(y_pred, y_gt) |
|
|
|
loss_mask = loss_mask.repeat(1, 1, loss.shape[-1]) |
|
|
|
loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask) |
|
return loss |
|
|
|
def _save_auxiliary_states(self): |
|
""" |
|
To save the singer's look-up table in the checkpoint saving path |
|
""" |
|
with open( |
|
os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w" |
|
) as f: |
|
json.dump(self.singers, f, indent=4, ensure_ascii=False) |
|
|
|
def _build_singer_lut(self): |
|
resumed_singer_path = None |
|
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": |
|
resumed_singer_path = os.path.join( |
|
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id |
|
) |
|
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): |
|
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) |
|
|
|
if resumed_singer_path: |
|
with open(resumed_singer_path, "r") as f: |
|
singers = json.load(f) |
|
else: |
|
singers = dict() |
|
|
|
for dataset in self.cfg.dataset: |
|
singer_lut_path = os.path.join( |
|
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id |
|
) |
|
with open(singer_lut_path, "r") as singer_lut_path: |
|
singer_lut = json.load(singer_lut_path) |
|
for singer in singer_lut.keys(): |
|
if singer not in singers: |
|
singers[singer] = len(singers) |
|
|
|
with open( |
|
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" |
|
) as singer_file: |
|
json.dump(singers, singer_file, indent=4, ensure_ascii=False) |
|
print( |
|
"singers have been dumped to {}".format( |
|
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) |
|
) |
|
) |
|
return singers |
|
|