import collections from typing import Any, Callable from src.common.env import build_default_namespace def _dict_to_tuple(dict_obj: dict) -> tuple[tuple]: return tuple(sorted(dict_obj.items())) def Am(y_true: list, y_pred: list) -> bool: """ "Check if all elements in y_pred are present in y_true and vice versa.""" return all(y in y_true for y in y_pred) and all(y in y_pred for y in y_true) def am(y_true: list, y_pred: list) -> bool: """Check if any elements in y_pred are present in y_true.""" return any(y in y_true for y in y_pred) def em(y_true: Any, y_pred: Any) -> bool: """Check if the true answer and predicted answer are exactly the same.""" if type(y_true) is str: y_true = y_true.lower() y_pred = y_pred.lower() return y_true == y_pred def um(y_true: list, y_pred: list) -> bool: """Check if the true answer and predicted answer are unordered but contain the same elements.""" if len(y_true) != len(y_pred): return False if len(y_true) == 0: return True if (len(y_true) > 0 and type(y_true[0]) is dict) or ( len(y_true) == 0 and type(y_pred[0]) is dict ): y_true = [_dict_to_tuple(item) for item in y_true] y_pred = [_dict_to_tuple(item) for item in y_pred] if type(y_true) != type(y_pred): return False return collections.Counter(y_true) == collections.Counter(y_pred) def om(y_true: list, y_pred: list) -> bool: """Check if the true answer and predicted answer are in the same order.""" return list(y_true) == list(y_pred) def um_om(y_true: list[list], y_pred: list[list]) -> bool: """Check if the true answer and predicted answer are unordered lists of ordered sublists.""" true_bags = collections.Counter(tuple(sub) for sub in y_true) pred_bags = collections.Counter(tuple(sub) for sub in y_pred) return true_bags == pred_bags def um_um(y_true: list[list], y_pred: list[list]) -> bool: """Check if the true answer and predicted answer are unordered lists of unordered sublists.""" true_sets = [tuple(sorted(sub)) for sub in y_true] pred_sets = [tuple(sorted(sub)) for sub in y_pred] return collections.Counter(true_sets) == collections.Counter(pred_sets) def _build_custom(check_code: str) -> Callable[[Any, Any], bool]: """ Builds a custom function based on the provided check code. The check code should be a string representing a Python expression. """ code = "\n".join([f" {line}" for line in check_code.splitlines()]) code = f"def check(y_true: Any, y_pred: Any) -> bool:\n{code}" namespace = build_default_namespace() exec(code, namespace) return namespace["check"] def _build_dict(type_dict: dict[Any, str]) -> Callable[[Any, Any], bool]: """ Builds a function that checks if the predicted answer matches the true answer for each field in the type dictionary. """ def check(y_true, y_pred) -> bool: assert set(type_dict.keys()) == set(y_true.keys()) try: for key, value in y_true.items(): key_check = build_check_function(type_dict[key]) if not key_check(y_true=value, y_pred=y_pred[key]): return False return True except KeyError: return False return check def build_check_function( check_type: str, check_code: str | None = None ) -> Callable[[Any, Any], bool]: """ Returns a function that checks if the predicted answer matches the true answer. Args: check_type (str): The type of check to perform. Can be one of: - "Am": All match - "am": Any match - "em": Exact match - "um": Unordered match - "om": Ordered match - "um[om]": Unordered match with ordered sublists - "um[um]": Unordered match with unordered sublists - "custom": Custom check defined by `check_code` - A dictionary where keys are field names and values are check types for each field. check_code (str, optional): Custom check code to be executed if `check_type` is "custom". It should define a function body without the function definition line. Returns: Callable[[Any, Any], bool]: A function that takes two arguments (true answer and predicted answer) and returns True if they match according to the specified check type, otherwise False. """ check_functions = { "Am": Am, "am": am, "em": em, "um": um, "um_f": um, # TODO: fraction of matched answers "om": om, "um[om]": um_om, "um[um]": um_um, } try: check_type_dict = eval(check_type) if not type(check_type_dict) is dict: check_type_dict = None except: check_type_dict = None if check_type in check_functions: return check_functions[check_type] elif check_type == "custom" and check_code is not None: return _build_custom(check_code) elif check_type_dict: return _build_dict(check_type_dict) else: raise ValueError( f"Unknown check type: {check_type}. Available types: {list(check_functions.keys()) + ['custom']}." )