177 lines
5.8 KiB
Python
177 lines
5.8 KiB
Python
import inspect
|
|
from dataclasses import Field, replace
|
|
from types import new_class
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Generic,
|
|
Iterable,
|
|
Mapping,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from apischema.methods import is_method, method_wrapper
|
|
from apischema.objects.fields import MISSING_DEFAULT, ObjectField, set_object_fields
|
|
from apischema.objects.getters import object_fields, parameters_as_fields
|
|
from apischema.type_names import type_name
|
|
from apischema.types import OrderedDict
|
|
from apischema.typing import get_type_hints
|
|
from apischema.utils import (
|
|
empty_dict,
|
|
substitute_type_vars,
|
|
subtyping_substitution,
|
|
to_pascal_case,
|
|
with_parameters,
|
|
)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def object_deserialization(
|
|
func: Callable[..., T],
|
|
*input_class_modifiers: Callable[[type], Any],
|
|
parameters_metadata: Mapping[str, Mapping] = None,
|
|
) -> Any:
|
|
fields = parameters_as_fields(func, parameters_metadata)
|
|
types = get_type_hints(func, include_extras=True)
|
|
if "return" not in types:
|
|
raise TypeError("Object deserialization must be typed")
|
|
return_type = types["return"]
|
|
bases = ()
|
|
if getattr(return_type, "__parameters__", ()):
|
|
bases = (Generic[return_type.__parameters__],) # type: ignore
|
|
elif func.__name__ != "<lambda>":
|
|
input_class_modifiers = (
|
|
type_name(to_pascal_case(func.__name__)),
|
|
*input_class_modifiers,
|
|
)
|
|
|
|
def __init__(self, **kwargs):
|
|
self.kwargs = kwargs
|
|
|
|
input_cls = new_class(
|
|
to_pascal_case(func.__name__),
|
|
bases,
|
|
exec_body=lambda ns: ns.update({"__init__": __init__}),
|
|
)
|
|
for modifier in input_class_modifiers:
|
|
modifier(input_cls)
|
|
set_object_fields(input_cls, fields)
|
|
if any(f.additional_properties for f in fields):
|
|
kwargs_param = next(f.name for f in fields if f.additional_properties)
|
|
|
|
def wrapper(input):
|
|
kwargs = input.kwargs.copy()
|
|
kwargs.update(kwargs.pop(kwargs_param))
|
|
return func(**kwargs)
|
|
|
|
else:
|
|
|
|
def wrapper(input):
|
|
return func(**input.kwargs)
|
|
|
|
wrapper.__annotations__["input"] = with_parameters(input_cls)
|
|
wrapper.__annotations__["return"] = return_type
|
|
return wrapper
|
|
|
|
|
|
def _fields_and_init(
|
|
cls: type, fields_and_methods: Union[Iterable[Any], Callable[[], Iterable[Any]]]
|
|
) -> Tuple[Sequence[ObjectField], Callable[[Any, Any], None]]:
|
|
fields = object_fields(cls, serialization=True)
|
|
output_fields: Dict[str, ObjectField] = OrderedDict()
|
|
methods = []
|
|
if callable(fields_and_methods):
|
|
fields_and_methods = fields_and_methods()
|
|
for elt in fields_and_methods:
|
|
if elt is ...:
|
|
output_fields.update(fields)
|
|
continue
|
|
if isinstance(elt, tuple):
|
|
elt, metadata = elt
|
|
else:
|
|
metadata = empty_dict
|
|
if not isinstance(metadata, Mapping):
|
|
raise TypeError(f"Invalid metadata {metadata}")
|
|
if isinstance(elt, Field):
|
|
elt = elt.name
|
|
if isinstance(elt, str) and elt in fields:
|
|
elt = fields[elt]
|
|
if is_method(elt):
|
|
elt = method_wrapper(elt)
|
|
if isinstance(elt, ObjectField):
|
|
if metadata:
|
|
output_fields[elt.name] = replace(
|
|
elt, metadata={**elt.metadata, **metadata}, default=MISSING_DEFAULT
|
|
)
|
|
else:
|
|
output_fields[elt.name] = elt
|
|
continue
|
|
elif callable(elt):
|
|
types = get_type_hints(elt)
|
|
first_param = next(iter(inspect.signature(elt).parameters))
|
|
substitution, _ = subtyping_substitution(types.get(first_param, cls), cls)
|
|
ret = substitute_type_vars(types.get("return", Any), substitution)
|
|
output_fields[elt.__name__] = ObjectField(
|
|
elt.__name__, ret, metadata=metadata
|
|
)
|
|
methods.append((elt, output_fields[elt.__name__]))
|
|
else:
|
|
raise TypeError(f"Invalid serialization member {elt} for class {cls}")
|
|
|
|
serialized_methods = [m for m, f in methods if output_fields[f.name] is f]
|
|
serialized_fields = list(
|
|
output_fields.keys() - {m.__name__ for m in serialized_methods}
|
|
)
|
|
|
|
def __init__(self, obj):
|
|
for field in serialized_fields:
|
|
setattr(self, field, getattr(obj, field))
|
|
for method in serialized_methods:
|
|
setattr(self, method.__name__, method(obj))
|
|
|
|
return tuple(output_fields.values()), __init__
|
|
|
|
|
|
def object_serialization(
|
|
cls: Type[T],
|
|
fields_and_methods: Union[Iterable[Any], Callable[[], Iterable[Any]]],
|
|
*output_class_modifiers: Callable[[type], Any],
|
|
) -> Callable[[T], Any]:
|
|
|
|
generic, bases = cls, ()
|
|
if getattr(cls, "__parameters__", ()):
|
|
generic = cls[cls.__parameters__] # type: ignore
|
|
bases = Generic[cls.__parameters__] # type: ignore
|
|
elif (
|
|
callable(fields_and_methods)
|
|
and fields_and_methods.__name__ != "<lambda>"
|
|
and not getattr(cls, "__parameters__", ())
|
|
):
|
|
output_class_modifiers = (
|
|
type_name(to_pascal_case(fields_and_methods.__name__)),
|
|
*output_class_modifiers,
|
|
)
|
|
|
|
def __init__(self, obj):
|
|
_, new_init = _fields_and_init(cls, fields_and_methods)
|
|
new_init.__annotations__ = {"obj": generic}
|
|
output_cls.__init__ = new_init
|
|
new_init(self, obj)
|
|
|
|
__init__.__annotations__ = {"obj": generic}
|
|
output_cls = new_class(
|
|
f"{cls.__name__}Serialization",
|
|
bases,
|
|
exec_body=lambda ns: ns.update({"__init__": __init__}),
|
|
)
|
|
for modifier in output_class_modifiers:
|
|
modifier(output_cls)
|
|
set_object_fields(output_cls, lambda: _fields_and_init(cls, fields_and_methods)[0])
|
|
|
|
return output_cls
|