Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict | |
| import hashlib | |
| import math | |
| import json | |
| from pathlib import Path | |
| import julius | |
| import torch as th | |
| from torch import distributed | |
| import torchaudio as ta | |
| from torch.nn import functional as F | |
| from .audio import convert_audio_channels | |
| from .compressed import get_musdb_tracks | |
| MIXTURE = "mixture" | |
| EXT = ".wav" | |
| def _track_metadata(track, sources): | |
| track_length = None | |
| track_samplerate = None | |
| for source in sources + [MIXTURE]: | |
| file = track / f"{source}{EXT}" | |
| info = ta.info(str(file)) | |
| length = info.num_frames | |
| if track_length is None: | |
| track_length = length | |
| track_samplerate = info.sample_rate | |
| elif track_length != length: | |
| raise ValueError( | |
| f"Invalid length for file {file}: " | |
| f"expecting {track_length} but got {length}.") | |
| elif info.sample_rate != track_samplerate: | |
| raise ValueError( | |
| f"Invalid sample rate for file {file}: " | |
| f"expecting {track_samplerate} but got {info.sample_rate}.") | |
| if source == MIXTURE: | |
| wav, _ = ta.load(str(file)) | |
| wav = wav.mean(0) | |
| mean = wav.mean().item() | |
| std = wav.std().item() | |
| return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} | |
| def _build_metadata(path, sources): | |
| meta = {} | |
| path = Path(path) | |
| for file in path.iterdir(): | |
| meta[file.name] = _track_metadata(file, sources) | |
| return meta | |
| class Wavset: | |
| def __init__( | |
| self, | |
| root, metadata, sources, | |
| length=None, stride=None, normalize=True, | |
| samplerate=44100, channels=2): | |
| """ | |
| Waveset (or mp3 set for that matter). Can be used to train | |
| with arbitrary sources. Each track should be one folder inside of `path`. | |
| The folder should contain files named `{source}.{ext}`. | |
| Files will be grouped according to `sources` (each source is a list of | |
| filenames). | |
| Sample rate and channels will be converted on the fly. | |
| `length` is the sample size to extract (in samples, not duration). | |
| `stride` is how many samples to move by between each example. | |
| """ | |
| self.root = Path(root) | |
| self.metadata = OrderedDict(metadata) | |
| self.length = length | |
| self.stride = stride or length | |
| self.normalize = normalize | |
| self.sources = sources | |
| self.channels = channels | |
| self.samplerate = samplerate | |
| self.num_examples = [] | |
| for name, meta in self.metadata.items(): | |
| track_length = int(self.samplerate * meta['length'] / meta['samplerate']) | |
| if length is None or track_length < length: | |
| examples = 1 | |
| else: | |
| examples = int(math.ceil((track_length - self.length) / self.stride) + 1) | |
| self.num_examples.append(examples) | |
| def __len__(self): | |
| return sum(self.num_examples) | |
| def get_file(self, name, source): | |
| return self.root / name / f"{source}{EXT}" | |
| def __getitem__(self, index): | |
| for name, examples in zip(self.metadata, self.num_examples): | |
| if index >= examples: | |
| index -= examples | |
| continue | |
| meta = self.metadata[name] | |
| num_frames = -1 | |
| offset = 0 | |
| if self.length is not None: | |
| offset = int(math.ceil( | |
| meta['samplerate'] * self.stride * index / self.samplerate)) | |
| num_frames = int(math.ceil( | |
| meta['samplerate'] * self.length / self.samplerate)) | |
| wavs = [] | |
| for source in self.sources: | |
| file = self.get_file(name, source) | |
| wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) | |
| wav = convert_audio_channels(wav, self.channels) | |
| wavs.append(wav) | |
| example = th.stack(wavs) | |
| example = julius.resample_frac(example, meta['samplerate'], self.samplerate) | |
| if self.normalize: | |
| example = (example - meta['mean']) / meta['std'] | |
| if self.length: | |
| example = example[..., :self.length] | |
| example = F.pad(example, (0, self.length - example.shape[-1])) | |
| return example | |
| def get_wav_datasets(args, samples, sources): | |
| sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] | |
| metadata_file = args.metadata / (sig + ".json") | |
| train_path = args.wav / "train" | |
| valid_path = args.wav / "valid" | |
| if not metadata_file.is_file() and args.rank == 0: | |
| train = _build_metadata(train_path, sources) | |
| valid = _build_metadata(valid_path, sources) | |
| json.dump([train, valid], open(metadata_file, "w")) | |
| if args.world_size > 1: | |
| distributed.barrier() | |
| train, valid = json.load(open(metadata_file)) | |
| train_set = Wavset(train_path, train, sources, | |
| length=samples, stride=args.data_stride, | |
| samplerate=args.samplerate, channels=args.audio_channels, | |
| normalize=args.norm_wav) | |
| valid_set = Wavset(valid_path, valid, [MIXTURE] + sources, | |
| samplerate=args.samplerate, channels=args.audio_channels, | |
| normalize=args.norm_wav) | |
| return train_set, valid_set | |
| def get_musdb_wav_datasets(args, samples, sources): | |
| metadata_file = args.metadata / "musdb_wav.json" | |
| root = args.musdb / "train" | |
| if not metadata_file.is_file() and args.rank == 0: | |
| metadata = _build_metadata(root, sources) | |
| json.dump(metadata, open(metadata_file, "w")) | |
| if args.world_size > 1: | |
| distributed.barrier() | |
| metadata = json.load(open(metadata_file)) | |
| train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train") | |
| metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks} | |
| metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks} | |
| train_set = Wavset(root, metadata_train, sources, | |
| length=samples, stride=args.data_stride, | |
| samplerate=args.samplerate, channels=args.audio_channels, | |
| normalize=args.norm_wav) | |
| valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources, | |
| samplerate=args.samplerate, channels=args.audio_channels, | |
| normalize=args.norm_wav) | |
| return train_set, valid_set | |