Spaces:
Running
Running
import json | |
import os | |
from typing import Any, Dict | |
import pandas as pd | |
from huggingface_hub import HfApi, hf_hub_download, metadata_load | |
from .dataset_handler import DEPRECATED_VIDORE_2_DATASETS_KEYWORDS, DEPRECATED_VIDORE_DATASETS_KEYWORDS, deprecated_get_datasets_nickname | |
BLOCKLIST = ["impactframes"] | |
class DeprecatedModelHandler: | |
def __init__(self, model_infos_path="model_infos.json"): | |
self.api = HfApi() | |
self.model_infos_path = model_infos_path | |
self.model_infos = self._load_model_infos() | |
def _load_model_infos(self) -> Dict: | |
if os.path.exists(self.model_infos_path): | |
with open(self.model_infos_path) as f: | |
return json.load(f) | |
return {} | |
def _save_model_infos(self): | |
with open(self.model_infos_path, "w") as f: | |
json.dump(self.model_infos, f) | |
def _are_results_in_new_vidore_format(self, results: Dict[str, Any]) -> bool: | |
return "metadata" in results and "metrics" in results | |
def _is_baseline_repo(self, repo_id: str) -> bool: | |
return repo_id == "vidore/baseline-results" | |
def sanitize_model_name(self, model_name): | |
return model_name.replace("/", "_").replace(".", "-thisisapoint-") | |
def fuze_model_infos(self, model_name, results): | |
for dataset, metrics in results.items(): | |
if dataset not in self.model_infos[model_name]["results"].keys(): | |
self.model_infos[model_name]["results"][dataset] = metrics | |
else: | |
continue | |
def get_vidore_data(self, metric="ndcg_at_5"): | |
models = self.api.list_models(filter="vidore") | |
repositories = [model.modelId for model in models] # type: ignore | |
# Sort repositories to process non-baseline repos first (to prioritize their results) | |
repositories.sort(key=lambda x: self._is_baseline_repo(x)) | |
for repo_id in repositories: | |
org_name = repo_id.split("/")[0] | |
if org_name in BLOCKLIST: | |
continue | |
files = [f for f in self.api.list_repo_files(repo_id) if f.endswith("_metrics.json") or f == "results.json"] | |
if len(files) == 0: | |
continue | |
else: | |
for file in files: | |
if file.endswith("results.json"): | |
model_name = repo_id.replace("/", "_").replace(".", "-thisisapoint-") | |
else: | |
model_name = file.split("_metrics.json")[0] | |
model_name = model_name.replace("/", "_").replace(".", "-thisisapoint-") | |
# Skip if the model is from baseline and we already have results | |
readme_path = hf_hub_download(repo_id, filename="README.md") | |
meta = metadata_load(readme_path) | |
try: | |
result_path = hf_hub_download(repo_id, filename=file) | |
with open(result_path) as f: | |
results = json.load(f) | |
if self._are_results_in_new_vidore_format(results): | |
metadata = results["metadata"] | |
results = results["metrics"] | |
# Handles the case where the model is both in baseline and outside of it | |
# (prioritizes the non-baseline results) | |
if self._is_baseline_repo(repo_id) and self.sanitize_model_name(model_name) in self.model_infos: | |
self.fuze_model_infos(model_name, results) | |
self.model_infos[model_name] = {"meta": meta, "results": results} | |
except Exception as e: | |
print(f"Error loading {model_name} - {e}") | |
continue | |
# In order to keep only models relevant to a benchmark | |
def filter_models_by_benchmark(self, benchmark_version=1): | |
filtered_model_infos = {} | |
keywords = DEPRECATED_VIDORE_DATASETS_KEYWORDS if benchmark_version == 1 else DEPRECATED_VIDORE_2_DATASETS_KEYWORDS | |
for model, info in self.model_infos.items(): | |
results = info["results"] | |
if any(any(keyword in dataset for keyword in keywords) for dataset in results.keys()): | |
filtered_model_infos[model] = info | |
return filtered_model_infos | |
# Compute the average of a metric for each model, | |
def render_df(self, metric="ndcg_at_5", benchmark_version=1): | |
model_res = {} | |
filtered_model_infos = self.filter_models_by_benchmark(benchmark_version) | |
if len(filtered_model_infos) > 0: | |
for model in filtered_model_infos.keys(): | |
res = filtered_model_infos[model]["results"] | |
dataset_res = {} | |
keywords = DEPRECATED_VIDORE_DATASETS_KEYWORDS if benchmark_version == 1 else DEPRECATED_VIDORE_2_DATASETS_KEYWORDS | |
for dataset in res.keys(): | |
if not any(keyword in dataset for keyword in keywords): | |
continue | |
dataset_nickname = deprecated_get_datasets_nickname(dataset) | |
dataset_res[dataset_nickname] = res[dataset][metric] | |
model_res[model] = dataset_res | |
df = pd.DataFrame(model_res).T | |
return df | |
return pd.DataFrame() | |