46 lines
1.7 KiB
Python
46 lines
1.7 KiB
Python
from typing import Iterator, Mapping, Sequence, Type
|
|
|
|
from apischema.conversions.conversions import DefaultConversion
|
|
from apischema.conversions.visitor import DeserializationVisitor
|
|
from apischema.objects import ObjectField
|
|
from apischema.objects.visitor import DeserializationObjectVisitor
|
|
from apischema.types import AnyType
|
|
from apischema.utils import get_origin_or_type
|
|
from apischema.visitor import Unsupported
|
|
|
|
|
|
class InitFlattenedAliasVisitor(
|
|
DeserializationObjectVisitor[Iterator[str]], DeserializationVisitor[Iterator[str]]
|
|
):
|
|
def mapping(
|
|
self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType
|
|
) -> Iterator[str]:
|
|
yield from ()
|
|
|
|
def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> Iterator[str]:
|
|
for field in fields:
|
|
if field.flattened:
|
|
yield from get_deserialization_flattened_aliases(
|
|
get_origin_or_type(tp), field, self.default_conversion
|
|
)
|
|
elif not field.is_aggregate:
|
|
yield field.alias
|
|
|
|
def _visited_union(self, results: Sequence[Iterator[str]]) -> Iterator[str]:
|
|
if len(results) != 1:
|
|
raise NotImplementedError
|
|
return results[0]
|
|
|
|
|
|
def get_deserialization_flattened_aliases(
|
|
cls: Type, field: ObjectField, default_conversion: DefaultConversion
|
|
) -> Iterator[str]:
|
|
assert field.flattened
|
|
try:
|
|
yield from InitFlattenedAliasVisitor(default_conversion).visit_with_conv(
|
|
field.type, field.deserialization
|
|
)
|
|
except (NotImplementedError, Unsupported):
|
|
raise TypeError(
|
|
f"Flattened field {cls.__name__}.{field.name} must have an object type"
|
|
) from None
|