Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import DistributedSampler | |
| from tasks.base_task import BaseTask | |
| from tasks.base_task import data_loader | |
| from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler | |
| from utils.hparams import hparams | |
| class VocoderBaseTask(BaseTask): | |
| def __init__(self): | |
| super(VocoderBaseTask, self).__init__() | |
| self.max_sentences = hparams['max_sentences'] | |
| self.max_valid_sentences = hparams['max_valid_sentences'] | |
| if self.max_valid_sentences == -1: | |
| hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences | |
| self.dataset_cls = VocoderDataset | |
| def train_dataloader(self): | |
| train_dataset = self.dataset_cls('train', shuffle=True) | |
| return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) | |
| def val_dataloader(self): | |
| valid_dataset = self.dataset_cls('valid', shuffle=False) | |
| return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) | |
| def test_dataloader(self): | |
| test_dataset = self.dataset_cls('test', shuffle=False) | |
| return self.build_dataloader(test_dataset, False, self.max_valid_sentences) | |
| def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): | |
| world_size = 1 | |
| rank = 0 | |
| if dist.is_initialized(): | |
| world_size = dist.get_world_size() | |
| rank = dist.get_rank() | |
| sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler | |
| train_sampler = sampler_cls( | |
| dataset=dataset, | |
| num_replicas=world_size, | |
| rank=rank, | |
| shuffle=shuffle, | |
| ) | |
| return torch.utils.data.DataLoader( | |
| dataset=dataset, | |
| shuffle=False, | |
| collate_fn=dataset.collater, | |
| batch_size=max_sentences, | |
| num_workers=dataset.num_workers, | |
| sampler=train_sampler, | |
| pin_memory=True, | |
| ) | |
| def test_start(self): | |
| self.gen_dir = os.path.join(hparams['work_dir'], | |
| f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') | |
| os.makedirs(self.gen_dir, exist_ok=True) | |
| def test_end(self, outputs): | |
| return {} | |