import collections.abc import inspect import re import sys import warnings from contextlib import contextmanager, suppress from dataclasses import dataclass, is_dataclass from enum import Enum from functools import wraps from types import MappingProxyType from typing import ( AbstractSet, Any, Awaitable, Callable, Collection, Container, Dict, Generic, Hashable, Iterable, Iterator, List, Mapping, MutableMapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, cast, ) from apischema.types import ( AnyType, COLLECTION_TYPES, MAPPING_TYPES, OrderedDict, PRIMITIVE_TYPES, ) from apischema.typing import ( _collect_type_vars, generic_mro, get_args, get_origin, get_type_hints, is_annotated, is_type_var, is_union, typing_origin, ) try: from apischema.typing import Annotated except ImportError: Annotated = ... # type: ignore PREFIX = "_apischema_" T = TypeVar("T") U = TypeVar("U") def identity(x: T) -> T: return x Lazy = Callable[[], T] @dataclass(frozen=True) # dataclass enable equality check class LazyValue(Generic[T]): default: T def __call__(self) -> T: return self.default if sys.version_info <= (3, 7): # pragma: no cover is_dataclass_ = is_dataclass def is_dataclass(obj) -> bool: return is_dataclass_(obj) and getattr(obj, "__origin__", None) is None def is_hashable(obj: Any) -> bool: return isinstance(obj, collections.abc.Hashable) def opt_or(opt: Optional[T], default: U) -> Union[T, U]: return opt if opt is not None else default def to_hashable(data: Union[None, int, float, str, bool, list, dict]) -> Hashable: if isinstance(data, list): return tuple(map(to_hashable, data)) if isinstance(data, dict): return tuple(sorted((to_hashable(k), to_hashable(v)) for k, v in data.items())) return data # type: ignore SNAKE_CASE_REGEX = re.compile(r"_([a-z\d])") CAMEL_CASE_REGEX = re.compile(r"[a-z\d]([A-Z])") def to_camel_case(s: str) -> str: return SNAKE_CASE_REGEX.sub(lambda m: m.group(1).upper(), s) def to_snake_case(s: str) -> str: return CAMEL_CASE_REGEX.sub(lambda m: "_" + m.group(1).lower(), s) def to_pascal_case(s: str) -> str: camel = to_camel_case(s) return camel[0].upper() + camel[1:] if camel else camel MakeDataclassField = Union[Tuple[str, AnyType], Tuple[str, AnyType, Any]] def merge_opts( func: Callable[[T, T], T] ) -> Callable[[Optional[T], Optional[T]], Optional[T]]: def wrapper(opt1, opt2): if opt1 is None: return opt2 if opt2 is None: return opt1 return func(opt1, opt2) return wrapper K = TypeVar("K") V = TypeVar("V") @merge_opts def merge_opts_mapping(m1: Mapping[K, V], m2: Mapping[K, V]) -> Mapping[K, V]: return {**m1, **m2} def has_type_vars(tp: AnyType) -> bool: return is_type_var(tp) or bool(getattr(tp, "__parameters__", ())) TV = AnyType # TypeVar is not supported as a type # 10 should be enough for all builtin types _type_vars = [TypeVar(f"T{i}") for i in range(10)] def get_parameters(tp: AnyType) -> Iterable[TV]: if hasattr(tp, "__parameters__"): return tp.__parameters__ elif hasattr(tp, "__orig_bases__"): return _collect_type_vars(tp.__orig_bases__) elif is_type_var(tp): return (tp,) else: return _type_vars def substitute_type_vars(tp: AnyType, substitution: Mapping[TV, AnyType]) -> AnyType: if is_type_var(tp): try: return substitution[tp] except KeyError: return Union[tp.__constraints__] if tp.__constraints__ else Any elif getattr(tp, "__parameters__", ()): return (Union if is_union(tp) else tp)[ tuple(substitution.get(p, p) for p in tp.__parameters__) ] else: return tp Func = TypeVar("Func", bound=Callable) def typed_wraps(wrapped: Func) -> Callable[[Callable], Func]: return cast(Func, wraps(wrapped)) def is_subclass(tp: AnyType, base: AnyType) -> bool: tp, base = get_origin_or_type(tp), get_origin_or_type(base) return tp == base or ( isinstance(tp, type) and isinstance(base, type) and issubclass(tp, base) ) def _annotated(tp: AnyType) -> AnyType: return get_args(tp)[0] if is_annotated(tp) else tp def get_origin_or_type(tp: AnyType) -> AnyType: origin = get_origin(tp) return origin if origin is not None else tp def get_origin2(tp: AnyType) -> Optional[Type]: return get_origin(_annotated(tp)) def get_args2(tp: AnyType) -> Tuple[AnyType, ...]: return get_args(_annotated(tp)) def get_origin_or_type2(tp: AnyType) -> AnyType: tp2 = _annotated(tp) origin = get_origin(tp2) return origin if origin is not None else tp2 def keep_annotations(tp: AnyType, annotated: AnyType) -> AnyType: return Annotated[(tp, *get_args(annotated)[1:])] if is_annotated(annotated) else tp def with_parameters(tp: AnyType) -> AnyType: return tp[tp.__parameters__] if getattr(tp, "__parameters__", ()) else tp def is_union_of(tp: AnyType, of: AnyType) -> bool: return tp == of or (is_union(get_origin_or_type2(tp)) and of in get_args2(tp)) if sys.version_info < (3, 7): LIST_ORIGIN = List SET_ORIGIN = Set TUPLE_ORIGIN = Tuple DICT_ORIGIN = Dict else: LIST_ORIGIN = typing_origin(list) SET_ORIGIN = typing_origin(set) TUPLE_ORIGIN = typing_origin(tuple) DICT_ORIGIN = typing_origin(dict) def replace_builtins(tp: AnyType) -> AnyType: origin = get_origin2(tp) if origin is None: return tp args = tuple(map(replace_builtins, get_args2(tp))) replacement: Any if origin in COLLECTION_TYPES: if issubclass(origin, collections.abc.Set): replacement = SET_ORIGIN elif issubclass(origin, tuple) and (len(args) < 2 or args[1] is not ...): replacement = TUPLE_ORIGIN else: replacement = LIST_ORIGIN elif origin in MAPPING_TYPES: replacement = DICT_ORIGIN elif is_union(origin): replacement = Union else: replacement = typing_origin(origin) res = replacement[args] if args else replacement return keep_annotations(res, tp) def sort_by_annotations_position( cls: Type, elts: Collection[T], key: Callable[[T], str] ) -> List[T]: annotations: Dict[str, Any] = OrderedDict() for base in reversed(cls.__mro__): annotations.update(getattr(base, "__annotations__", ())) positions = {key: i for i, key in enumerate(annotations)} return sorted(elts, key=lambda elt: positions.get(key(elt), len(positions))) def stop_signature_abuse() -> NoReturn: raise TypeError("Stop signature abuse") empty_dict: Mapping[str, Any] = MappingProxyType({}) ITERABLE_TYPES = { *COLLECTION_TYPES, *MAPPING_TYPES, Iterable, collections.abc.Iterable, Container, collections.abc.Container, } def subtyping_substitution( supertype: AnyType, subtype: AnyType ) -> Tuple[Mapping[AnyType, AnyType], Mapping[AnyType, AnyType]]: if not get_args(subtype) and not isinstance(subtype, type): return {}, {} supertype, subtype = with_parameters(supertype), with_parameters(subtype) supertype_to_subtype, subtype_to_supertype = {}, {} super_origin = get_origin_or_type2(supertype) for base in generic_mro(subtype): base_origin = get_origin_or_type2(base) if base_origin == super_origin or ( base_origin in ITERABLE_TYPES and super_origin in ITERABLE_TYPES ): for base_arg, super_arg in zip(get_args2(base), get_args2(supertype)): if is_type_var(super_arg): supertype_to_subtype[super_arg] = base_arg if is_type_var(base_arg): subtype_to_supertype[base_arg] = super_arg break return supertype_to_subtype, subtype_to_supertype def literal_values(values: Sequence[Any]) -> Sequence[Any]: primitive_values = [v.value if isinstance(v, Enum) else v for v in values] if any(not isinstance(v, PRIMITIVE_TYPES) for v in primitive_values): raise TypeError("Only primitive types are supported for Literal/Enum") return primitive_values awaitable_origin = get_origin(Awaitable[Any]) def is_async(func: Callable, types: Mapping[str, AnyType] = None) -> bool: wrapped_func = func while hasattr(wrapped_func, "__wrapped__"): wrapped_func = wrapped_func.__wrapped__ # type: ignore if inspect.iscoroutinefunction(wrapped_func): return True if types is None: try: types = get_type_hints(func) except Exception: types = {} return get_origin_or_type2(types.get("return")) == awaitable_origin @contextmanager def context_setter(obj: Any): dict_copy = obj.__dict__.copy() try: yield finally: obj.__dict__.clear() obj.__dict__.update(dict_copy) def wrap_generic_init_subclass(init_subclass: Func) -> Func: if sys.version_info >= (3, 7): return init_subclass @wraps(init_subclass) def wrapper(cls, **kwargs): if getattr(cls, "__origin__", None) is not None: super(cls).__init_subclass__(**kwargs) return init_subclass(cls, **kwargs) return wrapper # # Because hash of generic classes is changed by metaclass after __init_subclass__ # # classes registered in global dictionaries are no more accessible. Here is a dictionary # # wrapper to fix this issue if sys.version_info < (3, 7): K = TypeVar("K") V = TypeVar("V") class KeyWrapper: def __init__(self, key): self.key = key def __eq__(self, other): return self.key == self.key def __hash__(self): return hash( id(self.key) if getattr(self.key, "__origin__", ...) is None else self.key ) class type_dict_wrapper(MutableMapping[K, V]): def __init__(self, wrapped: Dict[K, V]): self.wrapped = cast(Dict[KeyWrapper, V], wrapped) def __delitem__(self, key: K) -> None: del self.wrapped[KeyWrapper(key)] def __getitem__(self, key: K) -> V: return self.wrapped[KeyWrapper(key)] def __iter__(self) -> Iterator[K]: return iter(wrapper.key for wrapper in list(self.wrapped)) def __len__(self) -> int: return len(self.wrapped) def __setitem__(self, key: K, value: V): self.wrapped[KeyWrapper(key)] = value else: M = TypeVar("M", bound=MutableMapping) def type_dict_wrapper(wrapped: M) -> M: return wrapped def deprecate_kwargs( parameters_map: Mapping[str, Optional[str]] ) -> Callable[[Func], Func]: def decorator(func: Func) -> Func: wrapped = func.__init__ if isinstance(func, type) else func # type: ignore def wrapper(*args, **kwargs): for param, replacement in parameters_map.items(): if param in kwargs: instead = f", use '{replacement}' instead" if replacement else "" warnings.warn( f"{func.__name__} parameter '{param}' is deprecated{instead}", DeprecationWarning, ) arg = kwargs.pop(param) if replacement: kwargs[replacement] = kwargs.get(replacement, arg) return wrapped(*args, **kwargs) if isinstance(func, type): func.__init__ = wraps(func.__init__)(wrapper) # type: ignore return cast(Func, func) else: return cast(Func, wraps(func)(wrapper)) return decorator def as_predicate( collection_or_predicate: Union[Collection[T], Callable[[T], bool]] ) -> Callable[[T], bool]: if not isinstance(collection_or_predicate, Collection): return collection_or_predicate collection = collection_or_predicate if not isinstance(collection, AbstractSet): with suppress(Exception): collection = set(collection) def wrapper(elt: T) -> bool: try: return elt in collection except Exception: return False return wrapper