tokencrawler/.venv/lib/python3.9/site-packages/apischema/utils.py
2022-03-17 22:16:30 +01:00

452 lines
12 KiB
Python

import collections.abc
import inspect
import re
import sys
import warnings
from contextlib import contextmanager, suppress
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import wraps
from types import MappingProxyType
from typing import (
AbstractSet,
Any,
Awaitable,
Callable,
Collection,
Container,
Dict,
Generic,
Hashable,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from apischema.types import (
AnyType,
COLLECTION_TYPES,
MAPPING_TYPES,
OrderedDict,
PRIMITIVE_TYPES,
)
from apischema.typing import (
_collect_type_vars,
generic_mro,
get_args,
get_origin,
get_type_hints,
is_annotated,
is_type_var,
is_union,
typing_origin,
)
try:
from apischema.typing import Annotated
except ImportError:
Annotated = ... # type: ignore
PREFIX = "_apischema_"
T = TypeVar("T")
U = TypeVar("U")
def identity(x: T) -> T:
return x
Lazy = Callable[[], T]
@dataclass(frozen=True) # dataclass enable equality check
class LazyValue(Generic[T]):
default: T
def __call__(self) -> T:
return self.default
if sys.version_info <= (3, 7): # pragma: no cover
is_dataclass_ = is_dataclass
def is_dataclass(obj) -> bool:
return is_dataclass_(obj) and getattr(obj, "__origin__", None) is None
def is_hashable(obj: Any) -> bool:
return isinstance(obj, collections.abc.Hashable)
def opt_or(opt: Optional[T], default: U) -> Union[T, U]:
return opt if opt is not None else default
def to_hashable(data: Union[None, int, float, str, bool, list, dict]) -> Hashable:
if isinstance(data, list):
return tuple(map(to_hashable, data))
if isinstance(data, dict):
return tuple(sorted((to_hashable(k), to_hashable(v)) for k, v in data.items()))
return data # type: ignore
SNAKE_CASE_REGEX = re.compile(r"_([a-z\d])")
CAMEL_CASE_REGEX = re.compile(r"[a-z\d]([A-Z])")
def to_camel_case(s: str) -> str:
return SNAKE_CASE_REGEX.sub(lambda m: m.group(1).upper(), s)
def to_snake_case(s: str) -> str:
return CAMEL_CASE_REGEX.sub(lambda m: "_" + m.group(1).lower(), s)
def to_pascal_case(s: str) -> str:
camel = to_camel_case(s)
return camel[0].upper() + camel[1:] if camel else camel
MakeDataclassField = Union[Tuple[str, AnyType], Tuple[str, AnyType, Any]]
def merge_opts(
func: Callable[[T, T], T]
) -> Callable[[Optional[T], Optional[T]], Optional[T]]:
def wrapper(opt1, opt2):
if opt1 is None:
return opt2
if opt2 is None:
return opt1
return func(opt1, opt2)
return wrapper
K = TypeVar("K")
V = TypeVar("V")
@merge_opts
def merge_opts_mapping(m1: Mapping[K, V], m2: Mapping[K, V]) -> Mapping[K, V]:
return {**m1, **m2}
def has_type_vars(tp: AnyType) -> bool:
return is_type_var(tp) or bool(getattr(tp, "__parameters__", ()))
TV = AnyType # TypeVar is not supported as a type
# 10 should be enough for all builtin types
_type_vars = [TypeVar(f"T{i}") for i in range(10)]
def get_parameters(tp: AnyType) -> Iterable[TV]:
if hasattr(tp, "__parameters__"):
return tp.__parameters__
elif hasattr(tp, "__orig_bases__"):
return _collect_type_vars(tp.__orig_bases__)
elif is_type_var(tp):
return (tp,)
else:
return _type_vars
def substitute_type_vars(tp: AnyType, substitution: Mapping[TV, AnyType]) -> AnyType:
if is_type_var(tp):
try:
return substitution[tp]
except KeyError:
return Union[tp.__constraints__] if tp.__constraints__ else Any
elif getattr(tp, "__parameters__", ()):
return (Union if is_union(tp) else tp)[
tuple(substitution.get(p, p) for p in tp.__parameters__)
]
else:
return tp
Func = TypeVar("Func", bound=Callable)
def typed_wraps(wrapped: Func) -> Callable[[Callable], Func]:
return cast(Func, wraps(wrapped))
def is_subclass(tp: AnyType, base: AnyType) -> bool:
tp, base = get_origin_or_type(tp), get_origin_or_type(base)
return tp == base or (
isinstance(tp, type) and isinstance(base, type) and issubclass(tp, base)
)
def _annotated(tp: AnyType) -> AnyType:
return get_args(tp)[0] if is_annotated(tp) else tp
def get_origin_or_type(tp: AnyType) -> AnyType:
origin = get_origin(tp)
return origin if origin is not None else tp
def get_origin2(tp: AnyType) -> Optional[Type]:
return get_origin(_annotated(tp))
def get_args2(tp: AnyType) -> Tuple[AnyType, ...]:
return get_args(_annotated(tp))
def get_origin_or_type2(tp: AnyType) -> AnyType:
tp2 = _annotated(tp)
origin = get_origin(tp2)
return origin if origin is not None else tp2
def keep_annotations(tp: AnyType, annotated: AnyType) -> AnyType:
return Annotated[(tp, *get_args(annotated)[1:])] if is_annotated(annotated) else tp
def with_parameters(tp: AnyType) -> AnyType:
return tp[tp.__parameters__] if getattr(tp, "__parameters__", ()) else tp
def is_union_of(tp: AnyType, of: AnyType) -> bool:
return tp == of or (is_union(get_origin_or_type2(tp)) and of in get_args2(tp))
if sys.version_info < (3, 7):
LIST_ORIGIN = List
SET_ORIGIN = Set
TUPLE_ORIGIN = Tuple
DICT_ORIGIN = Dict
else:
LIST_ORIGIN = typing_origin(list)
SET_ORIGIN = typing_origin(set)
TUPLE_ORIGIN = typing_origin(tuple)
DICT_ORIGIN = typing_origin(dict)
def replace_builtins(tp: AnyType) -> AnyType:
origin = get_origin2(tp)
if origin is None:
return tp
args = tuple(map(replace_builtins, get_args2(tp)))
replacement: Any
if origin in COLLECTION_TYPES:
if issubclass(origin, collections.abc.Set):
replacement = SET_ORIGIN
elif issubclass(origin, tuple) and (len(args) < 2 or args[1] is not ...):
replacement = TUPLE_ORIGIN
else:
replacement = LIST_ORIGIN
elif origin in MAPPING_TYPES:
replacement = DICT_ORIGIN
elif is_union(origin):
replacement = Union
else:
replacement = typing_origin(origin)
res = replacement[args] if args else replacement
return keep_annotations(res, tp)
def sort_by_annotations_position(
cls: Type, elts: Collection[T], key: Callable[[T], str]
) -> List[T]:
annotations: Dict[str, Any] = OrderedDict()
for base in reversed(cls.__mro__):
annotations.update(getattr(base, "__annotations__", ()))
positions = {key: i for i, key in enumerate(annotations)}
return sorted(elts, key=lambda elt: positions.get(key(elt), len(positions)))
def stop_signature_abuse() -> NoReturn:
raise TypeError("Stop signature abuse")
empty_dict: Mapping[str, Any] = MappingProxyType({})
ITERABLE_TYPES = {
*COLLECTION_TYPES,
*MAPPING_TYPES,
Iterable,
collections.abc.Iterable,
Container,
collections.abc.Container,
}
def subtyping_substitution(
supertype: AnyType, subtype: AnyType
) -> Tuple[Mapping[AnyType, AnyType], Mapping[AnyType, AnyType]]:
if not get_args(subtype) and not isinstance(subtype, type):
return {}, {}
supertype, subtype = with_parameters(supertype), with_parameters(subtype)
supertype_to_subtype, subtype_to_supertype = {}, {}
super_origin = get_origin_or_type2(supertype)
for base in generic_mro(subtype):
base_origin = get_origin_or_type2(base)
if base_origin == super_origin or (
base_origin in ITERABLE_TYPES and super_origin in ITERABLE_TYPES
):
for base_arg, super_arg in zip(get_args2(base), get_args2(supertype)):
if is_type_var(super_arg):
supertype_to_subtype[super_arg] = base_arg
if is_type_var(base_arg):
subtype_to_supertype[base_arg] = super_arg
break
return supertype_to_subtype, subtype_to_supertype
def literal_values(values: Sequence[Any]) -> Sequence[Any]:
primitive_values = [v.value if isinstance(v, Enum) else v for v in values]
if any(not isinstance(v, PRIMITIVE_TYPES) for v in primitive_values):
raise TypeError("Only primitive types are supported for Literal/Enum")
return primitive_values
awaitable_origin = get_origin(Awaitable[Any])
def is_async(func: Callable, types: Mapping[str, AnyType] = None) -> bool:
wrapped_func = func
while hasattr(wrapped_func, "__wrapped__"):
wrapped_func = wrapped_func.__wrapped__ # type: ignore
if inspect.iscoroutinefunction(wrapped_func):
return True
if types is None:
try:
types = get_type_hints(func)
except Exception:
types = {}
return get_origin_or_type2(types.get("return")) == awaitable_origin
@contextmanager
def context_setter(obj: Any):
dict_copy = obj.__dict__.copy()
try:
yield
finally:
obj.__dict__.clear()
obj.__dict__.update(dict_copy)
def wrap_generic_init_subclass(init_subclass: Func) -> Func:
if sys.version_info >= (3, 7):
return init_subclass
@wraps(init_subclass)
def wrapper(cls, **kwargs):
if getattr(cls, "__origin__", None) is not None:
super(cls).__init_subclass__(**kwargs)
return
init_subclass(cls, **kwargs)
return wrapper
# # Because hash of generic classes is changed by metaclass after __init_subclass__
# # classes registered in global dictionaries are no more accessible. Here is a dictionary
# # wrapper to fix this issue
if sys.version_info < (3, 7):
K = TypeVar("K")
V = TypeVar("V")
class KeyWrapper:
def __init__(self, key):
self.key = key
def __eq__(self, other):
return self.key == self.key
def __hash__(self):
return hash(
id(self.key)
if getattr(self.key, "__origin__", ...) is None
else self.key
)
class type_dict_wrapper(MutableMapping[K, V]):
def __init__(self, wrapped: Dict[K, V]):
self.wrapped = cast(Dict[KeyWrapper, V], wrapped)
def __delitem__(self, key: K) -> None:
del self.wrapped[KeyWrapper(key)]
def __getitem__(self, key: K) -> V:
return self.wrapped[KeyWrapper(key)]
def __iter__(self) -> Iterator[K]:
return iter(wrapper.key for wrapper in list(self.wrapped))
def __len__(self) -> int:
return len(self.wrapped)
def __setitem__(self, key: K, value: V):
self.wrapped[KeyWrapper(key)] = value
else:
M = TypeVar("M", bound=MutableMapping)
def type_dict_wrapper(wrapped: M) -> M:
return wrapped
def deprecate_kwargs(
parameters_map: Mapping[str, Optional[str]]
) -> Callable[[Func], Func]:
def decorator(func: Func) -> Func:
wrapped = func.__init__ if isinstance(func, type) else func # type: ignore
def wrapper(*args, **kwargs):
for param, replacement in parameters_map.items():
if param in kwargs:
instead = f", use '{replacement}' instead" if replacement else ""
warnings.warn(
f"{func.__name__} parameter '{param}' is deprecated{instead}",
DeprecationWarning,
)
arg = kwargs.pop(param)
if replacement:
kwargs[replacement] = kwargs.get(replacement, arg)
return wrapped(*args, **kwargs)
if isinstance(func, type):
func.__init__ = wraps(func.__init__)(wrapper) # type: ignore
return cast(Func, func)
else:
return cast(Func, wraps(func)(wrapper))
return decorator
def as_predicate(
collection_or_predicate: Union[Collection[T], Callable[[T], bool]]
) -> Callable[[T], bool]:
if not isinstance(collection_or_predicate, Collection):
return collection_or_predicate
collection = collection_or_predicate
if not isinstance(collection, AbstractSet):
with suppress(Exception):
collection = set(collection)
def wrapper(elt: T) -> bool:
try:
return elt in collection
except Exception:
return False
return wrapper