tokencrawler/.venv/lib/python3.9/site-packages/apischema/visitor.py
2022-03-17 22:16:30 +01:00

218 lines
7.1 KiB
Python

import warnings
from dataclasses import ( # type: ignore
Field,
InitVar,
_FIELDS,
_FIELD_CLASSVAR,
make_dataclass,
)
from enum import Enum
from types import MappingProxyType
from typing import (
Any,
Collection,
Generic,
Mapping,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
from apischema.types import (
AnyType,
COLLECTION_TYPES,
MAPPING_TYPES,
OrderedDict,
PRIMITIVE_TYPES,
)
from apischema.typing import (
get_args,
get_origin,
get_type_hints,
is_annotated,
is_literal,
is_named_tuple,
is_type_var,
is_typed_dict,
is_union,
required_keys,
resolve_type_hints,
)
from apischema.utils import PREFIX, get_origin_or_type, has_type_vars, is_dataclass
try:
from apischema.typing import Annotated
except ImportError:
Annotated = ... # type: ignore
TUPLE_TYPE = get_origin(Tuple[Any])
def dataclass_types_and_fields(
tp: AnyType,
) -> Tuple[Mapping[str, AnyType], Sequence[Field], Sequence[Field]]:
from apischema.metadata.keys import INIT_VAR_METADATA
cls = get_origin_or_type(tp)
assert is_dataclass(cls)
types = resolve_type_hints(tp)
fields, init_fields = [], []
for field in getattr(cls, _FIELDS).values():
assert isinstance(field, Field)
if field._field_type == _FIELD_CLASSVAR: # type: ignore
continue
field_type = types[field.name]
if isinstance(field_type, InitVar):
types[field.name] = field_type.type # type: ignore
init_fields.append(field)
elif field_type is InitVar:
metadata = getattr(cls, _FIELDS)[field.name].metadata
if INIT_VAR_METADATA not in metadata:
raise TypeError("Before 3.8, InitVar requires init_var metadata")
init_field = (PREFIX, metadata[INIT_VAR_METADATA], ...)
tmp_cls = make_dataclass("Tmp", [init_field], bases=(cls,)) # type: ignore
types[field.name] = get_type_hints(tmp_cls, include_extras=True)[PREFIX]
if has_type_vars(types[field.name]):
raise TypeError("Generic InitVar are not supported before 3.8")
init_fields.append(field)
else:
fields.append(field)
# Use immutable return because of cache
return MappingProxyType(types), tuple(fields), tuple(init_fields)
class Unsupported(TypeError):
def __init__(self, tp: AnyType):
self.type = tp
@property
def cls(self) -> AnyType:
warnings.warn(
"Unsupported.cls is deprecated, use Unsupported.type instead",
DeprecationWarning,
)
return self.type
Result = TypeVar("Result", covariant=True)
class Visitor(Generic[Result]):
def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Result:
if Unsupported in annotations:
raise Unsupported(Annotated[(tp, *annotations)]) # type: ignore
return self.visit(tp)
def any(self) -> Result:
raise NotImplementedError
def collection(self, cls: Type[Collection], value_type: AnyType) -> Result:
raise NotImplementedError
def dataclass(
self,
tp: AnyType,
types: Mapping[str, AnyType],
fields: Sequence[Field],
init_vars: Sequence[Field],
) -> Result:
raise NotImplementedError
def enum(self, cls: Type[Enum]) -> Result:
raise NotImplementedError
def literal(self, values: Sequence[Any]) -> Result:
raise NotImplementedError
def mapping(
self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType
) -> Result:
raise NotImplementedError
def named_tuple(
self, tp: AnyType, types: Mapping[str, AnyType], defaults: Mapping[str, Any]
) -> Result:
raise NotImplementedError
def new_type(self, tp: AnyType, super_type: AnyType) -> Result:
return self.visit(super_type)
def primitive(self, cls: Type) -> Result:
raise NotImplementedError
def subprimitive(self, cls: Type, superclass: Type) -> Result:
return self.primitive(superclass)
def tuple(self, types: Sequence[AnyType]) -> Result:
raise NotImplementedError
def typed_dict(
self, tp: AnyType, types: Mapping[str, AnyType], required_keys: Collection[str]
) -> Result:
raise NotImplementedError
def union(self, alternatives: Sequence[AnyType]) -> Result:
raise NotImplementedError
def unsupported(self, tp: AnyType) -> Result:
raise Unsupported(tp)
def visit(self, tp: AnyType) -> Result:
origin, args = get_origin_or_type(tp), get_args(tp)
if args:
if is_annotated(tp):
return self.annotated(args[0], args[1:])
if is_union(origin):
return self.union(args[0]) if len(args) == 1 else self.union(args)
if origin is TUPLE_TYPE:
if len(args) < 2 or args[1] is not ...:
return self.tuple(args)
if origin in COLLECTION_TYPES:
return self.collection(origin, args[0])
if origin in MAPPING_TYPES:
return self.mapping(origin, args[0], args[1])
if is_literal(tp): # pragma: no cover py37+
return self.literal(args)
if origin in PRIMITIVE_TYPES:
return self.primitive(origin)
if is_dataclass(origin):
return self.dataclass(tp, *dataclass_types_and_fields(tp)) # type: ignore
if hasattr(origin, "__supertype__"):
return self.new_type(origin, origin.__supertype__)
if origin is Any:
return self.any()
if origin in COLLECTION_TYPES:
return self.collection(origin, Any)
if origin in MAPPING_TYPES:
return self.mapping(origin, Any, Any)
if isinstance(origin, type):
if issubclass(origin, Enum):
return self.enum(origin)
for primitive in PRIMITIVE_TYPES:
if issubclass(origin, primitive):
return self.subprimitive(origin, primitive)
# NamedTuple
if is_named_tuple(origin):
if hasattr(origin, "__annotations__"):
types = resolve_type_hints(origin)
elif hasattr(origin, "__field_types"): # pragma: no cover
types = origin.__field_types # type: ignore
else: # pragma: no cover
types = OrderedDict((f, Any) for f in origin._fields) # type: ignore # noqa: E501
return self.named_tuple(
origin, types, origin._field_defaults # type: ignore
)
if is_literal(origin): # pragma: no cover py36
return self.literal(origin.__values__) # type: ignore
if is_typed_dict(origin):
return self.typed_dict(
origin, resolve_type_hints(origin), required_keys(origin)
)
if is_type_var(origin):
if origin.__constraints__:
return self.visit(Union[origin.__constraints__])
else:
return self.any()
return self.unsupported(tp)