Spaces:
Running
Running
import collections | |
from typing import Any, Callable | |
from src.common.env import build_default_namespace | |
def _dict_to_tuple(dict_obj: dict) -> tuple[tuple]: | |
return tuple(sorted(dict_obj.items())) | |
def Am(y_true: list, y_pred: list) -> bool: | |
""" "Check if all elements in y_pred are present in y_true and vice versa.""" | |
return all(y in y_true for y in y_pred) and all(y in y_pred for y in y_true) | |
def am(y_true: list, y_pred: list) -> bool: | |
"""Check if any elements in y_pred are present in y_true.""" | |
return any(y in y_true for y in y_pred) | |
def em(y_true: Any, y_pred: Any) -> bool: | |
"""Check if the true answer and predicted answer are exactly the same.""" | |
if type(y_true) is str: | |
y_true = y_true.lower() | |
y_pred = y_pred.lower() | |
return y_true == y_pred | |
def um(y_true: list, y_pred: list) -> bool: | |
"""Check if the true answer and predicted answer are unordered but contain the same elements.""" | |
if len(y_true) != len(y_pred): | |
return False | |
if len(y_true) == 0: | |
return True | |
if (len(y_true) > 0 and type(y_true[0]) is dict) or ( | |
len(y_true) == 0 and type(y_pred[0]) is dict | |
): | |
y_true = [_dict_to_tuple(item) for item in y_true] | |
y_pred = [_dict_to_tuple(item) for item in y_pred] | |
if type(y_true) != type(y_pred): | |
return False | |
return collections.Counter(y_true) == collections.Counter(y_pred) | |
def om(y_true: list, y_pred: list) -> bool: | |
"""Check if the true answer and predicted answer are in the same order.""" | |
return list(y_true) == list(y_pred) | |
def um_om(y_true: list[list], y_pred: list[list]) -> bool: | |
"""Check if the true answer and predicted answer are unordered lists of ordered sublists.""" | |
true_bags = collections.Counter(tuple(sub) for sub in y_true) | |
pred_bags = collections.Counter(tuple(sub) for sub in y_pred) | |
return true_bags == pred_bags | |
def um_um(y_true: list[list], y_pred: list[list]) -> bool: | |
"""Check if the true answer and predicted answer are unordered lists of unordered sublists.""" | |
true_sets = [tuple(sorted(sub)) for sub in y_true] | |
pred_sets = [tuple(sorted(sub)) for sub in y_pred] | |
return collections.Counter(true_sets) == collections.Counter(pred_sets) | |
def _build_custom(check_code: str) -> Callable[[Any, Any], bool]: | |
""" | |
Builds a custom function based on the provided check code. | |
The check code should be a string representing a Python expression. | |
""" | |
code = "\n".join([f" {line}" for line in check_code.splitlines()]) | |
code = f"def check(y_true: Any, y_pred: Any) -> bool:\n{code}" | |
namespace = build_default_namespace() | |
exec(code, namespace) | |
return namespace["check"] | |
def _build_dict(type_dict: dict[Any, str]) -> Callable[[Any, Any], bool]: | |
""" | |
Builds a function that checks if the predicted answer matches the true answer | |
for each field in the type dictionary. | |
""" | |
def check(y_true, y_pred) -> bool: | |
assert set(type_dict.keys()) == set(y_true.keys()) | |
try: | |
for key, value in y_true.items(): | |
key_check = build_check_function(type_dict[key]) | |
if not key_check(y_true=value, y_pred=y_pred[key]): | |
return False | |
return True | |
except KeyError: | |
return False | |
return check | |
def build_check_function( | |
check_type: str, check_code: str | None = None | |
) -> Callable[[Any, Any], bool]: | |
""" | |
Returns a function that checks if the predicted answer matches the true answer. | |
Args: | |
check_type (str): The type of check to perform. Can be one of: | |
- "Am": All match | |
- "am": Any match | |
- "em": Exact match | |
- "um": Unordered match | |
- "om": Ordered match | |
- "um[om]": Unordered match with ordered sublists | |
- "um[um]": Unordered match with unordered sublists | |
- "custom": Custom check defined by `check_code` | |
- A dictionary where keys are field names and values are check types for each field. | |
check_code (str, optional): Custom check code to be executed if `check_type` is "custom". | |
It should define a function body without the function definition line. | |
Returns: | |
Callable[[Any, Any], bool]: A function that takes two arguments (true answer and predicted answer) | |
and returns True if they match according to the specified check type, otherwise False. | |
""" | |
check_functions = { | |
"Am": Am, | |
"am": am, | |
"em": em, | |
"um": um, | |
"um_f": um, # TODO: fraction of matched answers | |
"om": om, | |
"um[om]": um_om, | |
"um[um]": um_um, | |
} | |
try: | |
check_type_dict = eval(check_type) | |
if not type(check_type_dict) is dict: | |
check_type_dict = None | |
except: | |
check_type_dict = None | |
if check_type in check_functions: | |
return check_functions[check_type] | |
elif check_type == "custom" and check_code is not None: | |
return _build_custom(check_code) | |
elif check_type_dict: | |
return _build_dict(check_type_dict) | |
else: | |
raise ValueError( | |
f"Unknown check type: {check_type}. Available types: {list(check_functions.keys()) + ['custom']}." | |
) | |