Spaces:
Running
Running
| # This modules handles the task queue | |
| import os | |
| import multiprocessing | |
| from typing import TypedDict | |
| from datetime import datetime | |
| from metrics import per, fer | |
| from datasets import load_from_disk | |
| from hf import get_repo_info, add_leaderboard_entry | |
| from inference import clear_cache, load_model, transcribe | |
| leaderboard_lock = multiprocessing.Lock() | |
| class Task(TypedDict): | |
| status: str | |
| display_name: str | |
| repo_id: str | |
| repo_hash: str | |
| repo_last_modified: datetime | |
| submission_timestamp: datetime | |
| url: str | |
| error: str | None | |
| tasks: list[Task] = [] | |
| def get_status(query: str) -> dict: | |
| """Check status of an evaluation task by repo_id or repo_hash""" | |
| query = query.strip().lower() | |
| if not query: | |
| return {"error": "Please enter a model id or task id"} | |
| for task in reversed(tasks): | |
| if task["repo_id"].lower() == query or task["repo_hash"].lower() == query: | |
| return dict(task) | |
| return {"error": f"No results found for '{query}'"} | |
| def start_eval_task(display_name: str, repo_id: str, url: str) -> str: | |
| """Start evaluation task in background. Returns task ID that can be used to check status.""" | |
| repo_hash, last_modified = get_repo_info(repo_id) | |
| # TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute | |
| task = Task( | |
| status="submitted", | |
| display_name=display_name, | |
| repo_id=repo_id, | |
| repo_hash=repo_hash, | |
| repo_last_modified=last_modified, | |
| submission_timestamp=datetime.now(), | |
| url=url, | |
| error=None, | |
| ) | |
| manager = multiprocessing.Manager() | |
| task_proxy = manager.dict(task) | |
| tasks.append(task_proxy) # type: ignore | |
| multiprocessing.Process( | |
| target=_eval_task, args=[task_proxy, leaderboard_lock] | |
| ).start() | |
| return repo_hash | |
| test_ds = load_from_disk(os.path.join(os.path.dirname(__file__), "data", "test")) | |
| def _eval_task(task: Task, leaderboard_lock): | |
| """Background task to evaluate model and save updated results""" | |
| try: | |
| # Indicate task is processing | |
| task["status"] = "evaluating" | |
| # Evaluate model | |
| average_per = 0 | |
| average_fer = 0 | |
| per_dataset_fers = {} | |
| clear_cache() | |
| model, processor = load_model(task["repo_id"]) | |
| for row in test_ds: | |
| transcript = transcribe(row["audio"]["array"], model, processor) # type: ignore | |
| row_per = per(transcript, row["ipa"]) # type: ignore | |
| row_fer = fer(transcript, row["ipa"]) # type: ignore | |
| average_per += row_per | |
| average_fer += row_fer | |
| per_dataset_fers[row["dataset"]] = per_dataset_fers.get(row["dataset"], 0) + row_fer # type: ignore | |
| for key in per_dataset_fers.keys(): | |
| per_dataset_fers[key] /= len(test_ds.filter(lambda r: r["dataset"] == key)) | |
| average_per /= len(test_ds) | |
| average_fer /= len(test_ds) | |
| # Save results | |
| with leaderboard_lock: | |
| add_leaderboard_entry( | |
| display_name=task["display_name"], | |
| repo_id=task["repo_id"], | |
| repo_hash=task["repo_hash"], | |
| repo_last_modified=task["repo_last_modified"], | |
| submission_timestamp=task["submission_timestamp"], | |
| average_per=average_per, | |
| average_fer=average_fer, | |
| url=task["url"], | |
| per_dataset_fers=per_dataset_fers, | |
| ) | |
| # Mark task as complete | |
| task["status"] = "completed" | |
| except Exception as e: | |
| task["status"] = "failed" | |
| task["error"] = str(e) | |