romb-leaderboard / src /eval /matchers.py
d0rj's picture
style: code blacked
3e35a01
import collections
from typing import Any, Callable
from src.common.env import build_default_namespace
def _dict_to_tuple(dict_obj: dict) -> tuple[tuple]:
return tuple(sorted(dict_obj.items()))
def Am(y_true: list, y_pred: list) -> bool:
""" "Check if all elements in y_pred are present in y_true and vice versa."""
return all(y in y_true for y in y_pred) and all(y in y_pred for y in y_true)
def am(y_true: list, y_pred: list) -> bool:
"""Check if any elements in y_pred are present in y_true."""
return any(y in y_true for y in y_pred)
def em(y_true: Any, y_pred: Any) -> bool:
"""Check if the true answer and predicted answer are exactly the same."""
if type(y_true) is str:
y_true = y_true.lower()
y_pred = y_pred.lower()
return y_true == y_pred
def um(y_true: list, y_pred: list) -> bool:
"""Check if the true answer and predicted answer are unordered but contain the same elements."""
if len(y_true) != len(y_pred):
return False
if len(y_true) == 0:
return True
if (len(y_true) > 0 and type(y_true[0]) is dict) or (
len(y_true) == 0 and type(y_pred[0]) is dict
):
y_true = [_dict_to_tuple(item) for item in y_true]
y_pred = [_dict_to_tuple(item) for item in y_pred]
if type(y_true) != type(y_pred):
return False
return collections.Counter(y_true) == collections.Counter(y_pred)
def om(y_true: list, y_pred: list) -> bool:
"""Check if the true answer and predicted answer are in the same order."""
return list(y_true) == list(y_pred)
def um_om(y_true: list[list], y_pred: list[list]) -> bool:
"""Check if the true answer and predicted answer are unordered lists of ordered sublists."""
true_bags = collections.Counter(tuple(sub) for sub in y_true)
pred_bags = collections.Counter(tuple(sub) for sub in y_pred)
return true_bags == pred_bags
def um_um(y_true: list[list], y_pred: list[list]) -> bool:
"""Check if the true answer and predicted answer are unordered lists of unordered sublists."""
true_sets = [tuple(sorted(sub)) for sub in y_true]
pred_sets = [tuple(sorted(sub)) for sub in y_pred]
return collections.Counter(true_sets) == collections.Counter(pred_sets)
def _build_custom(check_code: str) -> Callable[[Any, Any], bool]:
"""
Builds a custom function based on the provided check code.
The check code should be a string representing a Python expression.
"""
code = "\n".join([f" {line}" for line in check_code.splitlines()])
code = f"def check(y_true: Any, y_pred: Any) -> bool:\n{code}"
namespace = build_default_namespace()
exec(code, namespace)
return namespace["check"]
def _build_dict(type_dict: dict[Any, str]) -> Callable[[Any, Any], bool]:
"""
Builds a function that checks if the predicted answer matches the true answer
for each field in the type dictionary.
"""
def check(y_true, y_pred) -> bool:
assert set(type_dict.keys()) == set(y_true.keys())
try:
for key, value in y_true.items():
key_check = build_check_function(type_dict[key])
if not key_check(y_true=value, y_pred=y_pred[key]):
return False
return True
except KeyError:
return False
return check
def build_check_function(
check_type: str, check_code: str | None = None
) -> Callable[[Any, Any], bool]:
"""
Returns a function that checks if the predicted answer matches the true answer.
Args:
check_type (str): The type of check to perform. Can be one of:
- "Am": All match
- "am": Any match
- "em": Exact match
- "um": Unordered match
- "om": Ordered match
- "um[om]": Unordered match with ordered sublists
- "um[um]": Unordered match with unordered sublists
- "custom": Custom check defined by `check_code`
- A dictionary where keys are field names and values are check types for each field.
check_code (str, optional): Custom check code to be executed if `check_type` is "custom".
It should define a function body without the function definition line.
Returns:
Callable[[Any, Any], bool]: A function that takes two arguments (true answer and predicted answer)
and returns True if they match according to the specified check type, otherwise False.
"""
check_functions = {
"Am": Am,
"am": am,
"em": em,
"um": um,
"um_f": um, # TODO: fraction of matched answers
"om": om,
"um[om]": um_om,
"um[um]": um_um,
}
try:
check_type_dict = eval(check_type)
if not type(check_type_dict) is dict:
check_type_dict = None
except:
check_type_dict = None
if check_type in check_functions:
return check_functions[check_type]
elif check_type == "custom" and check_code is not None:
return _build_custom(check_code)
elif check_type_dict:
return _build_dict(check_type_dict)
else:
raise ValueError(
f"Unknown check type: {check_type}. Available types: {list(check_functions.keys()) + ['custom']}."
)