tokencrawler/.venv/lib/python3.9/site-packages/apischema/conversions/visitor.py
2022-03-17 22:16:30 +01:00

245 lines
8.3 KiB
Python

from contextlib import contextmanager, suppress
from dataclasses import replace
from functools import lru_cache
from types import new_class
from typing import (
Any,
Collection,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from apischema.conversions import LazyConversion
from apischema.conversions.conversions import (
AnyConversion,
DefaultConversion,
ResolvedConversion,
ResolvedConversions,
handle_identity_conversion,
is_identity,
resolve_any_conversion,
)
from apischema.conversions.utils import is_convertible
from apischema.metadata.implem import ConversionMetadata
from apischema.metadata.keys import CONVERSION_METADATA
from apischema.type_names import type_name
from apischema.types import AnyType
from apischema.typing import get_args, is_type_var
from apischema.utils import (
context_setter,
get_args2,
get_origin_or_type,
has_type_vars,
is_subclass,
substitute_type_vars,
subtyping_substitution,
)
from apischema.visitor import Result, Unsupported, Visitor
Deserialization = ResolvedConversions
Serialization = ResolvedConversion
Conv = TypeVar("Conv")
class ConversionsVisitor(Visitor[Result], Generic[Conv, Result]):
def __init__(self, default_conversion: DefaultConversion):
self.default_conversion = default_conversion
self._conversion: Optional[AnyConversion] = None
def _has_conversion(
self, tp: AnyType, conversion: Optional[AnyConversion]
) -> Tuple[bool, Optional[Conv]]:
raise NotImplementedError
def _annotated_conversion(
self, annotation: ConversionMetadata
) -> Optional[AnyConversion]:
raise NotImplementedError
def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Result:
for annotation in reversed(annotations):
if isinstance(annotation, Mapping) and CONVERSION_METADATA in annotation:
with self._replace_conversion(
self._annotated_conversion(annotation[CONVERSION_METADATA])
):
return super().annotated(tp, annotations)
return super().annotated(tp, annotations)
def _union_results(self, alternatives: Iterable[AnyType]) -> Sequence[Result]:
results = []
for alt in alternatives:
with suppress(Unsupported):
results.append(self.visit(alt))
if not results:
raise Unsupported(Union[tuple(alternatives)])
return results
def _visited_union(self, results: Sequence[Result]) -> Result:
raise NotImplementedError
def union(self, alternatives: Sequence[AnyType]) -> Result:
return self._visited_union(self._union_results(alternatives))
@contextmanager
def _replace_conversion(self, conversion: Optional[AnyConversion]):
with context_setter(self):
self._conversion = resolve_any_conversion(conversion) or None
yield
def visit_with_conv(
self, tp: AnyType, conversion: Optional[AnyConversion]
) -> Result:
with self._replace_conversion(conversion):
return self.visit(tp)
def _visit_conversion(
self,
tp: AnyType,
conversion: Conv,
dynamic: bool,
next_conversion: Optional[AnyConversion],
) -> Result:
raise NotImplementedError
def visit_conversion(
self,
tp: AnyType,
conversion: Optional[Conv],
dynamic: bool,
next_conversion: Optional[AnyConversion] = None,
) -> Result:
if conversion is not None:
return self._visit_conversion(tp, conversion, dynamic, next_conversion)
else:
with self._replace_conversion(next_conversion):
return super().visit(tp)
def visit(self, tp: AnyType) -> Result:
if not is_convertible(tp):
return self.visit_conversion(tp, None, False, self._conversion)
dynamic, conversion = self._has_conversion(tp, self._conversion)
if not dynamic:
_, conversion = self._has_conversion(
tp, self.default_conversion(get_origin_or_type(tp)) # type: ignore
)
next_conversion = None
if not dynamic and is_subclass(tp, Collection):
next_conversion = self._conversion
return self.visit_conversion(tp, conversion, dynamic, next_conversion)
def sub_conversion(
conversion: ResolvedConversion, next_conversion: Optional[AnyConversion]
) -> Optional[AnyConversion]:
return (
LazyConversion(lambda: conversion.sub_conversion),
LazyConversion(lambda: next_conversion),
)
@lru_cache(maxsize=0)
def self_deserialization_wrapper(cls: Type) -> Type:
wrapper = new_class(
f"{cls.__name__}SelfDeserializer",
(cls[cls.__parameters__] if has_type_vars(cls) else cls,),
exec_body=lambda ns: ns.update(
{"__new__": lambda _, *args, **kwargs: cls(*args, **kwargs)}
),
)
return type_name(None)(wrapper)
class DeserializationVisitor(ConversionsVisitor[Deserialization, Result]):
@staticmethod
def _has_conversion(
tp: AnyType, conversion: Optional[AnyConversion]
) -> Tuple[bool, Optional[Deserialization]]:
identity_conv, result = False, []
for conv in resolve_any_conversion(conversion):
conv = handle_identity_conversion(conv, tp)
if is_subclass(conv.target, tp):
if is_identity(conv):
if identity_conv:
continue
identity_conv = True
wrapper: AnyType = self_deserialization_wrapper(
get_origin_or_type(tp)
)
if get_args(tp):
wrapper = wrapper[get_args(tp)]
conv = ResolvedConversion(replace(conv, source=wrapper))
if is_type_var(conv.source) or any(
map(is_type_var, get_args2(conv.source))
):
_, substitution = subtyping_substitution(tp, conv.target)
conv = replace(
conv, source=substitute_type_vars(conv.source, substitution)
)
result.append(ResolvedConversion(replace(conv, target=tp)))
if identity_conv and len(result) == 1:
return True, None
else:
return bool(result), tuple(result) or None
def _annotated_conversion(
self, annotation: ConversionMetadata
) -> Optional[AnyConversion]:
return annotation.deserialization
def _visit_conversion(
self,
tp: AnyType,
conversion: Deserialization,
dynamic: bool,
next_conversion: Optional[AnyConversion],
) -> Result:
results = [
self.visit_with_conv(conv.source, sub_conversion(conv, next_conversion))
for conv in conversion
]
return self._visited_union(results)
class SerializationVisitor(ConversionsVisitor[Serialization, Result]):
@staticmethod
def _has_conversion(
tp: AnyType, conversion: Optional[AnyConversion]
) -> Tuple[bool, Optional[Serialization]]:
for conv in resolve_any_conversion(conversion):
conv = handle_identity_conversion(conv, tp)
if is_subclass(tp, conv.source):
if is_identity(conv):
return True, None
if is_type_var(conv.target) or any(
map(is_type_var, get_args2(conv.target))
):
substitution, _ = subtyping_substitution(conv.source, tp)
conv = replace(
conv, target=substitute_type_vars(conv.target, substitution)
)
return True, ResolvedConversion(replace(conv, source=tp))
else:
return False, None
def _annotated_conversion(
self, annotation: ConversionMetadata
) -> Optional[AnyConversion]:
return annotation.serialization
def _visit_conversion(
self,
tp: AnyType,
conversion: Serialization,
dynamic: bool,
next_conversion: Optional[AnyConversion],
) -> Result:
return self.visit_with_conv(
conversion.target, sub_conversion(conversion, next_conversion)
)