BenCzechMark / model_compare.py
Lakoc
v0.0.1
b66f230
raw
history blame
887 Bytes
from functools import cmp_to_key
class ModelCompare:
def __init__(self, tasks, ranks: dict = None):
self.current_task = None
self.ranks = ranks
self.tasks = tasks
def compare_models(self, model_a, model_b):
if not self.ranks:
raise Exception("Missing model rankings")
res = self.ranks[model_a][model_b][self.current_task]
if res:
return 1
elif not res:
return -1
else:
return -1
def get_tasks_ranks(self, ranks: dict) -> dict:
"""Order models based on the significance improvement"""
self.ranks = ranks
tasks_ranks = {}
models = ranks.keys()
for task in self.tasks:
self.current_task = task
tasks_ranks[task] = sorted(models, key=cmp_to_key(self.compare_models))
return tasks_ranks