Spaces:
Sleeping
Sleeping
File size: 4,758 Bytes
1719436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from typing_extensions import get_args, get_origin, TypedDict
from typing import Any, Union, Literal, List, Tuple, Dict, Set, Annotated
from pydantic import create_model, BaseModel, RootModel
from src.common.env import build_default_namespace
def string_to_type(type_str: str) -> Union[type, Tuple[type, ...]]:
"""Converts a string representation of a type to an actual type."""
namespace = build_default_namespace()
return eval(type_str, namespace, {})
def matches_type(value: Any, type_hint: Union[type, Tuple[type, ...]]) -> bool:
"""Checks if a value matches a given type hint."""
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is Union:
return any(matches_type(value, arg) for arg in args)
if origin is Literal:
return value in args
if origin is Annotated:
return matches_type(value, args[0])
if origin is list or origin is List:
if not isinstance(value, list):
return False
if not args:
return True
return all(matches_type(item, args[0]) for item in value)
if origin is tuple or origin is Tuple:
if not isinstance(value, tuple):
return False
if not args:
return True
if len(args) == 2 and args[1] is Ellipsis:
return all(matches_type(item, args[0]) for item in value)
if len(args) != len(value):
return False
return all(matches_type(item, sub_type) for item, sub_type in zip(value, args))
if origin is dict or origin is Dict:
if not isinstance(value, dict):
return False
if not args:
return True
key_type, val_type = args
return all(
matches_type(k, key_type) and matches_type(v, val_type)
for k, v in value.items()
)
if origin is set or origin is Set:
if not isinstance(value, set):
return False
if not args:
return True
return all(matches_type(item, args[0]) for item in value)
if type_hint is type(None):
return value is None
if type_hint is Any:
return True
try:
return isinstance(value, type_hint)
except TypeError:
return False
def make_answer_model(
type_str: str,
field_name: str = "answer",
model_name: str = "AnswerModel",
add_thinking_field: bool = False,
) -> type[BaseModel]:
"""
Creates a Pydantic model with one required field `field_name`,
whose type is taken from the string `type_str`.
If `add_thinking_field` is True, then a `thinking` field of type str is added.
The resulting class will have the name `model_name`.
"""
type_hint = string_to_type(type_str)
model = create_model(
model_name,
**(
(
{
"thinking": (str, ...),
}
if add_thinking_field
else {}
)
| {
field_name: (type_hint, ...),
}
),
)
return model
def _build_typed_dict(name: str, keys: tuple, value_type: Any):
annotations = {k: value_type for k in keys}
return TypedDict(name, annotations, total=True)
def _transform_required_dicts(tp: Any, name_base: str = "TD") -> Any:
origin = get_origin(tp)
if origin in (dict, Dict):
k_type, v_type = get_args(tp)
if get_origin(k_type) is Literal:
literal_keys = get_args(k_type)
v_type_t = _transform_required_dicts(v_type, name_base + "V")
return _build_typed_dict(f"{name_base}Required", literal_keys, v_type_t)
k_type_t = _transform_required_dicts(k_type, name_base + "K")
v_type_t = _transform_required_dicts(v_type, name_base + "V")
return Dict[k_type_t, v_type_t]
if origin in (list, List):
(inner,) = get_args(tp)
inner_t = _transform_required_dicts(inner, name_base + "Item")
return List[inner_t]
if origin is Union:
return Union[
tuple(_transform_required_dicts(a, name_base + "U") for a in get_args(tp))
]
return tp
def make_root_model(
type_str: str, model_name: str = "Answer", make_required: bool = True
) -> type[BaseModel]:
"""
Creates a Pydantic root model equivalent to any type hint from string.
The resulting class will have a root field __root__ with the needed type,
and you can parse an object of this type directly.
"""
type_hint = string_to_type(type_str)
if make_required:
type_hint = _transform_required_dicts(type_hint, name_base=model_name + "Dict")
model = type(model_name, (RootModel[type_hint],), {})
return model
|