148 lines
4.5 KiB
Python
148 lines
4.5 KiB
Python
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Collection,
|
|
Dict,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
)
|
|
|
|
from apischema.conversions.conversions import AnyConversion, DefaultConversion
|
|
from apischema.conversions.visitor import (
|
|
ConversionsVisitor,
|
|
DeserializationVisitor,
|
|
SerializationVisitor,
|
|
)
|
|
from apischema.json_schema.conversions_resolver import WithConversionsResolver
|
|
from apischema.objects import ObjectField
|
|
from apischema.objects.visitor import (
|
|
DeserializationObjectVisitor,
|
|
ObjectVisitor,
|
|
SerializationObjectVisitor,
|
|
)
|
|
from apischema.type_names import TypeNameFactory, get_type_name
|
|
from apischema.types import AnyType
|
|
from apischema.utils import is_hashable, replace_builtins
|
|
from apischema.visitor import Unsupported
|
|
|
|
try:
|
|
from apischema.typing import Annotated
|
|
except ImportError:
|
|
Annotated = ... # type: ignore
|
|
|
|
Refs = Dict[str, Tuple[AnyType, int]]
|
|
|
|
|
|
class Recursive(Exception):
|
|
pass
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class RefsExtractor(ConversionsVisitor, ObjectVisitor, WithConversionsResolver):
|
|
def __init__(self, default_conversion: DefaultConversion, refs: Refs):
|
|
super().__init__(default_conversion)
|
|
self.refs = refs
|
|
self._rec_guard: Dict[
|
|
Tuple[AnyType, Optional[AnyConversion]], int
|
|
] = defaultdict(lambda: 0)
|
|
|
|
def _incr_ref(self, ref: Optional[str], tp: AnyType) -> bool:
|
|
if ref is None:
|
|
return False
|
|
else:
|
|
ref_cls, count = self.refs.get(ref, (tp, 0))
|
|
if replace_builtins(ref_cls) != replace_builtins(tp):
|
|
raise ValueError(
|
|
f"Types {tp} and {self.refs[ref][0]} share same reference '{ref}'"
|
|
)
|
|
self.refs[ref] = (ref_cls, count + 1)
|
|
return count > 0
|
|
|
|
def annotated(self, tp: AnyType, annotations: Sequence[Any]):
|
|
for i, annotation in enumerate(reversed(annotations)):
|
|
if isinstance(annotation, TypeNameFactory):
|
|
ref = annotation.to_type_name(tp).json_schema
|
|
if not isinstance(ref, str):
|
|
continue
|
|
ref_annotations = annotations[: len(annotations) - i]
|
|
annotated = Annotated[(tp, *ref_annotations)] # type: ignore
|
|
if self._incr_ref(ref, annotated):
|
|
return
|
|
return super().annotated(tp, annotations)
|
|
|
|
def any(self):
|
|
pass
|
|
|
|
def collection(self, cls: Type[Collection], value_type: AnyType):
|
|
self.visit(value_type)
|
|
|
|
def enum(self, cls: Type[Enum]):
|
|
pass
|
|
|
|
def literal(self, values: Sequence[Any]):
|
|
pass
|
|
|
|
def mapping(self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType):
|
|
self.visit(key_type)
|
|
self.visit(value_type)
|
|
|
|
def object(self, tp: AnyType, fields: Sequence[ObjectField]):
|
|
for field in fields:
|
|
self.visit_with_conv(field.type, self._field_conversion(field))
|
|
|
|
def primitive(self, cls: Type):
|
|
pass
|
|
|
|
def tuple(self, types: Sequence[AnyType]):
|
|
for cls in types:
|
|
self.visit(cls)
|
|
|
|
def _visited_union(self, results: Sequence):
|
|
pass
|
|
|
|
def visit_conversion(
|
|
self,
|
|
tp: AnyType,
|
|
conversion: Optional[Any],
|
|
dynamic: bool,
|
|
next_conversion: Optional[AnyConversion] = None,
|
|
):
|
|
ref_types = []
|
|
if not dynamic:
|
|
for ref_tp in self.resolve_conversion(tp):
|
|
ref_types.append(ref_tp)
|
|
if self._incr_ref(get_type_name(ref_tp).json_schema, ref_tp):
|
|
return
|
|
if not is_hashable(tp):
|
|
return super().visit_conversion(tp, conversion, dynamic, next_conversion)
|
|
# 2 because the first type encountered of the recursive cycle can have no ref
|
|
# (see test_recursive_by_conversion_schema)
|
|
if self._rec_guard[(tp, self._conversion)] > 2:
|
|
raise TypeError(f"Recursive type {tp} need a ref")
|
|
self._rec_guard[(tp, self._conversion)] += 1
|
|
try:
|
|
super().visit_conversion(tp, conversion, dynamic, next_conversion)
|
|
except Unsupported:
|
|
for ref_tp in ref_types:
|
|
self.refs.pop(get_type_name(ref_tp).json_schema, ...) # type: ignore
|
|
finally:
|
|
self._rec_guard[(tp, self._conversion)] -= 1
|
|
|
|
|
|
class DeserializationRefsExtractor(
|
|
RefsExtractor, DeserializationVisitor, DeserializationObjectVisitor
|
|
):
|
|
pass
|
|
|
|
|
|
class SerializationRefsExtractor(
|
|
RefsExtractor, SerializationVisitor, SerializationObjectVisitor
|
|
):
|
|
pass
|