Spaces:
Build error
Build error
| import matplotlib | |
| matplotlib.use('Agg') | |
| import torch | |
| import numpy as np | |
| import os | |
| from tasks.base_task import BaseDataset | |
| from tasks.tts.fs2 import FastSpeech2Task | |
| from modules.fastspeech.pe import PitchExtractor | |
| import utils | |
| from utils.indexed_datasets import IndexedDataset | |
| from utils.hparams import hparams | |
| from utils.plot import f0_to_figure | |
| from utils.pitch_utils import norm_interp_f0, denorm_f0 | |
| class PeDataset(BaseDataset): | |
| def __init__(self, prefix, shuffle=False): | |
| super().__init__(shuffle) | |
| self.data_dir = hparams['binary_data_dir'] | |
| self.prefix = prefix | |
| self.hparams = hparams | |
| self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') | |
| self.indexed_ds = None | |
| # pitch stats | |
| f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy' | |
| if os.path.exists(f0_stats_fn): | |
| hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn) | |
| hparams['f0_mean'] = float(hparams['f0_mean']) | |
| hparams['f0_std'] = float(hparams['f0_std']) | |
| else: | |
| hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None | |
| if prefix == 'test': | |
| if hparams['num_test_samples'] > 0: | |
| self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids'] | |
| self.sizes = [self.sizes[i] for i in self.avail_idxs] | |
| def _get_item(self, index): | |
| if hasattr(self, 'avail_idxs') and self.avail_idxs is not None: | |
| index = self.avail_idxs[index] | |
| if self.indexed_ds is None: | |
| self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') | |
| return self.indexed_ds[index] | |
| def __getitem__(self, index): | |
| hparams = self.hparams | |
| item = self._get_item(index) | |
| max_frames = hparams['max_frames'] | |
| spec = torch.Tensor(item['mel'])[:max_frames] | |
| # mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None | |
| f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams) | |
| pitch = torch.LongTensor(item.get("pitch"))[:max_frames] | |
| # print(item.keys(), item['mel'].shape, spec.shape) | |
| sample = { | |
| "id": index, | |
| "item_name": item['item_name'], | |
| "text": item['txt'], | |
| "mel": spec, | |
| "pitch": pitch, | |
| "f0": f0, | |
| "uv": uv, | |
| # "mel2ph": mel2ph, | |
| # "mel_nonpadding": spec.abs().sum(-1) > 0, | |
| } | |
| return sample | |
| def collater(self, samples): | |
| if len(samples) == 0: | |
| return {} | |
| id = torch.LongTensor([s['id'] for s in samples]) | |
| item_names = [s['item_name'] for s in samples] | |
| text = [s['text'] for s in samples] | |
| f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) | |
| pitch = utils.collate_1d([s['pitch'] for s in samples]) | |
| uv = utils.collate_1d([s['uv'] for s in samples]) | |
| mels = utils.collate_2d([s['mel'] for s in samples], 0.0) | |
| mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) | |
| # mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ | |
| # if samples[0]['mel2ph'] is not None else None | |
| # mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0) | |
| batch = { | |
| 'id': id, | |
| 'item_name': item_names, | |
| 'nsamples': len(samples), | |
| 'text': text, | |
| 'mels': mels, | |
| 'mel_lengths': mel_lengths, | |
| 'pitch': pitch, | |
| # 'mel2ph': mel2ph, | |
| # 'mel_nonpaddings': mel_nonpaddings, | |
| 'f0': f0, | |
| 'uv': uv, | |
| } | |
| return batch | |
| class PitchExtractionTask(FastSpeech2Task): | |
| def __init__(self): | |
| super().__init__() | |
| self.dataset_cls = PeDataset | |
| def build_tts_model(self): | |
| self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers']) | |
| # def build_scheduler(self, optimizer): | |
| # return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) | |
| def _training_step(self, sample, batch_idx, _): | |
| loss_output = self.run_model(self.model, sample) | |
| total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
| loss_output['batch_size'] = sample['mels'].size()[0] | |
| return total_loss, loss_output | |
| def validation_step(self, sample, batch_idx): | |
| outputs = {} | |
| outputs['losses'] = {} | |
| outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True) | |
| outputs['total_loss'] = sum(outputs['losses'].values()) | |
| outputs['nsamples'] = sample['nsamples'] | |
| outputs = utils.tensors_to_scalars(outputs) | |
| if batch_idx < hparams['num_valid_plots']: | |
| self.plot_pitch(batch_idx, model_out, sample) | |
| return outputs | |
| def run_model(self, model, sample, return_output=False, infer=False): | |
| f0 = sample['f0'] | |
| uv = sample['uv'] | |
| output = model(sample['mels']) | |
| losses = {} | |
| self.add_pitch_loss(output, sample, losses) | |
| if not return_output: | |
| return losses | |
| else: | |
| return losses, output | |
| def plot_pitch(self, batch_idx, model_out, sample): | |
| gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) | |
| self.logger.experiment.add_figure( | |
| f'f0_{batch_idx}', | |
| f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]), | |
| self.global_step) | |
| def add_pitch_loss(self, output, sample, losses): | |
| # mel2ph = sample['mel2ph'] # [B, T_s] | |
| mel = sample['mels'] | |
| f0 = sample['f0'] | |
| uv = sample['uv'] | |
| # nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \ | |
| # else (sample['txt_tokens'] != 0).float() | |
| nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings'] | |
| # print(nonpadding[0][-8:], nonpadding.shape) | |
| self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) |