245 lines
8.3 KiB
Python
245 lines
8.3 KiB
Python
from contextlib import contextmanager, suppress
|
|
from dataclasses import replace
|
|
from functools import lru_cache
|
|
from types import new_class
|
|
from typing import (
|
|
Any,
|
|
Collection,
|
|
Generic,
|
|
Iterable,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from apischema.conversions import LazyConversion
|
|
from apischema.conversions.conversions import (
|
|
AnyConversion,
|
|
DefaultConversion,
|
|
ResolvedConversion,
|
|
ResolvedConversions,
|
|
handle_identity_conversion,
|
|
is_identity,
|
|
resolve_any_conversion,
|
|
)
|
|
from apischema.conversions.utils import is_convertible
|
|
from apischema.metadata.implem import ConversionMetadata
|
|
from apischema.metadata.keys import CONVERSION_METADATA
|
|
from apischema.type_names import type_name
|
|
from apischema.types import AnyType
|
|
from apischema.typing import get_args, is_type_var
|
|
from apischema.utils import (
|
|
context_setter,
|
|
get_args2,
|
|
get_origin_or_type,
|
|
has_type_vars,
|
|
is_subclass,
|
|
substitute_type_vars,
|
|
subtyping_substitution,
|
|
)
|
|
from apischema.visitor import Result, Unsupported, Visitor
|
|
|
|
Deserialization = ResolvedConversions
|
|
Serialization = ResolvedConversion
|
|
Conv = TypeVar("Conv")
|
|
|
|
|
|
class ConversionsVisitor(Visitor[Result], Generic[Conv, Result]):
|
|
def __init__(self, default_conversion: DefaultConversion):
|
|
self.default_conversion = default_conversion
|
|
self._conversion: Optional[AnyConversion] = None
|
|
|
|
def _has_conversion(
|
|
self, tp: AnyType, conversion: Optional[AnyConversion]
|
|
) -> Tuple[bool, Optional[Conv]]:
|
|
raise NotImplementedError
|
|
|
|
def _annotated_conversion(
|
|
self, annotation: ConversionMetadata
|
|
) -> Optional[AnyConversion]:
|
|
raise NotImplementedError
|
|
|
|
def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Result:
|
|
for annotation in reversed(annotations):
|
|
if isinstance(annotation, Mapping) and CONVERSION_METADATA in annotation:
|
|
with self._replace_conversion(
|
|
self._annotated_conversion(annotation[CONVERSION_METADATA])
|
|
):
|
|
return super().annotated(tp, annotations)
|
|
return super().annotated(tp, annotations)
|
|
|
|
def _union_results(self, alternatives: Iterable[AnyType]) -> Sequence[Result]:
|
|
results = []
|
|
for alt in alternatives:
|
|
with suppress(Unsupported):
|
|
results.append(self.visit(alt))
|
|
if not results:
|
|
raise Unsupported(Union[tuple(alternatives)])
|
|
return results
|
|
|
|
def _visited_union(self, results: Sequence[Result]) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def union(self, alternatives: Sequence[AnyType]) -> Result:
|
|
return self._visited_union(self._union_results(alternatives))
|
|
|
|
@contextmanager
|
|
def _replace_conversion(self, conversion: Optional[AnyConversion]):
|
|
with context_setter(self):
|
|
self._conversion = resolve_any_conversion(conversion) or None
|
|
yield
|
|
|
|
def visit_with_conv(
|
|
self, tp: AnyType, conversion: Optional[AnyConversion]
|
|
) -> Result:
|
|
with self._replace_conversion(conversion):
|
|
return self.visit(tp)
|
|
|
|
def _visit_conversion(
|
|
self,
|
|
tp: AnyType,
|
|
conversion: Conv,
|
|
dynamic: bool,
|
|
next_conversion: Optional[AnyConversion],
|
|
) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def visit_conversion(
|
|
self,
|
|
tp: AnyType,
|
|
conversion: Optional[Conv],
|
|
dynamic: bool,
|
|
next_conversion: Optional[AnyConversion] = None,
|
|
) -> Result:
|
|
if conversion is not None:
|
|
return self._visit_conversion(tp, conversion, dynamic, next_conversion)
|
|
else:
|
|
with self._replace_conversion(next_conversion):
|
|
return super().visit(tp)
|
|
|
|
def visit(self, tp: AnyType) -> Result:
|
|
if not is_convertible(tp):
|
|
return self.visit_conversion(tp, None, False, self._conversion)
|
|
dynamic, conversion = self._has_conversion(tp, self._conversion)
|
|
if not dynamic:
|
|
_, conversion = self._has_conversion(
|
|
tp, self.default_conversion(get_origin_or_type(tp)) # type: ignore
|
|
)
|
|
next_conversion = None
|
|
if not dynamic and is_subclass(tp, Collection):
|
|
next_conversion = self._conversion
|
|
return self.visit_conversion(tp, conversion, dynamic, next_conversion)
|
|
|
|
|
|
def sub_conversion(
|
|
conversion: ResolvedConversion, next_conversion: Optional[AnyConversion]
|
|
) -> Optional[AnyConversion]:
|
|
return (
|
|
LazyConversion(lambda: conversion.sub_conversion),
|
|
LazyConversion(lambda: next_conversion),
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=0)
|
|
def self_deserialization_wrapper(cls: Type) -> Type:
|
|
wrapper = new_class(
|
|
f"{cls.__name__}SelfDeserializer",
|
|
(cls[cls.__parameters__] if has_type_vars(cls) else cls,),
|
|
exec_body=lambda ns: ns.update(
|
|
{"__new__": lambda _, *args, **kwargs: cls(*args, **kwargs)}
|
|
),
|
|
)
|
|
return type_name(None)(wrapper)
|
|
|
|
|
|
class DeserializationVisitor(ConversionsVisitor[Deserialization, Result]):
|
|
@staticmethod
|
|
def _has_conversion(
|
|
tp: AnyType, conversion: Optional[AnyConversion]
|
|
) -> Tuple[bool, Optional[Deserialization]]:
|
|
identity_conv, result = False, []
|
|
for conv in resolve_any_conversion(conversion):
|
|
conv = handle_identity_conversion(conv, tp)
|
|
if is_subclass(conv.target, tp):
|
|
if is_identity(conv):
|
|
if identity_conv:
|
|
continue
|
|
identity_conv = True
|
|
wrapper: AnyType = self_deserialization_wrapper(
|
|
get_origin_or_type(tp)
|
|
)
|
|
if get_args(tp):
|
|
wrapper = wrapper[get_args(tp)]
|
|
conv = ResolvedConversion(replace(conv, source=wrapper))
|
|
if is_type_var(conv.source) or any(
|
|
map(is_type_var, get_args2(conv.source))
|
|
):
|
|
_, substitution = subtyping_substitution(tp, conv.target)
|
|
conv = replace(
|
|
conv, source=substitute_type_vars(conv.source, substitution)
|
|
)
|
|
result.append(ResolvedConversion(replace(conv, target=tp)))
|
|
if identity_conv and len(result) == 1:
|
|
return True, None
|
|
else:
|
|
return bool(result), tuple(result) or None
|
|
|
|
def _annotated_conversion(
|
|
self, annotation: ConversionMetadata
|
|
) -> Optional[AnyConversion]:
|
|
return annotation.deserialization
|
|
|
|
def _visit_conversion(
|
|
self,
|
|
tp: AnyType,
|
|
conversion: Deserialization,
|
|
dynamic: bool,
|
|
next_conversion: Optional[AnyConversion],
|
|
) -> Result:
|
|
results = [
|
|
self.visit_with_conv(conv.source, sub_conversion(conv, next_conversion))
|
|
for conv in conversion
|
|
]
|
|
return self._visited_union(results)
|
|
|
|
|
|
class SerializationVisitor(ConversionsVisitor[Serialization, Result]):
|
|
@staticmethod
|
|
def _has_conversion(
|
|
tp: AnyType, conversion: Optional[AnyConversion]
|
|
) -> Tuple[bool, Optional[Serialization]]:
|
|
for conv in resolve_any_conversion(conversion):
|
|
conv = handle_identity_conversion(conv, tp)
|
|
if is_subclass(tp, conv.source):
|
|
if is_identity(conv):
|
|
return True, None
|
|
if is_type_var(conv.target) or any(
|
|
map(is_type_var, get_args2(conv.target))
|
|
):
|
|
substitution, _ = subtyping_substitution(conv.source, tp)
|
|
conv = replace(
|
|
conv, target=substitute_type_vars(conv.target, substitution)
|
|
)
|
|
return True, ResolvedConversion(replace(conv, source=tp))
|
|
else:
|
|
return False, None
|
|
|
|
def _annotated_conversion(
|
|
self, annotation: ConversionMetadata
|
|
) -> Optional[AnyConversion]:
|
|
return annotation.serialization
|
|
|
|
def _visit_conversion(
|
|
self,
|
|
tp: AnyType,
|
|
conversion: Serialization,
|
|
dynamic: bool,
|
|
next_conversion: Optional[AnyConversion],
|
|
) -> Result:
|
|
return self.visit_with_conv(
|
|
conversion.target, sub_conversion(conversion, next_conversion)
|
|
)
|