Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import sys | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from wenet.dataset import processor | |
| from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource | |
| def padding(data): | |
| """ Padding the data into training data | |
| Args: | |
| data: List[{key, feat, label} | |
| Returns: | |
| Tuple(keys, feats, labels, feats lengths, label lengths) | |
| """ | |
| sample = data | |
| assert isinstance(sample, list) | |
| feats_length = torch.tensor([x['feat'].size(0) for x in sample], | |
| dtype=torch.int32) | |
| order = torch.argsort(feats_length, descending=True) | |
| feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order], | |
| dtype=torch.int32) | |
| sorted_feats = [sample[i]['feat'] for i in order] | |
| sorted_keys = [sample[i]['key'] for i in order] | |
| padded_feats = pad_sequence(sorted_feats, | |
| batch_first=True, | |
| padding_value=0) | |
| batch = { | |
| "keys": sorted_keys, | |
| "feats": padded_feats, | |
| "feats_lengths": feats_lengths, | |
| # NOTE(Mddct): cv need targets , refine later | |
| "target": padded_feats, | |
| "target_lengths": feats_lengths, | |
| } | |
| return batch | |
| def Dataset(data_type, data_list_file, conf=None, partition=True): | |
| """ Construct dataset from arguments for ssl model | |
| We have two shuffle stage in the Dataset. The first is global | |
| shuffle at shards tar/raw file level. The second is global shuffle | |
| at training samples level. | |
| Args: | |
| data_type(str): raw/shard | |
| partition(bool): whether to do data partition in terms of rank | |
| """ | |
| assert conf is not None | |
| assert data_type in ['raw', 'shard'] | |
| # cycle dataset | |
| cycle = conf.get('cycle', 1) | |
| # stage1 shuffle: source | |
| list_shuffle = conf.get('list_shuffle', True) | |
| list_shuffle_size = sys.maxsize | |
| if list_shuffle: | |
| list_shuffle_conf = conf.get('list_shuffle_conf', {}) | |
| list_shuffle_size = list_shuffle_conf.get('shuffle_size', | |
| list_shuffle_size) | |
| if data_type == 'raw': | |
| dataset = WenetRawDatasetSource(data_list_file, | |
| partition=partition, | |
| shuffle=list_shuffle, | |
| shuffle_size=list_shuffle_size, | |
| cycle=cycle) | |
| dataset = dataset.map(processor.parse_json) | |
| else: | |
| dataset = WenetTarShardDatasetSource(data_list_file, | |
| partition=partition, | |
| shuffle=list_shuffle, | |
| shuffle_size=list_shuffle_size, | |
| cycle=cycle) | |
| dataset = dataset.map_ignore_error(processor.decode_wav) | |
| singal_channel_conf = conf.get('singal_channel_conf', {}) | |
| dataset = dataset.map( | |
| partial(processor.singal_channel, **singal_channel_conf)) | |
| filter_conf = conf.get('filter_conf', {}) | |
| dataset = dataset.filter(partial(processor.filter, **filter_conf)) | |
| resample_conf = conf.get('resample_conf', {}) | |
| dataset = dataset.map(partial(processor.resample, **resample_conf)) | |
| speed_perturb = conf.get('speed_perturb', False) | |
| if speed_perturb: | |
| dataset = dataset.map(partial(processor.speed_perturb)) | |
| feats_type = conf.get('feats_type', 'fbank') | |
| assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] | |
| if feats_type == 'fbank': | |
| fbank_conf = conf.get('fbank_conf', {}) | |
| dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf)) | |
| elif feats_type == 'mfcc': | |
| mfcc_conf = conf.get('mfcc_conf', {}) | |
| dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf)) | |
| elif feats_type == 'log_mel_spectrogram': | |
| log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) | |
| dataset = dataset.map( | |
| partial(processor.compute_log_mel_spectrogram, | |
| **log_mel_spectrogram_conf)) | |
| spec_aug = conf.get('spec_aug', True) | |
| spec_sub = conf.get('spec_sub', False) | |
| spec_trim = conf.get('spec_trim', False) | |
| if spec_aug: | |
| spec_aug_conf = conf.get('spec_aug_conf', {}) | |
| dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf)) | |
| if spec_sub: | |
| spec_sub_conf = conf.get('spec_sub_conf', {}) | |
| dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf)) | |
| if spec_trim: | |
| spec_trim_conf = conf.get('spec_trim_conf', {}) | |
| dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf)) | |
| shuffle = conf.get('shuffle', True) | |
| if shuffle: | |
| shuffle_conf = conf.get('shuffle_conf', {}) | |
| dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) | |
| sort = conf.get('sort', True) | |
| if sort: | |
| sort_conf = conf.get('sort_conf', {}) | |
| dataset = dataset.sort(buffer_size=sort_conf['sort_size'], | |
| key_func=processor.sort_by_feats) | |
| batch_conf = conf.get('batch_conf', {}) | |
| batch_type = batch_conf.get('batch_type', 'static') | |
| assert batch_type in ['static', 'bucket', 'dynamic'] | |
| if batch_type == 'static': | |
| assert 'batch_size' in batch_conf | |
| batch_size = batch_conf.get('batch_size', 16) | |
| dataset = dataset.batch(batch_size, wrapper_class=padding) | |
| elif batch_type == 'bucket': | |
| assert 'bucket_boundaries' in batch_conf | |
| assert 'bucket_batch_sizes' in batch_conf | |
| dataset = dataset.bucket_by_sequence_length( | |
| processor.feats_length_fn, | |
| batch_conf['bucket_boundaries'], | |
| batch_conf['bucket_batch_sizes'], | |
| wrapper_class=padding) | |
| else: | |
| max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000) | |
| dataset = dataset.dynamic_batch( | |
| processor.DynamicBatchWindow(max_frames_in_batch), | |
| wrapper_class=padding, | |
| ) | |
| return dataset | |
| def init_dataset(data_type, data_list_file, conf=None, partition=True): | |
| return Dataset(data_type, data_list_file, conf, partition) | |