218 lines
7.1 KiB
Python
218 lines
7.1 KiB
Python
import warnings
|
|
from dataclasses import ( # type: ignore
|
|
Field,
|
|
InitVar,
|
|
_FIELDS,
|
|
_FIELD_CLASSVAR,
|
|
make_dataclass,
|
|
)
|
|
from enum import Enum
|
|
from types import MappingProxyType
|
|
from typing import (
|
|
Any,
|
|
Collection,
|
|
Generic,
|
|
Mapping,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from apischema.types import (
|
|
AnyType,
|
|
COLLECTION_TYPES,
|
|
MAPPING_TYPES,
|
|
OrderedDict,
|
|
PRIMITIVE_TYPES,
|
|
)
|
|
from apischema.typing import (
|
|
get_args,
|
|
get_origin,
|
|
get_type_hints,
|
|
is_annotated,
|
|
is_literal,
|
|
is_named_tuple,
|
|
is_type_var,
|
|
is_typed_dict,
|
|
is_union,
|
|
required_keys,
|
|
resolve_type_hints,
|
|
)
|
|
from apischema.utils import PREFIX, get_origin_or_type, has_type_vars, is_dataclass
|
|
|
|
try:
|
|
from apischema.typing import Annotated
|
|
except ImportError:
|
|
Annotated = ... # type: ignore
|
|
|
|
TUPLE_TYPE = get_origin(Tuple[Any])
|
|
|
|
|
|
def dataclass_types_and_fields(
|
|
tp: AnyType,
|
|
) -> Tuple[Mapping[str, AnyType], Sequence[Field], Sequence[Field]]:
|
|
from apischema.metadata.keys import INIT_VAR_METADATA
|
|
|
|
cls = get_origin_or_type(tp)
|
|
assert is_dataclass(cls)
|
|
types = resolve_type_hints(tp)
|
|
fields, init_fields = [], []
|
|
for field in getattr(cls, _FIELDS).values():
|
|
assert isinstance(field, Field)
|
|
if field._field_type == _FIELD_CLASSVAR: # type: ignore
|
|
continue
|
|
field_type = types[field.name]
|
|
if isinstance(field_type, InitVar):
|
|
types[field.name] = field_type.type # type: ignore
|
|
init_fields.append(field)
|
|
elif field_type is InitVar:
|
|
metadata = getattr(cls, _FIELDS)[field.name].metadata
|
|
if INIT_VAR_METADATA not in metadata:
|
|
raise TypeError("Before 3.8, InitVar requires init_var metadata")
|
|
init_field = (PREFIX, metadata[INIT_VAR_METADATA], ...)
|
|
tmp_cls = make_dataclass("Tmp", [init_field], bases=(cls,)) # type: ignore
|
|
types[field.name] = get_type_hints(tmp_cls, include_extras=True)[PREFIX]
|
|
if has_type_vars(types[field.name]):
|
|
raise TypeError("Generic InitVar are not supported before 3.8")
|
|
init_fields.append(field)
|
|
else:
|
|
fields.append(field)
|
|
# Use immutable return because of cache
|
|
return MappingProxyType(types), tuple(fields), tuple(init_fields)
|
|
|
|
|
|
class Unsupported(TypeError):
|
|
def __init__(self, tp: AnyType):
|
|
self.type = tp
|
|
|
|
@property
|
|
def cls(self) -> AnyType:
|
|
warnings.warn(
|
|
"Unsupported.cls is deprecated, use Unsupported.type instead",
|
|
DeprecationWarning,
|
|
)
|
|
return self.type
|
|
|
|
|
|
Result = TypeVar("Result", covariant=True)
|
|
|
|
|
|
class Visitor(Generic[Result]):
|
|
def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Result:
|
|
if Unsupported in annotations:
|
|
raise Unsupported(Annotated[(tp, *annotations)]) # type: ignore
|
|
return self.visit(tp)
|
|
|
|
def any(self) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def collection(self, cls: Type[Collection], value_type: AnyType) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def dataclass(
|
|
self,
|
|
tp: AnyType,
|
|
types: Mapping[str, AnyType],
|
|
fields: Sequence[Field],
|
|
init_vars: Sequence[Field],
|
|
) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def enum(self, cls: Type[Enum]) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def literal(self, values: Sequence[Any]) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def mapping(
|
|
self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType
|
|
) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def named_tuple(
|
|
self, tp: AnyType, types: Mapping[str, AnyType], defaults: Mapping[str, Any]
|
|
) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def new_type(self, tp: AnyType, super_type: AnyType) -> Result:
|
|
return self.visit(super_type)
|
|
|
|
def primitive(self, cls: Type) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def subprimitive(self, cls: Type, superclass: Type) -> Result:
|
|
return self.primitive(superclass)
|
|
|
|
def tuple(self, types: Sequence[AnyType]) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def typed_dict(
|
|
self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str]
|
|
) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def union(self, alternatives: Sequence[AnyType]) -> Result:
|
|
raise NotImplementedError
|
|
|
|
def unsupported(self, tp: AnyType) -> Result:
|
|
raise Unsupported(tp)
|
|
|
|
def visit(self, tp: AnyType) -> Result:
|
|
origin, args = get_origin_or_type(tp), get_args(tp)
|
|
if args:
|
|
if is_annotated(tp):
|
|
return self.annotated(args[0], args[1:])
|
|
if is_union(origin):
|
|
return self.union(args[0]) if len(args) == 1 else self.union(args)
|
|
if origin is TUPLE_TYPE:
|
|
if len(args) < 2 or args[1] is not ...:
|
|
return self.tuple(args)
|
|
if origin in COLLECTION_TYPES:
|
|
return self.collection(origin, args[0])
|
|
if origin in MAPPING_TYPES:
|
|
return self.mapping(origin, args[0], args[1])
|
|
if is_literal(tp): # pragma: no cover py37+
|
|
return self.literal(args)
|
|
if origin in PRIMITIVE_TYPES:
|
|
return self.primitive(origin)
|
|
if is_dataclass(origin):
|
|
return self.dataclass(tp, *dataclass_types_and_fields(tp)) # type: ignore
|
|
if hasattr(origin, "__supertype__"):
|
|
return self.new_type(origin, origin.__supertype__)
|
|
if origin is Any:
|
|
return self.any()
|
|
if origin in COLLECTION_TYPES:
|
|
return self.collection(origin, Any)
|
|
if origin in MAPPING_TYPES:
|
|
return self.mapping(origin, Any, Any)
|
|
if isinstance(origin, type):
|
|
if issubclass(origin, Enum):
|
|
return self.enum(origin)
|
|
for primitive in PRIMITIVE_TYPES:
|
|
if issubclass(origin, primitive):
|
|
return self.subprimitive(origin, primitive)
|
|
# NamedTuple
|
|
if is_named_tuple(origin):
|
|
if hasattr(origin, "__annotations__"):
|
|
types = resolve_type_hints(origin)
|
|
elif hasattr(origin, "__field_types"): # pragma: no cover
|
|
types = origin.__field_types # type: ignore
|
|
else: # pragma: no cover
|
|
types = OrderedDict((f, Any) for f in origin._fields) # type: ignore # noqa: E501
|
|
return self.named_tuple(
|
|
origin, types, origin._field_defaults # type: ignore
|
|
)
|
|
if is_literal(origin): # pragma: no cover py36
|
|
return self.literal(origin.__values__) # type: ignore
|
|
if is_typed_dict(origin):
|
|
return self.typed_dict(
|
|
origin, resolve_type_hints(origin), required_keys(origin)
|
|
)
|
|
if is_type_var(origin):
|
|
if origin.__constraints__:
|
|
return self.visit(Union[origin.__constraints__])
|
|
else:
|
|
return self.any()
|
|
return self.unsupported(tp)
|