d0rj's picture
feat: Initial commit
1719436
raw
history blame
4.76 kB
from typing_extensions import get_args, get_origin, TypedDict
from typing import Any, Union, Literal, List, Tuple, Dict, Set, Annotated
from pydantic import create_model, BaseModel, RootModel
from src.common.env import build_default_namespace
def string_to_type(type_str: str) -> Union[type, Tuple[type, ...]]:
"""Converts a string representation of a type to an actual type."""
namespace = build_default_namespace()
return eval(type_str, namespace, {})
def matches_type(value: Any, type_hint: Union[type, Tuple[type, ...]]) -> bool:
"""Checks if a value matches a given type hint."""
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is Union:
return any(matches_type(value, arg) for arg in args)
if origin is Literal:
return value in args
if origin is Annotated:
return matches_type(value, args[0])
if origin is list or origin is List:
if not isinstance(value, list):
return False
if not args:
return True
return all(matches_type(item, args[0]) for item in value)
if origin is tuple or origin is Tuple:
if not isinstance(value, tuple):
return False
if not args:
return True
if len(args) == 2 and args[1] is Ellipsis:
return all(matches_type(item, args[0]) for item in value)
if len(args) != len(value):
return False
return all(matches_type(item, sub_type) for item, sub_type in zip(value, args))
if origin is dict or origin is Dict:
if not isinstance(value, dict):
return False
if not args:
return True
key_type, val_type = args
return all(
matches_type(k, key_type) and matches_type(v, val_type)
for k, v in value.items()
)
if origin is set or origin is Set:
if not isinstance(value, set):
return False
if not args:
return True
return all(matches_type(item, args[0]) for item in value)
if type_hint is type(None):
return value is None
if type_hint is Any:
return True
try:
return isinstance(value, type_hint)
except TypeError:
return False
def make_answer_model(
type_str: str,
field_name: str = "answer",
model_name: str = "AnswerModel",
add_thinking_field: bool = False,
) -> type[BaseModel]:
"""
Creates a Pydantic model with one required field `field_name`,
whose type is taken from the string `type_str`.
If `add_thinking_field` is True, then a `thinking` field of type str is added.
The resulting class will have the name `model_name`.
"""
type_hint = string_to_type(type_str)
model = create_model(
model_name,
**(
(
{
"thinking": (str, ...),
}
if add_thinking_field
else {}
)
| {
field_name: (type_hint, ...),
}
),
)
return model
def _build_typed_dict(name: str, keys: tuple, value_type: Any):
annotations = {k: value_type for k in keys}
return TypedDict(name, annotations, total=True)
def _transform_required_dicts(tp: Any, name_base: str = "TD") -> Any:
origin = get_origin(tp)
if origin in (dict, Dict):
k_type, v_type = get_args(tp)
if get_origin(k_type) is Literal:
literal_keys = get_args(k_type)
v_type_t = _transform_required_dicts(v_type, name_base + "V")
return _build_typed_dict(f"{name_base}Required", literal_keys, v_type_t)
k_type_t = _transform_required_dicts(k_type, name_base + "K")
v_type_t = _transform_required_dicts(v_type, name_base + "V")
return Dict[k_type_t, v_type_t]
if origin in (list, List):
(inner,) = get_args(tp)
inner_t = _transform_required_dicts(inner, name_base + "Item")
return List[inner_t]
if origin is Union:
return Union[
tuple(_transform_required_dicts(a, name_base + "U") for a in get_args(tp))
]
return tp
def make_root_model(
type_str: str, model_name: str = "Answer", make_required: bool = True
) -> type[BaseModel]:
"""
Creates a Pydantic root model equivalent to any type hint from string.
The resulting class will have a root field __root__ with the needed type,
and you can parse an object of this type directly.
"""
type_hint = string_to_type(type_str)
if make_required:
type_hint = _transform_required_dicts(type_hint, name_base=model_name + "Dict")
model = type(model_name, (RootModel[type_hint],), {})
return model