260 lines
7.7 KiB
Python
260 lines
7.7 KiB
Python
from dataclasses import Field, InitVar, MISSING, dataclass, field
|
|
from enum import Enum, auto
|
|
from types import FunctionType
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Iterable,
|
|
Mapping,
|
|
MutableMapping,
|
|
NoReturn,
|
|
Optional,
|
|
Pattern,
|
|
Sequence,
|
|
TYPE_CHECKING,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from apischema.cache import CacheAwareDict
|
|
from apischema.conversions.conversions import AnyConversion
|
|
from apischema.metadata.implem import (
|
|
ConversionMetadata,
|
|
SkipMetadata,
|
|
ValidatorsMetadata,
|
|
)
|
|
from apischema.metadata.keys import (
|
|
ALIAS_METADATA,
|
|
ALIAS_NO_OVERRIDE_METADATA,
|
|
CONVERSION_METADATA,
|
|
DEFAULT_AS_SET_METADATA,
|
|
FALL_BACK_ON_DEFAULT_METADATA,
|
|
FLATTEN_METADATA,
|
|
NONE_AS_UNDEFINED_METADATA,
|
|
ORDERING_METADATA,
|
|
POST_INIT_METADATA,
|
|
PROPERTIES_METADATA,
|
|
REQUIRED_METADATA,
|
|
SCHEMA_METADATA,
|
|
SKIP_METADATA,
|
|
VALIDATORS_METADATA,
|
|
)
|
|
from apischema.types import AnyType, ChainMap, NoneType, UndefinedType
|
|
from apischema.typing import get_args, is_annotated
|
|
from apischema.utils import (
|
|
LazyValue,
|
|
empty_dict,
|
|
get_args2,
|
|
is_union_of,
|
|
keep_annotations,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from apischema.ordering import Ordering
|
|
from apischema.schemas import Schema
|
|
from apischema.validation.validators import Validator
|
|
|
|
|
|
class FieldKind(Enum):
|
|
NORMAL = auto()
|
|
READ_ONLY = auto()
|
|
WRITE_ONLY = auto()
|
|
|
|
|
|
# Cannot reuse MISSING for dataclass field because it would be interpreted as no default
|
|
MISSING_DEFAULT = object()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ObjectField:
|
|
name: str
|
|
type: AnyType
|
|
required: bool = True
|
|
metadata: Mapping[str, Any] = field(default_factory=lambda: empty_dict)
|
|
default: InitVar[Any] = MISSING_DEFAULT
|
|
default_factory: Optional[Callable[[], Any]] = None
|
|
kind: FieldKind = FieldKind.NORMAL
|
|
|
|
def __post_init__(self, default: Any):
|
|
if REQUIRED_METADATA in self.full_metadata:
|
|
object.__setattr__(self, "required", True)
|
|
if self.default_factory is MISSING:
|
|
object.__setattr__(self, "default_factory", None)
|
|
if not self.required and self.default_factory is None:
|
|
if default is MISSING_DEFAULT:
|
|
raise ValueError("Missing default for non-required ObjectField")
|
|
object.__setattr__(self, "default_factory", LazyValue(default))
|
|
if self.none_as_undefined and is_union_of(self.type, NoneType):
|
|
new_type = Union[tuple(a for a in get_args2(self.type) if a != NoneType)] # type: ignore
|
|
object.__setattr__(self, "type", keep_annotations(new_type, self.type))
|
|
|
|
@property
|
|
def full_metadata(self) -> Mapping[str, Any]:
|
|
if not is_annotated(self.type):
|
|
return self.metadata
|
|
return ChainMap(
|
|
self.metadata,
|
|
*(
|
|
arg
|
|
for arg in reversed(get_args(self.type)[1:])
|
|
if isinstance(arg, Mapping)
|
|
),
|
|
)
|
|
|
|
@property
|
|
def additional_properties(self) -> bool:
|
|
return self.full_metadata.get(PROPERTIES_METADATA, ...) is None
|
|
|
|
@property
|
|
def alias(self) -> str:
|
|
return self.full_metadata.get(ALIAS_METADATA, self.name)
|
|
|
|
@property
|
|
def override_alias(self) -> bool:
|
|
return ALIAS_NO_OVERRIDE_METADATA not in self.full_metadata
|
|
|
|
@property
|
|
def _conversion(self) -> Optional[ConversionMetadata]:
|
|
return self.metadata.get(CONVERSION_METADATA)
|
|
|
|
@property
|
|
def default_as_set(self) -> bool:
|
|
return DEFAULT_AS_SET_METADATA in self.full_metadata
|
|
|
|
@property
|
|
def deserialization(self) -> Optional[AnyConversion]:
|
|
conversion = self._conversion
|
|
return conversion.deserialization if conversion is not None else None
|
|
|
|
@property
|
|
def fall_back_on_default(self) -> bool:
|
|
return (
|
|
FALL_BACK_ON_DEFAULT_METADATA in self.full_metadata
|
|
and self.default_factory is not None
|
|
)
|
|
|
|
@property
|
|
def flattened(self) -> bool:
|
|
return FLATTEN_METADATA in self.full_metadata
|
|
|
|
def get_default(self) -> Any:
|
|
if self.required:
|
|
raise RuntimeError("Field is required")
|
|
assert self.default_factory is not None
|
|
return self.default_factory() # type: ignore
|
|
|
|
@property
|
|
def is_aggregate(self) -> bool:
|
|
return (
|
|
self.flattened
|
|
or self.additional_properties
|
|
or self.pattern_properties is not None
|
|
)
|
|
|
|
@property
|
|
def none_as_undefined(self):
|
|
return NONE_AS_UNDEFINED_METADATA in self.full_metadata
|
|
|
|
@property
|
|
def ordering(self) -> Optional["Ordering"]:
|
|
return self.full_metadata.get(ORDERING_METADATA)
|
|
|
|
@property
|
|
def post_init(self) -> bool:
|
|
return POST_INIT_METADATA in self.full_metadata
|
|
|
|
@property
|
|
def pattern_properties(self) -> Union[Pattern, "ellipsis", None]: # noqa: F821
|
|
return self.full_metadata.get(PROPERTIES_METADATA)
|
|
|
|
@property
|
|
def schema(self) -> Optional["Schema"]:
|
|
return self.metadata.get(SCHEMA_METADATA)
|
|
|
|
@property
|
|
def serialization(self) -> Optional[AnyConversion]:
|
|
conversion = self._conversion
|
|
return conversion.serialization if conversion is not None else None
|
|
|
|
@property
|
|
def skip(self) -> SkipMetadata:
|
|
return self.metadata.get(SKIP_METADATA, SkipMetadata())
|
|
|
|
def skippable(self, default: bool, none: bool) -> bool:
|
|
return bool(
|
|
self.skip.serialization_if
|
|
or is_union_of(self.type, UndefinedType)
|
|
or (
|
|
self.default_factory is not None
|
|
and (self.skip.serialization_default or default)
|
|
)
|
|
or self.none_as_undefined
|
|
or (none and is_union_of(self.type, NoneType))
|
|
)
|
|
|
|
@property
|
|
def undefined(self) -> bool:
|
|
return is_union_of(self.type, UndefinedType)
|
|
|
|
@property
|
|
def validators(self) -> Sequence["Validator"]:
|
|
if VALIDATORS_METADATA in self.metadata:
|
|
return cast(
|
|
ValidatorsMetadata, self.metadata[VALIDATORS_METADATA]
|
|
).validators
|
|
else:
|
|
return ()
|
|
|
|
|
|
FieldOrName = Union[str, ObjectField, Field]
|
|
|
|
|
|
def _bad_field(obj: Any, methods: bool) -> NoReturn:
|
|
method_types = "property/types.FunctionType" if methods else ""
|
|
raise TypeError(
|
|
f"Expected dataclasses.Field/apischema.ObjectField/str{method_types}, found {obj}"
|
|
)
|
|
|
|
|
|
def check_field_or_name(field_or_name: Any, *, methods: bool = False):
|
|
method_types = (property, FunctionType) if methods else ()
|
|
if not isinstance(field_or_name, (str, ObjectField, Field, *method_types)):
|
|
_bad_field(field_or_name, methods)
|
|
|
|
|
|
def get_field_name(field_or_name: Any, *, methods: bool = False) -> str:
|
|
if isinstance(field_or_name, (Field, ObjectField)):
|
|
return field_or_name.name
|
|
elif isinstance(field_or_name, str):
|
|
return field_or_name
|
|
elif (
|
|
methods
|
|
and isinstance(field_or_name, property)
|
|
and field_or_name.fget is not None
|
|
):
|
|
return field_or_name.fget.__name__
|
|
elif methods and isinstance(field_or_name, FunctionType):
|
|
return field_or_name.__name__
|
|
else:
|
|
_bad_field(field_or_name, methods)
|
|
|
|
|
|
_class_fields: MutableMapping[
|
|
type, Callable[[], Sequence[ObjectField]]
|
|
] = CacheAwareDict({})
|
|
|
|
|
|
def set_object_fields(
|
|
cls: type,
|
|
fields: Union[Iterable[ObjectField], Callable[[], Sequence[ObjectField]], None],
|
|
):
|
|
if fields is None:
|
|
_class_fields.pop(cls, ...)
|
|
elif callable(fields):
|
|
_class_fields[cls] = fields
|
|
else:
|
|
_class_fields[cls] = lambda fields=tuple(fields): fields # type: ignore
|
|
|
|
|
|
def default_object_fields(cls: type) -> Optional[Sequence[ObjectField]]:
|
|
return _class_fields[cls]() if cls in _class_fields else None
|