Spaces:
Running
Running
File size: 3,668 Bytes
c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb b389fb6 c2e60bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
# This modules handles the task queue
import os
import multiprocessing
from typing import TypedDict
from datetime import datetime
from metrics import per, fer
from datasets import load_from_disk
from hf import get_repo_info, add_leaderboard_entry
from inference import clear_cache, load_model, transcribe
leaderboard_lock = multiprocessing.Lock()
class Task(TypedDict):
status: str
display_name: str
repo_id: str
repo_hash: str
repo_last_modified: datetime
submission_timestamp: datetime
url: str
error: str | None
tasks: list[Task] = []
def get_status(query: str) -> dict:
"""Check status of an evaluation task by repo_id or repo_hash"""
query = query.strip().lower()
if not query:
return {"error": "Please enter a model id or task id"}
for task in reversed(tasks):
if task["repo_id"].lower() == query or task["repo_hash"].lower() == query:
return dict(task)
return {"error": f"No results found for '{query}'"}
def start_eval_task(display_name: str, repo_id: str, url: str) -> str:
"""Start evaluation task in background. Returns task ID that can be used to check status."""
repo_hash, last_modified = get_repo_info(repo_id)
# TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute
task = Task(
status="submitted",
display_name=display_name,
repo_id=repo_id,
repo_hash=repo_hash,
repo_last_modified=last_modified,
submission_timestamp=datetime.now(),
url=url,
error=None,
)
manager = multiprocessing.Manager()
task_proxy = manager.dict(task)
tasks.append(task_proxy) # type: ignore
multiprocessing.Process(
target=_eval_task, args=[task_proxy, leaderboard_lock]
).start()
return repo_hash
test_ds = load_from_disk(os.path.join(os.path.dirname(__file__), "data", "test"))
def _eval_task(task: Task, leaderboard_lock):
"""Background task to evaluate model and save updated results"""
try:
# Indicate task is processing
task["status"] = "evaluating"
# Evaluate model
average_per = 0
average_fer = 0
per_dataset_fers = {}
clear_cache()
model, processor = load_model(task["repo_id"])
for row in test_ds:
transcript = transcribe(row["audio"]["array"], model, processor) # type: ignore
row_per = per(transcript, row["ipa"]) # type: ignore
row_fer = fer(transcript, row["ipa"]) # type: ignore
average_per += row_per
average_fer += row_fer
per_dataset_fers[row["dataset"]] = per_dataset_fers.get(row["dataset"], 0) + row_fer # type: ignore
for key in per_dataset_fers.keys():
per_dataset_fers[key] /= len(test_ds.filter(lambda r: r["dataset"] == key))
average_per /= len(test_ds)
average_fer /= len(test_ds)
# Save results
with leaderboard_lock:
add_leaderboard_entry(
display_name=task["display_name"],
repo_id=task["repo_id"],
repo_hash=task["repo_hash"],
repo_last_modified=task["repo_last_modified"],
submission_timestamp=task["submission_timestamp"],
average_per=average_per,
average_fer=average_fer,
url=task["url"],
per_dataset_fers=per_dataset_fers,
)
# Mark task as complete
task["status"] = "completed"
except Exception as e:
task["status"] = "failed"
task["error"] = str(e)
|