SanderGi's picture
fix and make functional, add more datasets
c2e60bb
raw
history blame
3.67 kB
# This modules handles the task queue
import os
import multiprocessing
from typing import TypedDict
from datetime import datetime
from metrics import per, fer
from datasets import load_from_disk
from hf import get_repo_info, add_leaderboard_entry
from inference import clear_cache, load_model, transcribe
leaderboard_lock = multiprocessing.Lock()
class Task(TypedDict):
status: str
display_name: str
repo_id: str
repo_hash: str
repo_last_modified: datetime
submission_timestamp: datetime
url: str
error: str | None
tasks: list[Task] = []
def get_status(query: str) -> dict:
"""Check status of an evaluation task by repo_id or repo_hash"""
query = query.strip().lower()
if not query:
return {"error": "Please enter a model id or task id"}
for task in reversed(tasks):
if task["repo_id"].lower() == query or task["repo_hash"].lower() == query:
return dict(task)
return {"error": f"No results found for '{query}'"}
def start_eval_task(display_name: str, repo_id: str, url: str) -> str:
"""Start evaluation task in background. Returns task ID that can be used to check status."""
repo_hash, last_modified = get_repo_info(repo_id)
# TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute
task = Task(
status="submitted",
display_name=display_name,
repo_id=repo_id,
repo_hash=repo_hash,
repo_last_modified=last_modified,
submission_timestamp=datetime.now(),
url=url,
error=None,
)
manager = multiprocessing.Manager()
task_proxy = manager.dict(task)
tasks.append(task_proxy) # type: ignore
multiprocessing.Process(
target=_eval_task, args=[task_proxy, leaderboard_lock]
).start()
return repo_hash
test_ds = load_from_disk(os.path.join(os.path.dirname(__file__), "data", "test"))
def _eval_task(task: Task, leaderboard_lock):
"""Background task to evaluate model and save updated results"""
try:
# Indicate task is processing
task["status"] = "evaluating"
# Evaluate model
average_per = 0
average_fer = 0
per_dataset_fers = {}
clear_cache()
model, processor = load_model(task["repo_id"])
for row in test_ds:
transcript = transcribe(row["audio"]["array"], model, processor) # type: ignore
row_per = per(transcript, row["ipa"]) # type: ignore
row_fer = fer(transcript, row["ipa"]) # type: ignore
average_per += row_per
average_fer += row_fer
per_dataset_fers[row["dataset"]] = per_dataset_fers.get(row["dataset"], 0) + row_fer # type: ignore
for key in per_dataset_fers.keys():
per_dataset_fers[key] /= len(test_ds.filter(lambda r: r["dataset"] == key))
average_per /= len(test_ds)
average_fer /= len(test_ds)
# Save results
with leaderboard_lock:
add_leaderboard_entry(
display_name=task["display_name"],
repo_id=task["repo_id"],
repo_hash=task["repo_hash"],
repo_last_modified=task["repo_last_modified"],
submission_timestamp=task["submission_timestamp"],
average_per=average_per,
average_fer=average_fer,
url=task["url"],
per_dataset_fers=per_dataset_fers,
)
# Mark task as complete
task["status"] = "completed"
except Exception as e:
task["status"] = "failed"
task["error"] = str(e)