303 lines
9.5 KiB
Python
303 lines
9.5 KiB
Python
"""Kind of typing_extensions for this package"""
|
|
__all__ = ["get_args", "get_origin", "get_type_hints"]
|
|
|
|
import sys
|
|
from types import ModuleType, new_class
|
|
from typing import ( # type: ignore
|
|
Any,
|
|
Callable,
|
|
Collection,
|
|
Dict,
|
|
Generic,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
|
|
class _FakeType:
|
|
pass
|
|
|
|
|
|
if sys.version_info >= (3, 9): # pragma: no cover
|
|
from typing import Annotated, TypedDict, get_type_hints, get_origin, get_args
|
|
else: # pragma: no cover
|
|
try:
|
|
from typing_extensions import Annotated, TypedDict
|
|
except ImportError:
|
|
if sys.version_info >= (3, 8):
|
|
from typing import TypedDict
|
|
try:
|
|
from typing_extensions import get_type_hints as gth
|
|
except ImportError:
|
|
from typing import get_type_hints as _gth
|
|
|
|
def gth(obj, globalns=None, localns=None, include_extras=False): # type: ignore
|
|
return _gth(obj, globalns, localns)
|
|
|
|
def get_type_hints( # type: ignore
|
|
obj, globalns=None, localns=None, include_extras=False
|
|
):
|
|
# TODO This has been fixed in recent 3.7 and 3.8
|
|
# fix https://bugs.python.org/issue37838
|
|
if not isinstance(obj, (type, ModuleType)) and globalns is None:
|
|
nsobj = obj
|
|
while hasattr(nsobj, "__wrapped__"):
|
|
nsobj = nsobj.__wrapped__
|
|
globalns = getattr(nsobj, "__globals__", None)
|
|
localns = {"unicode": str, **(localns or {})}
|
|
return gth(obj, globalns, localns, include_extras)
|
|
|
|
try:
|
|
from typing_extensions import get_origin, get_args
|
|
except ImportError:
|
|
|
|
def _assemble_tree(tree: Tuple[Any]) -> Any:
|
|
if not isinstance(tree, tuple):
|
|
return tree
|
|
else:
|
|
origin, *args = tree # type: ignore
|
|
if origin is Annotated:
|
|
return Annotated[(_assemble_tree(args[0]), *args[1])]
|
|
else:
|
|
return origin[tuple(map(_assemble_tree, args))]
|
|
|
|
def get_origin(tp): # type: ignore
|
|
# In Python 3.6: List[Collection[T]][int].__args__ == int != Collection[int]
|
|
if hasattr(tp, "_subs_tree"):
|
|
tp = _assemble_tree(tp._subs_tree())
|
|
if isinstance(tp, _AnnotatedAlias):
|
|
return None if tp.__args__ is None else Annotated
|
|
if tp is Generic:
|
|
return Generic
|
|
return getattr(tp, "__origin__", None)
|
|
|
|
def get_args(tp): # type: ignore
|
|
# In Python 3.6: List[Collection[T]][int].__args__ == int != Collection[int]
|
|
if hasattr(tp, "_subs_tree"):
|
|
tp = _assemble_tree(tp._subs_tree())
|
|
if isinstance(tp, _AnnotatedAlias):
|
|
return () if tp.__args__ is None else (tp.__args__[0], *tp.__metadata__)
|
|
# __args__ can be None in 3.6 inside __set_name__
|
|
res = getattr(tp, "__args__", ()) or ()
|
|
if get_origin(tp) is Callable and res[0] is not Ellipsis:
|
|
res = (list(res[:-1]), res[-1])
|
|
return res
|
|
|
|
|
|
if sys.version_info >= (3, 8): # pragma: no cover
|
|
from typing import Literal, Protocol # noqa: F401
|
|
else: # pragma: no cover
|
|
try:
|
|
from typing_extensions import Literal, Protocol # noqa: F401
|
|
except ImportError:
|
|
pass
|
|
|
|
if sys.version_info >= (3, 7):
|
|
from typing import _collect_type_vars, ForwardRef # type: ignore
|
|
else:
|
|
from typing import _type_vars, _ForwardRef
|
|
|
|
_collect_type_vars = _type_vars
|
|
|
|
def ForwardRef(arg, is_argument):
|
|
return _ForwardRef(arg)
|
|
|
|
|
|
try:
|
|
from typing import _strip_annotations # type: ignore
|
|
except ImportError:
|
|
try:
|
|
from typing_extensions import _strip_annotations # type: ignore
|
|
except ImportError:
|
|
|
|
def _strip_annotations(t):
|
|
return t
|
|
|
|
|
|
def _generic_mro(result, tp):
|
|
origin = get_origin(tp)
|
|
if origin is None:
|
|
origin = tp
|
|
result[origin] = tp
|
|
if hasattr(origin, "__orig_bases__"):
|
|
parameters = _collect_type_vars(origin.__orig_bases__)
|
|
substitution = dict(zip(parameters, get_args(tp)))
|
|
for base in origin.__orig_bases__:
|
|
if get_origin(base) in result:
|
|
continue
|
|
base_parameters = getattr(base, "__parameters__", ())
|
|
if base_parameters:
|
|
base = base[tuple(substitution.get(p, p) for p in base_parameters)]
|
|
_generic_mro(result, base)
|
|
|
|
|
|
# sentinel value to avoid to subscript Generic and Protocol
|
|
try:
|
|
BASE_GENERIC_MRO = {Generic: Generic, Protocol: Protocol}
|
|
except NameError:
|
|
BASE_GENERIC_MRO = {Generic: Generic}
|
|
|
|
|
|
def generic_mro(tp):
|
|
origin = get_origin(tp)
|
|
if origin is None and not hasattr(tp, "__orig_bases__"):
|
|
if not isinstance(tp, type):
|
|
raise TypeError(f"{tp!r} is not a type or a generic alias")
|
|
return tp.__mro__
|
|
result = BASE_GENERIC_MRO.copy()
|
|
_generic_mro(result, tp)
|
|
cls = origin if origin is not None else tp
|
|
return tuple(result.get(sub_cls, sub_cls) for sub_cls in cls.__mro__)
|
|
|
|
|
|
def resolve_type_hints(obj: Any) -> Dict[str, Any]:
|
|
"""Wrap get_type_hints to resolve type vars in case of generic inheritance.
|
|
|
|
`obj` can also be a parametrized generic class."""
|
|
origin_or_obj = get_origin(obj) or obj
|
|
if isinstance(origin_or_obj, type):
|
|
hints = {}
|
|
for base in reversed(generic_mro(obj)):
|
|
base_origin = get_origin(base) or base
|
|
base_annotations = getattr(base_origin, "__dict__", {}).get(
|
|
"__annotations__", {}
|
|
)
|
|
substitution = dict(
|
|
zip(getattr(base_origin, "__parameters__", ()), get_args(base))
|
|
)
|
|
for name, hint in get_type_hints(base_origin, include_extras=True).items():
|
|
if name not in base_annotations:
|
|
continue
|
|
if isinstance(hint, TypeVar):
|
|
hints[name] = substitution.get(hint, hint)
|
|
elif getattr(hint, "__parameters__", ()):
|
|
hints[name] = (Union if is_union(hint) else hint)[
|
|
tuple(substitution.get(p, p) for p in hint.__parameters__)
|
|
]
|
|
else:
|
|
hints[name] = hint
|
|
return hints
|
|
else:
|
|
return get_type_hints(obj, include_extras=True)
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_GenericAlias: Any = type(Generic[_T])
|
|
try:
|
|
_AnnotatedAlias: Any = type(Annotated[_T, ...])
|
|
except NameError:
|
|
_AnnotatedAlias = _FakeType
|
|
try:
|
|
|
|
class _TypedDictImplem(TypedDict):
|
|
pass
|
|
|
|
_LiteralMeta: Any = type(Literal)
|
|
_TypedDictMeta: Any = type(_TypedDictImplem)
|
|
except NameError:
|
|
_LiteralMeta, _TypedDictMeta = _FakeType, _FakeType # type: ignore
|
|
|
|
|
|
def is_new_type(tp: Any) -> bool:
|
|
return hasattr(tp, "__supertype__")
|
|
|
|
|
|
def is_annotated(tp: Any) -> bool:
|
|
try:
|
|
from typing import Annotated # type: ignore
|
|
|
|
return get_origin(tp) == Annotated
|
|
except ImportError:
|
|
try:
|
|
from typing_extensions import Annotated # type: ignore
|
|
|
|
return get_origin(tp) == Annotated
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def is_literal(tp: Any) -> bool:
|
|
try:
|
|
from typing import Literal
|
|
|
|
return get_origin(tp) == Literal or isinstance(tp, type(Literal)) # py36
|
|
except ImportError:
|
|
try:
|
|
from typing_extensions import Literal # type: ignore
|
|
|
|
return get_origin(tp) == Literal or isinstance(tp, type(Literal)) # py36
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def is_named_tuple(tp: Any) -> bool:
|
|
return issubclass(tp, tuple) and hasattr(tp, "_fields")
|
|
|
|
|
|
def is_typed_dict(tp: Any) -> bool:
|
|
try:
|
|
from typing import TypedDict
|
|
|
|
return isinstance(tp, type(new_class("_TypedDictImplem", (TypedDict,))))
|
|
except ImportError:
|
|
try:
|
|
from typing_extensions import TypedDict # type: ignore
|
|
|
|
return isinstance(tp, type(new_class("_TypedDictImplem", (TypedDict,))))
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def is_type_var(tp: Any) -> bool:
|
|
return isinstance(tp, TypeVar) # type: ignore
|
|
|
|
|
|
# Don't use sys.version_info because it can also depend of typing_extensions version
|
|
def required_keys(typed_dict: Type) -> Collection[str]:
|
|
assert is_typed_dict(typed_dict)
|
|
if hasattr(typed_dict, "__required_keys__"):
|
|
return typed_dict.__required_keys__
|
|
else:
|
|
required: Set[str] = set()
|
|
bases_annotations: Set = set()
|
|
for base in typed_dict.__bases__:
|
|
if not isinstance(base, _TypedDictMeta):
|
|
continue
|
|
bases_annotations.update(base.__annotations__)
|
|
required.update(required_keys(base))
|
|
if typed_dict.__total__: # type: ignore
|
|
required.update(typed_dict.__annotations__.keys() - bases_annotations)
|
|
return required
|
|
|
|
|
|
# py37/py38 get_origin of builtin wrapped generics return the unsubscriptable builtin
|
|
# type.
|
|
if (3, 7) <= sys.version_info < (3, 9):
|
|
import typing
|
|
|
|
TYPING_ALIASES = {
|
|
getattr(elt, "__origin__", None): elt for elt in typing.__dict__.values()
|
|
}
|
|
|
|
def typing_origin(origin: Any) -> Any:
|
|
return TYPING_ALIASES.get(origin, origin)
|
|
|
|
else:
|
|
typing_origin = lambda tp: tp
|
|
|
|
|
|
def is_type(tp: Any) -> bool:
|
|
"""isinstance is not enough because in py39: isinstance(list[int], type) == True"""
|
|
return isinstance(tp, type) and not get_args(tp)
|
|
|
|
|
|
def is_union(tp: Any) -> bool:
|
|
try:
|
|
from types import UnionType # type: ignore
|
|
|
|
return tp in (UnionType, Union)
|
|
except ImportError:
|
|
return tp is Union
|