Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| API that can manage the storage and retrieval of generated samples produced by experiments. | |
| It offers the following benefits: | |
| * Samples are stored in a consistent way across epoch | |
| * Metadata about the samples can be stored and retrieved | |
| * Can retrieve audio | |
| * Identifiers are reliable and deterministic for prompted and conditioned samples | |
| * Can request the samples for multiple XPs, grouped by sample identifier | |
| * For no-input samples (not prompt and no conditions), samples across XPs are matched | |
| by sorting their identifiers | |
| """ | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import asdict, dataclass | |
| from functools import lru_cache | |
| import hashlib | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import re | |
| import typing as tp | |
| import unicodedata | |
| import uuid | |
| import dora | |
| import torch | |
| from ...data.audio import audio_read, audio_write | |
| logger = logging.getLogger(__name__) | |
| class ReferenceSample: | |
| id: str | |
| path: str | |
| duration: float | |
| class Sample: | |
| id: str | |
| path: str | |
| epoch: int | |
| duration: float | |
| conditioning: tp.Optional[tp.Dict[str, tp.Any]] | |
| prompt: tp.Optional[ReferenceSample] | |
| reference: tp.Optional[ReferenceSample] | |
| generation_args: tp.Optional[tp.Dict[str, tp.Any]] | |
| def __hash__(self): | |
| return hash(self.id) | |
| def audio(self) -> tp.Tuple[torch.Tensor, int]: | |
| return audio_read(self.path) | |
| def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: | |
| return audio_read(self.prompt.path) if self.prompt is not None else None | |
| def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: | |
| return audio_read(self.reference.path) if self.reference is not None else None | |
| class SampleManager: | |
| """Audio samples IO handling within a given dora xp. | |
| The sample manager handles the dumping and loading logic for generated and | |
| references samples across epochs for a given xp, providing a simple API to | |
| store, retrieve and compare audio samples. | |
| Args: | |
| xp (dora.XP): Dora experiment object. The XP contains information on the XP folder | |
| where all outputs are stored and the configuration of the experiment, | |
| which is useful to retrieve audio-related parameters. | |
| map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples | |
| instead of generating a dedicated hash id. This is useful to allow easier comparison | |
| with ground truth sample from the files directly without having to read the JSON metadata | |
| to do the mapping (at the cost of potentially dumping duplicate prompts/references | |
| depending on the task). | |
| """ | |
| def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False): | |
| self.xp = xp | |
| self.base_folder: Path = xp.folder / xp.cfg.generate.path | |
| self.reference_folder = self.base_folder / 'reference' | |
| self.map_reference_to_sample_id = map_reference_to_sample_id | |
| self.samples: tp.List[Sample] = [] | |
| self._load_samples() | |
| def latest_epoch(self): | |
| """Latest epoch across all samples.""" | |
| return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0 | |
| def _load_samples(self): | |
| """Scan the sample folder and load existing samples.""" | |
| jsons = self.base_folder.glob('**/*.json') | |
| with ThreadPoolExecutor(6) as pool: | |
| self.samples = list(pool.map(self._load_sample, jsons)) | |
| def _load_sample(json_file: Path) -> Sample: | |
| with open(json_file, 'r') as f: | |
| data: tp.Dict[str, tp.Any] = json.load(f) | |
| # fetch prompt data | |
| prompt_data = data.get('prompt') | |
| prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'], | |
| duration=prompt_data['duration']) if prompt_data else None | |
| # fetch reference data | |
| reference_data = data.get('reference') | |
| reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'], | |
| duration=reference_data['duration']) if reference_data else None | |
| # build sample object | |
| return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'], | |
| prompt=prompt, conditioning=data.get('conditioning'), reference=reference, | |
| generation_args=data.get('generation_args')) | |
| def _init_hash(self): | |
| return hashlib.sha1() | |
| def _get_tensor_id(self, tensor: torch.Tensor) -> str: | |
| hash_id = self._init_hash() | |
| hash_id.update(tensor.numpy().data) | |
| return hash_id.hexdigest() | |
| def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor], | |
| conditions: tp.Optional[tp.Dict[str, str]]) -> str: | |
| """Computes an id for a sample given its input data. | |
| This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input. | |
| Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned. | |
| Args: | |
| index (int): Batch index, Helpful to differentiate samples from the same batch. | |
| prompt_wav (torch.Tensor): Prompt used during generation. | |
| conditions (dict[str, str]): Conditioning used during generation. | |
| """ | |
| # For totally unconditioned generations we will just use a random UUID. | |
| # The function get_samples_for_xps will do a simple ordered match with a custom key. | |
| if prompt_wav is None and not conditions: | |
| return f"noinput_{uuid.uuid4().hex}" | |
| # Human readable portion | |
| hr_label = "" | |
| # Create a deterministic id using hashing | |
| hash_id = self._init_hash() | |
| hash_id.update(f"{index}".encode()) | |
| if prompt_wav is not None: | |
| hash_id.update(prompt_wav.numpy().data) | |
| hr_label += "_prompted" | |
| else: | |
| hr_label += "_unprompted" | |
| if conditions: | |
| encoded_json = json.dumps(conditions, sort_keys=True).encode() | |
| hash_id.update(encoded_json) | |
| cond_str = "-".join([f"{key}={slugify(value)}" | |
| for key, value in sorted(conditions.items())]) | |
| cond_str = cond_str[:100] # some raw text might be too long to be a valid filename | |
| cond_str = cond_str if len(cond_str) > 0 else "unconditioned" | |
| hr_label += f"_{cond_str}" | |
| else: | |
| hr_label += "_unconditioned" | |
| return hash_id.hexdigest() + hr_label | |
| def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path: | |
| """Stores the audio with the given stem path using the XP's configuration. | |
| Args: | |
| wav (torch.Tensor): Audio to store. | |
| stem_path (Path): Path in sample output directory with file stem to use. | |
| overwrite (bool): When False (default), skips storing an existing audio file. | |
| Returns: | |
| Path: The path at which the audio is stored. | |
| """ | |
| existing_paths = [ | |
| path for path in stem_path.parent.glob(stem_path.stem + '.*') | |
| if path.suffix != '.json' | |
| ] | |
| exists = len(existing_paths) > 0 | |
| if exists and overwrite: | |
| logger.warning(f"Overwriting existing audio file with stem path {stem_path}") | |
| elif exists: | |
| return existing_paths[0] | |
| audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio) | |
| return audio_path | |
| def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, | |
| conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None, | |
| ground_truth_wav: tp.Optional[torch.Tensor] = None, | |
| generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample: | |
| """Adds a single sample. | |
| The sample is stored in the XP's sample output directory, under a corresponding epoch folder. | |
| Each sample is assigned an id which is computed using the input data. In addition to the | |
| sample itself, a json file containing associated metadata is stored next to it. | |
| Args: | |
| sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape]. | |
| epoch (int): current training epoch. | |
| index (int): helpful to differentiate samples from the same batch. | |
| conditions (dict[str, str], optional): conditioning used during generation. | |
| prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape]. | |
| ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from. | |
| Tensor of shape [channels, shape]. | |
| generation_args (dict[str, any], optional): dictionary of other arguments used during generation. | |
| Returns: | |
| Sample: The saved sample. | |
| """ | |
| sample_id = self._get_sample_id(index, prompt_wav, conditions) | |
| reuse_id = self.map_reference_to_sample_id | |
| prompt, ground_truth = None, None | |
| if prompt_wav is not None: | |
| prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True)) | |
| prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate | |
| prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id) | |
| prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration) | |
| if ground_truth_wav is not None: | |
| ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True)) | |
| ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate | |
| ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id) | |
| ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration) | |
| sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True) | |
| duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate | |
| sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args) | |
| self.samples.append(sample) | |
| with open(sample_path.with_suffix('.json'), 'w') as f: | |
| json.dump(asdict(sample), f, indent=2) | |
| return sample | |
| def add_samples(self, samples_wavs: torch.Tensor, epoch: int, | |
| conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, | |
| prompt_wavs: tp.Optional[torch.Tensor] = None, | |
| ground_truth_wavs: tp.Optional[torch.Tensor] = None, | |
| generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]: | |
| """Adds a batch of samples. | |
| The samples are stored in the XP's sample output directory, under a corresponding | |
| epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. | |
| In addition to the sample itself, a json file containing associated metadata is stored next to it. | |
| Args: | |
| sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape]. | |
| epoch (int): Current training epoch. | |
| conditioning (list of dict[str, str], optional): List of conditions used during generation, | |
| one per sample in the batch. | |
| prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape | |
| [batch_size, channels, shape]. | |
| ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from. | |
| Tensor of shape [batch_size, channels, shape]. | |
| generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation. | |
| Returns: | |
| samples (list of Sample): The saved audio samples with prompts, ground truth and metadata. | |
| """ | |
| samples = [] | |
| for idx, wav in enumerate(samples_wavs): | |
| prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None | |
| gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None | |
| conditions = conditioning[idx] if conditioning is not None else None | |
| samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args)) | |
| return samples | |
| def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, | |
| exclude_unprompted: bool = False, exclude_conditioned: bool = False, | |
| exclude_unconditioned: bool = False) -> tp.Set[Sample]: | |
| """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. | |
| Please note that existing samples are loaded during the manager's initialization, and added samples through this | |
| manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager | |
| is the only way detect them. | |
| Args: | |
| epoch (int): If provided, only return samples corresponding to this epoch. | |
| max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch. | |
| exclude_prompted (bool): If True, does not include samples that used a prompt. | |
| exclude_unprompted (bool): If True, does not include samples that did not use a prompt. | |
| exclude_conditioned (bool): If True, excludes samples that used conditioning. | |
| exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. | |
| Returns: | |
| Samples (set of Sample): The retrieved samples matching the provided filters. | |
| """ | |
| if max_epoch >= 0: | |
| samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch) | |
| else: | |
| samples_epoch = self.latest_epoch if epoch < 0 else epoch | |
| samples = { | |
| sample | |
| for sample in self.samples | |
| if ( | |
| (sample.epoch == samples_epoch) and | |
| (not exclude_prompted or sample.prompt is None) and | |
| (not exclude_unprompted or sample.prompt is not None) and | |
| (not exclude_conditioned or not sample.conditioning) and | |
| (not exclude_unconditioned or sample.conditioning) | |
| ) | |
| } | |
| return samples | |
| def slugify(value: tp.Any, allow_unicode: bool = False): | |
| """Process string for safer file naming. | |
| Taken from https://github.com/django/django/blob/master/django/utils/text.py | |
| Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated | |
| dashes to single dashes. Remove characters that aren't alphanumerics, | |
| underscores, or hyphens. Convert to lowercase. Also strip leading and | |
| trailing whitespace, dashes, and underscores. | |
| """ | |
| value = str(value) | |
| if allow_unicode: | |
| value = unicodedata.normalize("NFKC", value) | |
| else: | |
| value = ( | |
| unicodedata.normalize("NFKD", value) | |
| .encode("ascii", "ignore") | |
| .decode("ascii") | |
| ) | |
| value = re.sub(r"[^\w\s-]", "", value.lower()) | |
| return re.sub(r"[-\s]+", "-", value).strip("-_") | |
| def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: | |
| # Create a dictionary of stable id -> sample per XP | |
| stable_samples_per_xp = [{ | |
| sample.id: sample for sample in samples | |
| if sample.prompt is not None or sample.conditioning | |
| } for samples in samples_per_xp] | |
| # Set of all stable ids | |
| stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()} | |
| # Dictionary of stable id -> list of samples. If an XP does not have it, assign None | |
| stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids} | |
| # Filter out ids that contain None values (we only want matched samples after all) | |
| # cast is necessary to avoid mypy linter errors. | |
| return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples} | |
| def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: | |
| # For unstable ids, we use a sorted list since we'll match them in order | |
| unstable_samples_per_xp = [[ | |
| sample for sample in sorted(samples, key=lambda x: x.id) | |
| if sample.prompt is None and not sample.conditioning | |
| ] for samples in samples_per_xp] | |
| # Trim samples per xp so all samples can have a match | |
| min_len = min([len(samples) for samples in unstable_samples_per_xp]) | |
| unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp] | |
| # Dictionary of index -> list of matched samples | |
| return { | |
| f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len) | |
| } | |
| def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]: | |
| """Gets a dictionary of matched samples across the given XPs. | |
| Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id | |
| will always match the number of XPs provided and will correspond to each XP in the same order given. | |
| In other words, only samples that can be match across all provided XPs will be returned | |
| in order to satisfy this rule. | |
| There are two types of ids that can be returned: stable and unstable. | |
| * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs | |
| (prompts/conditioning). This is why we can match them across XPs. | |
| * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples | |
| that used non-deterministic, random ids. This is the case for samples that did not use prompts or | |
| conditioning for their generation. This function will sort these samples by their id and match them | |
| by their index. | |
| Args: | |
| xps: a list of XPs to match samples from. | |
| start_epoch (int): If provided, only return samples corresponding to this epoch or newer. | |
| end_epoch (int): If provided, only return samples corresponding to this epoch or older. | |
| exclude_prompted (bool): If True, does not include samples that used a prompt. | |
| exclude_unprompted (bool): If True, does not include samples that did not use a prompt. | |
| exclude_conditioned (bool): If True, excludes samples that used conditioning. | |
| exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. | |
| """ | |
| managers = [SampleManager(xp) for xp in xps] | |
| samples_per_xp = [manager.get_samples(**kwargs) for manager in managers] | |
| stable_samples = _match_stable_samples(samples_per_xp) | |
| unstable_samples = _match_unstable_samples(samples_per_xp) | |
| return dict(stable_samples, **unstable_samples) | |