Spaces:
Runtime error
Runtime error
from multiprocessing.pool import Pool | |
import matplotlib | |
from utils.pl_utils import data_loader | |
from utils.training_utils import RSQRTSchedule | |
from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder | |
from modules.fastspeech.pe import PitchExtractor | |
matplotlib.use('Agg') | |
import os | |
import numpy as np | |
from tqdm import tqdm | |
import torch.distributed as dist | |
from tasks.base_task import BaseTask | |
from utils.hparams import hparams | |
from utils.text_encoder import TokenTextEncoder | |
import json | |
import torch | |
import torch.optim | |
import torch.utils.data | |
import utils | |
class TtsTask(BaseTask): | |
def __init__(self, *args, **kwargs): | |
self.vocoder = None | |
self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir']) | |
self.padding_idx = self.phone_encoder.pad() | |
self.eos_idx = self.phone_encoder.eos() | |
self.seg_idx = self.phone_encoder.seg() | |
self.saving_result_pool = None | |
self.saving_results_futures = None | |
self.stats = {} | |
super().__init__(*args, **kwargs) | |
def build_scheduler(self, optimizer): | |
return RSQRTSchedule(optimizer) | |
def build_optimizer(self, model): | |
self.optimizer = optimizer = torch.optim.AdamW( | |
model.parameters(), | |
lr=hparams['lr']) | |
return optimizer | |
def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, | |
required_batch_size_multiple=-1, endless=False, batch_by_size=True): | |
devices_cnt = torch.cuda.device_count() | |
if devices_cnt == 0: | |
devices_cnt = 1 | |
if required_batch_size_multiple == -1: | |
required_batch_size_multiple = devices_cnt | |
def shuffle_batches(batches): | |
np.random.shuffle(batches) | |
return batches | |
if max_tokens is not None: | |
max_tokens *= devices_cnt | |
if max_sentences is not None: | |
max_sentences *= devices_cnt | |
indices = dataset.ordered_indices() | |
if batch_by_size: | |
batch_sampler = utils.batch_by_size( | |
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
) | |
else: | |
batch_sampler = [] | |
for i in range(0, len(indices), max_sentences): | |
batch_sampler.append(indices[i:i + max_sentences]) | |
if shuffle: | |
batches = shuffle_batches(list(batch_sampler)) | |
if endless: | |
batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] | |
else: | |
batches = batch_sampler | |
if endless: | |
batches = [b for _ in range(1000) for b in batches] | |
num_workers = dataset.num_workers | |
if self.trainer.use_ddp: | |
num_replicas = dist.get_world_size() | |
rank = dist.get_rank() | |
batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] | |
return torch.utils.data.DataLoader(dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=batches, | |
num_workers=num_workers, | |
pin_memory=False) | |
def build_phone_encoder(self, data_dir): | |
phone_list_file = os.path.join(data_dir, 'phone_set.json') | |
phone_list = json.load(open(phone_list_file)) | |
return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') | |
def build_optimizer(self, model): | |
self.optimizer = optimizer = torch.optim.AdamW( | |
model.parameters(), | |
lr=hparams['lr']) | |
return optimizer | |
def test_start(self): | |
self.saving_result_pool = Pool(8) | |
self.saving_results_futures = [] | |
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() | |
if hparams.get('pe_enable') is not None and hparams['pe_enable']: | |
self.pe = PitchExtractor().cuda() | |
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) | |
self.pe.eval() | |
def test_end(self, outputs): | |
self.saving_result_pool.close() | |
[f.get() for f in tqdm(self.saving_results_futures)] | |
self.saving_result_pool.join() | |
return {} | |
########## | |
# utils | |
########## | |
def weights_nonzero_speech(self, target): | |
# target : B x T x mel | |
# Assign weight 1.0 to all labels except for padding (id=0). | |
dim = target.size(-1) | |
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) | |
if __name__ == '__main__': | |
TtsTask.start() | |