Spaces:
Build error
Build error
| from multiprocessing.pool import Pool | |
| import matplotlib | |
| from utils.pl_utils import data_loader | |
| from utils.training_utils import RSQRTSchedule | |
| from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder | |
| from modules.fastspeech.pe import PitchExtractor | |
| matplotlib.use('Agg') | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch.distributed as dist | |
| from tasks.base_task import BaseTask | |
| from utils.hparams import hparams | |
| from utils.text_encoder import TokenTextEncoder | |
| import json | |
| import torch | |
| import torch.optim | |
| import torch.utils.data | |
| import utils | |
| class TtsTask(BaseTask): | |
| def __init__(self, *args, **kwargs): | |
| self.vocoder = None | |
| self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir']) | |
| self.padding_idx = self.phone_encoder.pad() | |
| self.eos_idx = self.phone_encoder.eos() | |
| self.seg_idx = self.phone_encoder.seg() | |
| self.saving_result_pool = None | |
| self.saving_results_futures = None | |
| self.stats = {} | |
| super().__init__(*args, **kwargs) | |
| def build_scheduler(self, optimizer): | |
| return RSQRTSchedule(optimizer) | |
| def build_optimizer(self, model): | |
| self.optimizer = optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=hparams['lr']) | |
| return optimizer | |
| def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, | |
| required_batch_size_multiple=-1, endless=False, batch_by_size=True): | |
| devices_cnt = torch.cuda.device_count() | |
| if devices_cnt == 0: | |
| devices_cnt = 1 | |
| if required_batch_size_multiple == -1: | |
| required_batch_size_multiple = devices_cnt | |
| def shuffle_batches(batches): | |
| np.random.shuffle(batches) | |
| return batches | |
| if max_tokens is not None: | |
| max_tokens *= devices_cnt | |
| if max_sentences is not None: | |
| max_sentences *= devices_cnt | |
| indices = dataset.ordered_indices() | |
| if batch_by_size: | |
| batch_sampler = utils.batch_by_size( | |
| indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, | |
| required_batch_size_multiple=required_batch_size_multiple, | |
| ) | |
| else: | |
| batch_sampler = [] | |
| for i in range(0, len(indices), max_sentences): | |
| batch_sampler.append(indices[i:i + max_sentences]) | |
| if shuffle: | |
| batches = shuffle_batches(list(batch_sampler)) | |
| if endless: | |
| batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] | |
| else: | |
| batches = batch_sampler | |
| if endless: | |
| batches = [b for _ in range(1000) for b in batches] | |
| num_workers = dataset.num_workers | |
| if self.trainer.use_ddp: | |
| num_replicas = dist.get_world_size() | |
| rank = dist.get_rank() | |
| batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] | |
| return torch.utils.data.DataLoader(dataset, | |
| collate_fn=dataset.collater, | |
| batch_sampler=batches, | |
| num_workers=num_workers, | |
| pin_memory=False) | |
| def build_phone_encoder(self, data_dir): | |
| phone_list_file = os.path.join(data_dir, 'phone_set.json') | |
| phone_list = json.load(open(phone_list_file)) | |
| return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') | |
| def build_optimizer(self, model): | |
| self.optimizer = optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=hparams['lr']) | |
| return optimizer | |
| def test_start(self): | |
| self.saving_result_pool = Pool(8) | |
| self.saving_results_futures = [] | |
| self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() | |
| if hparams.get('pe_enable') is not None and hparams['pe_enable']: | |
| self.pe = PitchExtractor().cuda() | |
| utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) | |
| self.pe.eval() | |
| def test_end(self, outputs): | |
| self.saving_result_pool.close() | |
| [f.get() for f in tqdm(self.saving_results_futures)] | |
| self.saving_result_pool.join() | |
| return {} | |
| ########## | |
| # utils | |
| ########## | |
| def weights_nonzero_speech(self, target): | |
| # target : B x T x mel | |
| # Assign weight 1.0 to all labels except for padding (id=0). | |
| dim = target.size(-1) | |
| return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) | |
| if __name__ == '__main__': | |
| TtsTask.start() | |