Spaces:
Sleeping
Sleeping
File size: 5,299 Bytes
1719436 3e35a01 1719436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import collections
from typing import Any, Callable
from src.common.env import build_default_namespace
def _dict_to_tuple(dict_obj: dict) -> tuple[tuple]:
return tuple(sorted(dict_obj.items()))
def Am(y_true: list, y_pred: list) -> bool:
""" "Check if all elements in y_pred are present in y_true and vice versa."""
return all(y in y_true for y in y_pred) and all(y in y_pred for y in y_true)
def am(y_true: list, y_pred: list) -> bool:
"""Check if any elements in y_pred are present in y_true."""
return any(y in y_true for y in y_pred)
def em(y_true: Any, y_pred: Any) -> bool:
"""Check if the true answer and predicted answer are exactly the same."""
if type(y_true) is str:
y_true = y_true.lower()
y_pred = y_pred.lower()
return y_true == y_pred
def um(y_true: list, y_pred: list) -> bool:
"""Check if the true answer and predicted answer are unordered but contain the same elements."""
if len(y_true) != len(y_pred):
return False
if len(y_true) == 0:
return True
if (len(y_true) > 0 and type(y_true[0]) is dict) or (
len(y_true) == 0 and type(y_pred[0]) is dict
):
y_true = [_dict_to_tuple(item) for item in y_true]
y_pred = [_dict_to_tuple(item) for item in y_pred]
if type(y_true) != type(y_pred):
return False
return collections.Counter(y_true) == collections.Counter(y_pred)
def om(y_true: list, y_pred: list) -> bool:
"""Check if the true answer and predicted answer are in the same order."""
return list(y_true) == list(y_pred)
def um_om(y_true: list[list], y_pred: list[list]) -> bool:
"""Check if the true answer and predicted answer are unordered lists of ordered sublists."""
true_bags = collections.Counter(tuple(sub) for sub in y_true)
pred_bags = collections.Counter(tuple(sub) for sub in y_pred)
return true_bags == pred_bags
def um_um(y_true: list[list], y_pred: list[list]) -> bool:
"""Check if the true answer and predicted answer are unordered lists of unordered sublists."""
true_sets = [tuple(sorted(sub)) for sub in y_true]
pred_sets = [tuple(sorted(sub)) for sub in y_pred]
return collections.Counter(true_sets) == collections.Counter(pred_sets)
def _build_custom(check_code: str) -> Callable[[Any, Any], bool]:
"""
Builds a custom function based on the provided check code.
The check code should be a string representing a Python expression.
"""
code = "\n".join([f" {line}" for line in check_code.splitlines()])
code = f"def check(y_true: Any, y_pred: Any) -> bool:\n{code}"
namespace = build_default_namespace()
exec(code, namespace)
return namespace["check"]
def _build_dict(type_dict: dict[Any, str]) -> Callable[[Any, Any], bool]:
"""
Builds a function that checks if the predicted answer matches the true answer
for each field in the type dictionary.
"""
def check(y_true, y_pred) -> bool:
assert set(type_dict.keys()) == set(y_true.keys())
try:
for key, value in y_true.items():
key_check = build_check_function(type_dict[key])
if not key_check(y_true=value, y_pred=y_pred[key]):
return False
return True
except KeyError:
return False
return check
def build_check_function(
check_type: str, check_code: str | None = None
) -> Callable[[Any, Any], bool]:
"""
Returns a function that checks if the predicted answer matches the true answer.
Args:
check_type (str): The type of check to perform. Can be one of:
- "Am": All match
- "am": Any match
- "em": Exact match
- "um": Unordered match
- "om": Ordered match
- "um[om]": Unordered match with ordered sublists
- "um[um]": Unordered match with unordered sublists
- "custom": Custom check defined by `check_code`
- A dictionary where keys are field names and values are check types for each field.
check_code (str, optional): Custom check code to be executed if `check_type` is "custom".
It should define a function body without the function definition line.
Returns:
Callable[[Any, Any], bool]: A function that takes two arguments (true answer and predicted answer)
and returns True if they match according to the specified check type, otherwise False.
"""
check_functions = {
"Am": Am,
"am": am,
"em": em,
"um": um,
"um_f": um, # TODO: fraction of matched answers
"om": om,
"um[om]": um_om,
"um[um]": um_um,
}
try:
check_type_dict = eval(check_type)
if not type(check_type_dict) is dict:
check_type_dict = None
except:
check_type_dict = None
if check_type in check_functions:
return check_functions[check_type]
elif check_type == "custom" and check_code is not None:
return _build_custom(check_code)
elif check_type_dict:
return _build_dict(check_type_dict)
else:
raise ValueError(
f"Unknown check type: {check_type}. Available types: {list(check_functions.keys()) + ['custom']}."
)
|