Spaces:
Sleeping
Sleeping
from typing_extensions import get_args, get_origin, TypedDict | |
from typing import Any, Union, Literal, List, Tuple, Dict, Set, Annotated | |
from pydantic import create_model, BaseModel, RootModel | |
from src.common.env import build_default_namespace | |
def string_to_type(type_str: str) -> Union[type, Tuple[type, ...]]: | |
"""Converts a string representation of a type to an actual type.""" | |
namespace = build_default_namespace() | |
return eval(type_str, namespace, {}) | |
def matches_type(value: Any, type_hint: Union[type, Tuple[type, ...]]) -> bool: | |
"""Checks if a value matches a given type hint.""" | |
origin = get_origin(type_hint) | |
args = get_args(type_hint) | |
if origin is Union: | |
return any(matches_type(value, arg) for arg in args) | |
if origin is Literal: | |
return value in args | |
if origin is Annotated: | |
return matches_type(value, args[0]) | |
if origin is list or origin is List: | |
if not isinstance(value, list): | |
return False | |
if not args: | |
return True | |
return all(matches_type(item, args[0]) for item in value) | |
if origin is tuple or origin is Tuple: | |
if not isinstance(value, tuple): | |
return False | |
if not args: | |
return True | |
if len(args) == 2 and args[1] is Ellipsis: | |
return all(matches_type(item, args[0]) for item in value) | |
if len(args) != len(value): | |
return False | |
return all(matches_type(item, sub_type) for item, sub_type in zip(value, args)) | |
if origin is dict or origin is Dict: | |
if not isinstance(value, dict): | |
return False | |
if not args: | |
return True | |
key_type, val_type = args | |
return all( | |
matches_type(k, key_type) and matches_type(v, val_type) | |
for k, v in value.items() | |
) | |
if origin is set or origin is Set: | |
if not isinstance(value, set): | |
return False | |
if not args: | |
return True | |
return all(matches_type(item, args[0]) for item in value) | |
if type_hint is type(None): | |
return value is None | |
if type_hint is Any: | |
return True | |
try: | |
return isinstance(value, type_hint) | |
except TypeError: | |
return False | |
def make_answer_model( | |
type_str: str, | |
field_name: str = "answer", | |
model_name: str = "AnswerModel", | |
add_thinking_field: bool = False, | |
) -> type[BaseModel]: | |
""" | |
Creates a Pydantic model with one required field `field_name`, | |
whose type is taken from the string `type_str`. | |
If `add_thinking_field` is True, then a `thinking` field of type str is added. | |
The resulting class will have the name `model_name`. | |
""" | |
type_hint = string_to_type(type_str) | |
model = create_model( | |
model_name, | |
**( | |
( | |
{ | |
"thinking": (str, ...), | |
} | |
if add_thinking_field | |
else {} | |
) | |
| { | |
field_name: (type_hint, ...), | |
} | |
), | |
) | |
return model | |
def _build_typed_dict(name: str, keys: tuple, value_type: Any): | |
annotations = {k: value_type for k in keys} | |
return TypedDict(name, annotations, total=True) | |
def _transform_required_dicts(tp: Any, name_base: str = "TD") -> Any: | |
origin = get_origin(tp) | |
if origin in (dict, Dict): | |
k_type, v_type = get_args(tp) | |
if get_origin(k_type) is Literal: | |
literal_keys = get_args(k_type) | |
v_type_t = _transform_required_dicts(v_type, name_base + "V") | |
return _build_typed_dict(f"{name_base}Required", literal_keys, v_type_t) | |
k_type_t = _transform_required_dicts(k_type, name_base + "K") | |
v_type_t = _transform_required_dicts(v_type, name_base + "V") | |
return Dict[k_type_t, v_type_t] | |
if origin in (list, List): | |
(inner,) = get_args(tp) | |
inner_t = _transform_required_dicts(inner, name_base + "Item") | |
return List[inner_t] | |
if origin is Union: | |
return Union[ | |
tuple(_transform_required_dicts(a, name_base + "U") for a in get_args(tp)) | |
] | |
return tp | |
def make_root_model( | |
type_str: str, model_name: str = "Answer", make_required: bool = True | |
) -> type[BaseModel]: | |
""" | |
Creates a Pydantic root model equivalent to any type hint from string. | |
The resulting class will have a root field __root__ with the needed type, | |
and you can parse an object of this type directly. | |
""" | |
type_hint = string_to_type(type_str) | |
if make_required: | |
type_hint = _transform_required_dicts(type_hint, name_base=model_name + "Dict") | |
model = type(model_name, (RootModel[type_hint],), {}) | |
return model | |