|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
API that can manage the storage and retrieval of generated samples produced by experiments. |
|
|
|
It offers the following benefits: |
|
* Samples are stored in a consistent way across epoch |
|
* Metadata about the samples can be stored and retrieved |
|
* Can retrieve audio |
|
* Identifiers are reliable and deterministic for prompted and conditioned samples |
|
* Can request the samples for multiple XPs, grouped by sample identifier |
|
* For no-input samples (not prompt and no conditions), samples across XPs are matched |
|
by sorting their identifiers |
|
""" |
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
from dataclasses import asdict, dataclass |
|
from functools import lru_cache |
|
import hashlib |
|
import json |
|
import logging |
|
from pathlib import Path |
|
import re |
|
import typing as tp |
|
import unicodedata |
|
import uuid |
|
|
|
import dora |
|
import torch |
|
|
|
from ...data.audio import audio_read, audio_write |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ReferenceSample: |
|
id: str |
|
path: str |
|
duration: float |
|
|
|
|
|
@dataclass |
|
class Sample: |
|
id: str |
|
path: str |
|
epoch: int |
|
duration: float |
|
conditioning: tp.Optional[tp.Dict[str, tp.Any]] |
|
prompt: tp.Optional[ReferenceSample] |
|
reference: tp.Optional[ReferenceSample] |
|
generation_args: tp.Optional[tp.Dict[str, tp.Any]] |
|
|
|
def __hash__(self): |
|
return hash(self.id) |
|
|
|
def audio(self) -> tp.Tuple[torch.Tensor, int]: |
|
return audio_read(self.path) |
|
|
|
def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: |
|
return audio_read(self.prompt.path) if self.prompt is not None else None |
|
|
|
def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: |
|
return audio_read(self.reference.path) if self.reference is not None else None |
|
|
|
|
|
class SampleManager: |
|
"""Audio samples IO handling within a given dora xp. |
|
|
|
The sample manager handles the dumping and loading logic for generated and |
|
references samples across epochs for a given xp, providing a simple API to |
|
store, retrieve and compare audio samples. |
|
|
|
Args: |
|
xp (dora.XP): Dora experiment object. The XP contains information on the XP folder |
|
where all outputs are stored and the configuration of the experiment, |
|
which is useful to retrieve audio-related parameters. |
|
map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples |
|
instead of generating a dedicated hash id. This is useful to allow easier comparison |
|
with ground truth sample from the files directly without having to read the JSON metadata |
|
to do the mapping (at the cost of potentially dumping duplicate prompts/references |
|
depending on the task). |
|
""" |
|
def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False): |
|
self.xp = xp |
|
self.base_folder: Path = xp.folder / xp.cfg.generate.path |
|
self.reference_folder = self.base_folder / 'reference' |
|
self.map_reference_to_sample_id = map_reference_to_sample_id |
|
self.samples: tp.List[Sample] = [] |
|
self._load_samples() |
|
|
|
@property |
|
def latest_epoch(self): |
|
"""Latest epoch across all samples.""" |
|
return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0 |
|
|
|
def _load_samples(self): |
|
"""Scan the sample folder and load existing samples.""" |
|
jsons = self.base_folder.glob('**/*.json') |
|
with ThreadPoolExecutor(6) as pool: |
|
self.samples = list(pool.map(self._load_sample, jsons)) |
|
|
|
@staticmethod |
|
@lru_cache(2**26) |
|
def _load_sample(json_file: Path) -> Sample: |
|
with open(json_file, 'r') as f: |
|
data: tp.Dict[str, tp.Any] = json.load(f) |
|
|
|
prompt_data = data.get('prompt') |
|
prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'], |
|
duration=prompt_data['duration']) if prompt_data else None |
|
|
|
reference_data = data.get('reference') |
|
reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'], |
|
duration=reference_data['duration']) if reference_data else None |
|
|
|
return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'], |
|
prompt=prompt, conditioning=data.get('conditioning'), reference=reference, |
|
generation_args=data.get('generation_args')) |
|
|
|
def _init_hash(self): |
|
return hashlib.sha1() |
|
|
|
def _get_tensor_id(self, tensor: torch.Tensor) -> str: |
|
hash_id = self._init_hash() |
|
hash_id.update(tensor.numpy().data) |
|
return hash_id.hexdigest() |
|
|
|
def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor], |
|
conditions: tp.Optional[tp.Dict[str, str]]) -> str: |
|
"""Computes an id for a sample given its input data. |
|
This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input. |
|
Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned. |
|
|
|
Args: |
|
index (int): Batch index, Helpful to differentiate samples from the same batch. |
|
prompt_wav (torch.Tensor): Prompt used during generation. |
|
conditions (dict[str, str]): Conditioning used during generation. |
|
""" |
|
|
|
|
|
if prompt_wav is None and not conditions: |
|
return f"noinput_{uuid.uuid4().hex}" |
|
|
|
|
|
hr_label = "" |
|
|
|
hash_id = self._init_hash() |
|
hash_id.update(f"{index}".encode()) |
|
if prompt_wav is not None: |
|
hash_id.update(prompt_wav.numpy().data) |
|
hr_label += "_prompted" |
|
else: |
|
hr_label += "_unprompted" |
|
if conditions: |
|
encoded_json = json.dumps(conditions, sort_keys=True).encode() |
|
hash_id.update(encoded_json) |
|
cond_str = "-".join([f"{key}={slugify(value)}" |
|
for key, value in sorted(conditions.items())]) |
|
cond_str = cond_str[:100] |
|
cond_str = cond_str if len(cond_str) > 0 else "unconditioned" |
|
hr_label += f"_{cond_str}" |
|
else: |
|
hr_label += "_unconditioned" |
|
|
|
return hash_id.hexdigest() + hr_label |
|
|
|
def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path: |
|
"""Stores the audio with the given stem path using the XP's configuration. |
|
|
|
Args: |
|
wav (torch.Tensor): Audio to store. |
|
stem_path (Path): Path in sample output directory with file stem to use. |
|
overwrite (bool): When False (default), skips storing an existing audio file. |
|
Returns: |
|
Path: The path at which the audio is stored. |
|
""" |
|
existing_paths = [ |
|
path for path in stem_path.parent.glob(stem_path.stem + '.*') |
|
if path.suffix != '.json' |
|
] |
|
exists = len(existing_paths) > 0 |
|
if exists and overwrite: |
|
logger.warning(f"Overwriting existing audio file with stem path {stem_path}") |
|
elif exists: |
|
return existing_paths[0] |
|
|
|
audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio) |
|
return audio_path |
|
|
|
def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, |
|
conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None, |
|
ground_truth_wav: tp.Optional[torch.Tensor] = None, |
|
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample: |
|
"""Adds a single sample. |
|
The sample is stored in the XP's sample output directory, under a corresponding epoch folder. |
|
Each sample is assigned an id which is computed using the input data. In addition to the |
|
sample itself, a json file containing associated metadata is stored next to it. |
|
|
|
Args: |
|
sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape]. |
|
epoch (int): current training epoch. |
|
index (int): helpful to differentiate samples from the same batch. |
|
conditions (dict[str, str], optional): conditioning used during generation. |
|
prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape]. |
|
ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from. |
|
Tensor of shape [channels, shape]. |
|
generation_args (dict[str, any], optional): dictionary of other arguments used during generation. |
|
Returns: |
|
Sample: The saved sample. |
|
""" |
|
sample_id = self._get_sample_id(index, prompt_wav, conditions) |
|
reuse_id = self.map_reference_to_sample_id |
|
prompt, ground_truth = None, None |
|
if prompt_wav is not None: |
|
prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True)) |
|
prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate |
|
prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id) |
|
prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration) |
|
if ground_truth_wav is not None: |
|
ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True)) |
|
ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate |
|
ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id) |
|
ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration) |
|
sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True) |
|
duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate |
|
sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args) |
|
self.samples.append(sample) |
|
with open(sample_path.with_suffix('.json'), 'w') as f: |
|
json.dump(asdict(sample), f, indent=2) |
|
return sample |
|
|
|
def add_samples(self, samples_wavs: torch.Tensor, epoch: int, |
|
conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, |
|
prompt_wavs: tp.Optional[torch.Tensor] = None, |
|
ground_truth_wavs: tp.Optional[torch.Tensor] = None, |
|
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]: |
|
"""Adds a batch of samples. |
|
The samples are stored in the XP's sample output directory, under a corresponding |
|
epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. |
|
In addition to the sample itself, a json file containing associated metadata is stored next to it. |
|
|
|
Args: |
|
sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape]. |
|
epoch (int): Current training epoch. |
|
conditioning (list of dict[str, str], optional): List of conditions used during generation, |
|
one per sample in the batch. |
|
prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape |
|
[batch_size, channels, shape]. |
|
ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from. |
|
Tensor of shape [batch_size, channels, shape]. |
|
generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation. |
|
Returns: |
|
samples (list of Sample): The saved audio samples with prompts, ground truth and metadata. |
|
""" |
|
samples = [] |
|
for idx, wav in enumerate(samples_wavs): |
|
prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None |
|
gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None |
|
conditions = conditioning[idx] if conditioning is not None else None |
|
samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args)) |
|
return samples |
|
|
|
def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, |
|
exclude_unprompted: bool = False, exclude_conditioned: bool = False, |
|
exclude_unconditioned: bool = False) -> tp.Set[Sample]: |
|
"""Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. |
|
Please note that existing samples are loaded during the manager's initialization, and added samples through this |
|
manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager |
|
is the only way detect them. |
|
|
|
Args: |
|
epoch (int): If provided, only return samples corresponding to this epoch. |
|
max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch. |
|
exclude_prompted (bool): If True, does not include samples that used a prompt. |
|
exclude_unprompted (bool): If True, does not include samples that did not use a prompt. |
|
exclude_conditioned (bool): If True, excludes samples that used conditioning. |
|
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. |
|
Returns: |
|
Samples (set of Sample): The retrieved samples matching the provided filters. |
|
""" |
|
if max_epoch >= 0: |
|
samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch) |
|
else: |
|
samples_epoch = self.latest_epoch if epoch < 0 else epoch |
|
samples = { |
|
sample |
|
for sample in self.samples |
|
if ( |
|
(sample.epoch == samples_epoch) and |
|
(not exclude_prompted or sample.prompt is None) and |
|
(not exclude_unprompted or sample.prompt is not None) and |
|
(not exclude_conditioned or not sample.conditioning) and |
|
(not exclude_unconditioned or sample.conditioning) |
|
) |
|
} |
|
return samples |
|
|
|
|
|
def slugify(value: tp.Any, allow_unicode: bool = False): |
|
"""Process string for safer file naming. |
|
|
|
Taken from https://github.com/django/django/blob/master/django/utils/text.py |
|
|
|
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated |
|
dashes to single dashes. Remove characters that aren't alphanumerics, |
|
underscores, or hyphens. Convert to lowercase. Also strip leading and |
|
trailing whitespace, dashes, and underscores. |
|
""" |
|
value = str(value) |
|
if allow_unicode: |
|
value = unicodedata.normalize("NFKC", value) |
|
else: |
|
value = ( |
|
unicodedata.normalize("NFKD", value) |
|
.encode("ascii", "ignore") |
|
.decode("ascii") |
|
) |
|
value = re.sub(r"[^\w\s-]", "", value.lower()) |
|
return re.sub(r"[-\s]+", "-", value).strip("-_") |
|
|
|
|
|
def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: |
|
|
|
stable_samples_per_xp = [{ |
|
sample.id: sample for sample in samples |
|
if sample.prompt is not None or sample.conditioning |
|
} for samples in samples_per_xp] |
|
|
|
stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()} |
|
|
|
stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids} |
|
|
|
|
|
return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples} |
|
|
|
|
|
def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: |
|
|
|
unstable_samples_per_xp = [[ |
|
sample for sample in sorted(samples, key=lambda x: x.id) |
|
if sample.prompt is None and not sample.conditioning |
|
] for samples in samples_per_xp] |
|
|
|
min_len = min([len(samples) for samples in unstable_samples_per_xp]) |
|
unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp] |
|
|
|
return { |
|
f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len) |
|
} |
|
|
|
|
|
def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]: |
|
"""Gets a dictionary of matched samples across the given XPs. |
|
Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id |
|
will always match the number of XPs provided and will correspond to each XP in the same order given. |
|
In other words, only samples that can be match across all provided XPs will be returned |
|
in order to satisfy this rule. |
|
|
|
There are two types of ids that can be returned: stable and unstable. |
|
* Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs |
|
(prompts/conditioning). This is why we can match them across XPs. |
|
* Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples |
|
that used non-deterministic, random ids. This is the case for samples that did not use prompts or |
|
conditioning for their generation. This function will sort these samples by their id and match them |
|
by their index. |
|
|
|
Args: |
|
xps: a list of XPs to match samples from. |
|
start_epoch (int): If provided, only return samples corresponding to this epoch or newer. |
|
end_epoch (int): If provided, only return samples corresponding to this epoch or older. |
|
exclude_prompted (bool): If True, does not include samples that used a prompt. |
|
exclude_unprompted (bool): If True, does not include samples that did not use a prompt. |
|
exclude_conditioned (bool): If True, excludes samples that used conditioning. |
|
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. |
|
""" |
|
managers = [SampleManager(xp) for xp in xps] |
|
samples_per_xp = [manager.get_samples(**kwargs) for manager in managers] |
|
stable_samples = _match_stable_samples(samples_per_xp) |
|
unstable_samples = _match_unstable_samples(samples_per_xp) |
|
return dict(stable_samples, **unstable_samples) |
|
|