TTSLeaderboard / dataset_utils.py
clementruhm's picture
Initial version of leaderboad
0dc360b
"""
Copyright 2025 Balacoon
Utils to interact with the `metrics` dataset.
"""
from typing import Optional
from huggingface_hub.hf_api import RepoFolder
from api import api
baselines_repo = "balacoon/speech_gen_baselines"
def get_system_types() -> list[str]:
"""
Get what types of systems user can check a leaderboard for.
We check `balacoon/speech_gen_baselines` dataset,
where synthesis from different models are stored.
For example, the dataset would have the following structure:
```
speech_gen_baselines/
zero-tts/
vocoder/
```
"""
repo_tree = api.list_repo_tree(
baselines_repo,
repo_type="dataset",
recursive=False
)
top_level_dirs = [item.path for item in repo_tree if isinstance(item, RepoFolder)]
return top_level_dirs
def get_models(system_type: str) -> list[str]:
"""
Get all models under the given system type.
For example, for system_type="zero-tts", returns ["xtts", "yourtts"].
"""
models_tree = api.list_repo_tree(
baselines_repo,
repo_type="dataset",
path_in_repo=system_type,
recursive=False
)
model_dirs = [item.path for item in models_tree if isinstance(item, RepoFolder)]
# Extract just the model names from the full paths
model_names = [path.split('/')[-1] for path in model_dirs]
return model_names
def get_datasets(system_type: str, model_dirs: Optional[list[str]] = None, return_union: bool = True) -> list[str]:
"""
Get what metrics on which datasets are available for the given system type.
Go through all systems under system type, and check datasets under each system.
The dataset would have the following structure:
```
speech_gen_baselines/
zero-tts/
xtts/
vctk/
daps_celeb/
yourtts/
vctk/
daps_celeb/
```
"""
if model_dirs is None:
# Get all models under the system type
model_dirs = get_models(system_type)
# Get all unique datasets across all models
datasets_per_model = []
for model_dir in model_dirs:
datasets_tree = api.list_repo_tree(
baselines_repo,
repo_type="dataset",
path_in_repo=system_type + "/" + model_dir,
recursive=False
)
model_datasets = [item.path.split('/')[-1] for item in datasets_tree if isinstance(item, RepoFolder)]
datasets_per_model.append(model_datasets)
if return_union:
# return all possible datasets for these models
return sorted(list(set().union(*datasets_per_model)))
else:
# return only datasets which are present in all models
return sorted(list(set.intersection(*map(set, datasets_per_model))))