File size: 4,758 Bytes
1719436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing_extensions import get_args, get_origin, TypedDict
from typing import Any, Union, Literal, List, Tuple, Dict, Set, Annotated

from pydantic import create_model, BaseModel, RootModel

from src.common.env import build_default_namespace


def string_to_type(type_str: str) -> Union[type, Tuple[type, ...]]:
    """Converts a string representation of a type to an actual type."""
    namespace = build_default_namespace()
    return eval(type_str, namespace, {})


def matches_type(value: Any, type_hint: Union[type, Tuple[type, ...]]) -> bool:
    """Checks if a value matches a given type hint."""
    origin = get_origin(type_hint)
    args = get_args(type_hint)

    if origin is Union:
        return any(matches_type(value, arg) for arg in args)

    if origin is Literal:
        return value in args

    if origin is Annotated:
        return matches_type(value, args[0])

    if origin is list or origin is List:
        if not isinstance(value, list):
            return False
        if not args:
            return True
        return all(matches_type(item, args[0]) for item in value)

    if origin is tuple or origin is Tuple:
        if not isinstance(value, tuple):
            return False
        if not args:
            return True
        if len(args) == 2 and args[1] is Ellipsis:
            return all(matches_type(item, args[0]) for item in value)
        if len(args) != len(value):
            return False
        return all(matches_type(item, sub_type) for item, sub_type in zip(value, args))

    if origin is dict or origin is Dict:
        if not isinstance(value, dict):
            return False
        if not args:
            return True
        key_type, val_type = args
        return all(
            matches_type(k, key_type) and matches_type(v, val_type)
            for k, v in value.items()
        )

    if origin is set or origin is Set:
        if not isinstance(value, set):
            return False
        if not args:
            return True
        return all(matches_type(item, args[0]) for item in value)

    if type_hint is type(None):
        return value is None

    if type_hint is Any:
        return True

    try:
        return isinstance(value, type_hint)
    except TypeError:
        return False


def make_answer_model(
    type_str: str,
    field_name: str = "answer",
    model_name: str = "AnswerModel",
    add_thinking_field: bool = False,
) -> type[BaseModel]:
    """
    Creates a Pydantic model with one required field `field_name`,
    whose type is taken from the string `type_str`.
    If `add_thinking_field` is True, then a `thinking` field of type str is added.

    The resulting class will have the name `model_name`.
    """

    type_hint = string_to_type(type_str)

    model = create_model(
        model_name,
        **(
            (
                {
                    "thinking": (str, ...),
                }
                if add_thinking_field
                else {}
            )
            | {
                field_name: (type_hint, ...),
            }
        ),
    )
    return model


def _build_typed_dict(name: str, keys: tuple, value_type: Any):
    annotations = {k: value_type for k in keys}
    return TypedDict(name, annotations, total=True)


def _transform_required_dicts(tp: Any, name_base: str = "TD") -> Any:
    origin = get_origin(tp)

    if origin in (dict, Dict):
        k_type, v_type = get_args(tp)
        if get_origin(k_type) is Literal:
            literal_keys = get_args(k_type)
            v_type_t = _transform_required_dicts(v_type, name_base + "V")
            return _build_typed_dict(f"{name_base}Required", literal_keys, v_type_t)
        k_type_t = _transform_required_dicts(k_type, name_base + "K")
        v_type_t = _transform_required_dicts(v_type, name_base + "V")
        return Dict[k_type_t, v_type_t]

    if origin in (list, List):
        (inner,) = get_args(tp)
        inner_t = _transform_required_dicts(inner, name_base + "Item")
        return List[inner_t]

    if origin is Union:
        return Union[
            tuple(_transform_required_dicts(a, name_base + "U") for a in get_args(tp))
        ]

    return tp


def make_root_model(
    type_str: str, model_name: str = "Answer", make_required: bool = True
) -> type[BaseModel]:
    """
    Creates a Pydantic root model equivalent to any type hint from string.
    The resulting class will have a root field __root__ with the needed type,
    and you can parse an object of this type directly.
    """

    type_hint = string_to_type(type_str)
    if make_required:
        type_hint = _transform_required_dicts(type_hint, name_base=model_name + "Dict")

    model = type(model_name, (RootModel[type_hint],), {})
    return model