Spaces:
Running
Running
""" | |
Copyright 2025 Balacoon | |
Utils to interact with the `metrics` dataset. | |
""" | |
from typing import Optional | |
from huggingface_hub.hf_api import RepoFolder | |
from api import api | |
baselines_repo = "balacoon/speech_gen_baselines" | |
def get_system_types() -> list[str]: | |
""" | |
Get what types of systems user can check a leaderboard for. | |
We check `balacoon/speech_gen_baselines` dataset, | |
where synthesis from different models are stored. | |
For example, the dataset would have the following structure: | |
``` | |
speech_gen_baselines/ | |
zero-tts/ | |
vocoder/ | |
``` | |
""" | |
repo_tree = api.list_repo_tree( | |
baselines_repo, | |
repo_type="dataset", | |
recursive=False | |
) | |
top_level_dirs = [item.path for item in repo_tree if isinstance(item, RepoFolder)] | |
return top_level_dirs | |
def get_models(system_type: str) -> list[str]: | |
""" | |
Get all models under the given system type. | |
For example, for system_type="zero-tts", returns ["xtts", "yourtts"]. | |
""" | |
models_tree = api.list_repo_tree( | |
baselines_repo, | |
repo_type="dataset", | |
path_in_repo=system_type, | |
recursive=False | |
) | |
model_dirs = [item.path for item in models_tree if isinstance(item, RepoFolder)] | |
# Extract just the model names from the full paths | |
model_names = [path.split('/')[-1] for path in model_dirs] | |
return model_names | |
def get_datasets(system_type: str, model_dirs: Optional[list[str]] = None, return_union: bool = True) -> list[str]: | |
""" | |
Get what metrics on which datasets are available for the given system type. | |
Go through all systems under system type, and check datasets under each system. | |
The dataset would have the following structure: | |
``` | |
speech_gen_baselines/ | |
zero-tts/ | |
xtts/ | |
vctk/ | |
daps_celeb/ | |
yourtts/ | |
vctk/ | |
daps_celeb/ | |
``` | |
""" | |
if model_dirs is None: | |
# Get all models under the system type | |
model_dirs = get_models(system_type) | |
# Get all unique datasets across all models | |
datasets_per_model = [] | |
for model_dir in model_dirs: | |
datasets_tree = api.list_repo_tree( | |
baselines_repo, | |
repo_type="dataset", | |
path_in_repo=system_type + "/" + model_dir, | |
recursive=False | |
) | |
model_datasets = [item.path.split('/')[-1] for item in datasets_tree if isinstance(item, RepoFolder)] | |
datasets_per_model.append(model_datasets) | |
if return_union: | |
# return all possible datasets for these models | |
return sorted(list(set().union(*datasets_per_model))) | |
else: | |
# return only datasets which are present in all models | |
return sorted(list(set.intersection(*map(set, datasets_per_model)))) |