File size: 5,299 Bytes
1719436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e35a01
 
 
1719436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import collections
from typing import Any, Callable

from src.common.env import build_default_namespace


def _dict_to_tuple(dict_obj: dict) -> tuple[tuple]:
    return tuple(sorted(dict_obj.items()))


def Am(y_true: list, y_pred: list) -> bool:
    """ "Check if all elements in y_pred are present in y_true and vice versa."""
    return all(y in y_true for y in y_pred) and all(y in y_pred for y in y_true)


def am(y_true: list, y_pred: list) -> bool:
    """Check if any elements in y_pred are present in y_true."""
    return any(y in y_true for y in y_pred)


def em(y_true: Any, y_pred: Any) -> bool:
    """Check if the true answer and predicted answer are exactly the same."""
    if type(y_true) is str:
        y_true = y_true.lower()
        y_pred = y_pred.lower()
    return y_true == y_pred


def um(y_true: list, y_pred: list) -> bool:
    """Check if the true answer and predicted answer are unordered but contain the same elements."""
    if len(y_true) != len(y_pred):
        return False
    if len(y_true) == 0:
        return True
    if (len(y_true) > 0 and type(y_true[0]) is dict) or (
        len(y_true) == 0 and type(y_pred[0]) is dict
    ):
        y_true = [_dict_to_tuple(item) for item in y_true]
        y_pred = [_dict_to_tuple(item) for item in y_pred]
    if type(y_true) != type(y_pred):
        return False
    return collections.Counter(y_true) == collections.Counter(y_pred)


def om(y_true: list, y_pred: list) -> bool:
    """Check if the true answer and predicted answer are in the same order."""
    return list(y_true) == list(y_pred)


def um_om(y_true: list[list], y_pred: list[list]) -> bool:
    """Check if the true answer and predicted answer are unordered lists of ordered sublists."""
    true_bags = collections.Counter(tuple(sub) for sub in y_true)
    pred_bags = collections.Counter(tuple(sub) for sub in y_pred)
    return true_bags == pred_bags


def um_um(y_true: list[list], y_pred: list[list]) -> bool:
    """Check if the true answer and predicted answer are unordered lists of unordered sublists."""
    true_sets = [tuple(sorted(sub)) for sub in y_true]
    pred_sets = [tuple(sorted(sub)) for sub in y_pred]
    return collections.Counter(true_sets) == collections.Counter(pred_sets)


def _build_custom(check_code: str) -> Callable[[Any, Any], bool]:
    """
    Builds a custom function based on the provided check code.
    The check code should be a string representing a Python expression.
    """
    code = "\n".join([f"    {line}" for line in check_code.splitlines()])
    code = f"def check(y_true: Any, y_pred: Any) -> bool:\n{code}"
    namespace = build_default_namespace()
    exec(code, namespace)
    return namespace["check"]


def _build_dict(type_dict: dict[Any, str]) -> Callable[[Any, Any], bool]:
    """
    Builds a function that checks if the predicted answer matches the true answer
    for each field in the type dictionary.
    """

    def check(y_true, y_pred) -> bool:
        assert set(type_dict.keys()) == set(y_true.keys())
        try:
            for key, value in y_true.items():
                key_check = build_check_function(type_dict[key])
                if not key_check(y_true=value, y_pred=y_pred[key]):
                    return False
            return True
        except KeyError:
            return False

    return check


def build_check_function(
    check_type: str, check_code: str | None = None
) -> Callable[[Any, Any], bool]:
    """
    Returns a function that checks if the predicted answer matches the true answer.

    Args:
        check_type (str): The type of check to perform. Can be one of:
            - "Am": All match
            - "am": Any match
            - "em": Exact match
            - "um": Unordered match
            - "om": Ordered match
            - "um[om]": Unordered match with ordered sublists
            - "um[um]": Unordered match with unordered sublists
            - "custom": Custom check defined by `check_code`
            - A dictionary where keys are field names and values are check types for each field.
        check_code (str, optional): Custom check code to be executed if `check_type` is "custom".
            It should define a function body without the function definition line.
    Returns:
        Callable[[Any, Any], bool]: A function that takes two arguments (true answer and predicted answer)
            and returns True if they match according to the specified check type, otherwise False.
    """
    check_functions = {
        "Am": Am,
        "am": am,
        "em": em,
        "um": um,
        "um_f": um,  # TODO: fraction of matched answers
        "om": om,
        "um[om]": um_om,
        "um[um]": um_um,
    }

    try:
        check_type_dict = eval(check_type)
        if not type(check_type_dict) is dict:
            check_type_dict = None
    except:
        check_type_dict = None

    if check_type in check_functions:
        return check_functions[check_type]
    elif check_type == "custom" and check_code is not None:
        return _build_custom(check_code)
    elif check_type_dict:
        return _build_dict(check_type_dict)
    else:
        raise ValueError(
            f"Unknown check type: {check_type}. Available types: {list(check_functions.keys()) + ['custom']}."
        )