tokencrawler/.venv/lib/python3.9/site-packages/apischema/ordering.py
2022-03-17 22:16:30 +01:00

142 lines
3.8 KiB
Python

from collections import defaultdict
from dataclasses import dataclass
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
TypeVar,
overload,
)
from apischema.cache import CacheAwareDict
from apischema.metadata.keys import ORDERING_METADATA
from apischema.types import MetadataMixin
from apischema.utils import stop_signature_abuse
Cls = TypeVar("Cls", bound=type)
@dataclass(frozen=True)
class Ordering(MetadataMixin):
key = ORDERING_METADATA
order: Optional[int] = None
after: Optional[Any] = None
before: Optional[Any] = None
def __post_init__(self):
from apischema.objects.fields import check_field_or_name
if self.after is not None:
check_field_or_name(self.after, methods=True)
if self.before is not None:
check_field_or_name(self.before, methods=True)
_order_overriding: MutableMapping[type, Mapping[Any, Ordering]] = CacheAwareDict({})
@overload
def order(__value: int) -> Ordering:
...
@overload
def order(*, after: Any) -> Ordering:
...
@overload
def order(*, before: Any) -> Ordering:
...
@overload
def order(__fields: Sequence[Any]) -> Callable[[Cls], Cls]:
...
@overload
def order(__override: Mapping[Any, Ordering]) -> Callable[[Cls], Cls]:
...
def order(__arg=None, *, before=None, after=None):
if len([arg for arg in (__arg, before, after) if arg is not None]) != 1:
stop_signature_abuse()
if isinstance(__arg, Sequence):
__arg = {field: order(after=prev) for field, prev in zip(__arg[1:], __arg)}
if isinstance(__arg, Mapping):
if not all(isinstance(val, Ordering) for val in __arg.values()):
stop_signature_abuse()
def decorator(cls: Cls) -> Cls:
_order_overriding[cls] = __arg
return cls
return decorator
elif __arg is not None and not isinstance(__arg, int):
stop_signature_abuse()
else:
return Ordering(__arg, after, before)
def get_order_overriding(cls: type) -> Mapping[str, Ordering]:
from apischema.objects.fields import get_field_name
return {
get_field_name(field, methods=True): ordering
for sub_cls in reversed(cls.__mro__)
if sub_cls in _order_overriding
for field, ordering in _order_overriding[sub_cls].items()
}
T = TypeVar("T")
def sort_by_order(
cls: type,
elts: Collection[T],
name: Callable[[T], str],
order: Callable[[T], Optional[Ordering]],
) -> Sequence[T]:
from apischema.objects.fields import get_field_name
order_overriding = get_order_overriding(cls)
groups: Dict[int, List[T]] = defaultdict(list)
after: Dict[str, List[T]] = defaultdict(list)
before: Dict[str, List[T]] = defaultdict(list)
for elt in elts:
ordering = order_overriding.get(name(elt), order(elt))
if ordering is None:
groups[0].append(elt)
elif ordering.order is not None:
groups[ordering.order].append(elt)
elif ordering.after is not None:
after[get_field_name(ordering.after, methods=True)].append(elt)
elif ordering.before is not None:
before[get_field_name(ordering.before, methods=True)].append(elt)
else:
raise NotImplementedError
if not after and not before and len(groups) == 1:
return next(iter(groups.values()))
result = []
def add_to_result(elt: T):
elt_name = name(elt)
for before_elt in before[elt_name]:
add_to_result(before_elt)
result.append(elt)
for after_elt in after[elt_name]:
add_to_result(after_elt)
for value in sorted(groups):
for elt in groups[value]:
add_to_result(elt)
return result