142 lines
3.8 KiB
Python
142 lines
3.8 KiB
Python
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Collection,
|
|
Dict,
|
|
List,
|
|
Mapping,
|
|
MutableMapping,
|
|
Optional,
|
|
Sequence,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
from apischema.cache import CacheAwareDict
|
|
from apischema.metadata.keys import ORDERING_METADATA
|
|
from apischema.types import MetadataMixin
|
|
from apischema.utils import stop_signature_abuse
|
|
|
|
Cls = TypeVar("Cls", bound=type)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Ordering(MetadataMixin):
|
|
key = ORDERING_METADATA
|
|
order: Optional[int] = None
|
|
after: Optional[Any] = None
|
|
before: Optional[Any] = None
|
|
|
|
def __post_init__(self):
|
|
from apischema.objects.fields import check_field_or_name
|
|
|
|
if self.after is not None:
|
|
check_field_or_name(self.after, methods=True)
|
|
if self.before is not None:
|
|
check_field_or_name(self.before, methods=True)
|
|
|
|
|
|
_order_overriding: MutableMapping[type, Mapping[Any, Ordering]] = CacheAwareDict({})
|
|
|
|
|
|
@overload
|
|
def order(__value: int) -> Ordering:
|
|
...
|
|
|
|
|
|
@overload
|
|
def order(*, after: Any) -> Ordering:
|
|
...
|
|
|
|
|
|
@overload
|
|
def order(*, before: Any) -> Ordering:
|
|
...
|
|
|
|
|
|
@overload
|
|
def order(__fields: Sequence[Any]) -> Callable[[Cls], Cls]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def order(__override: Mapping[Any, Ordering]) -> Callable[[Cls], Cls]:
|
|
...
|
|
|
|
|
|
def order(__arg=None, *, before=None, after=None):
|
|
if len([arg for arg in (__arg, before, after) if arg is not None]) != 1:
|
|
stop_signature_abuse()
|
|
if isinstance(__arg, Sequence):
|
|
__arg = {field: order(after=prev) for field, prev in zip(__arg[1:], __arg)}
|
|
if isinstance(__arg, Mapping):
|
|
if not all(isinstance(val, Ordering) for val in __arg.values()):
|
|
stop_signature_abuse()
|
|
|
|
def decorator(cls: Cls) -> Cls:
|
|
_order_overriding[cls] = __arg
|
|
return cls
|
|
|
|
return decorator
|
|
elif __arg is not None and not isinstance(__arg, int):
|
|
stop_signature_abuse()
|
|
else:
|
|
return Ordering(__arg, after, before)
|
|
|
|
|
|
def get_order_overriding(cls: type) -> Mapping[str, Ordering]:
|
|
from apischema.objects.fields import get_field_name
|
|
|
|
return {
|
|
get_field_name(field, methods=True): ordering
|
|
for sub_cls in reversed(cls.__mro__)
|
|
if sub_cls in _order_overriding
|
|
for field, ordering in _order_overriding[sub_cls].items()
|
|
}
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def sort_by_order(
|
|
cls: type,
|
|
elts: Collection[T],
|
|
name: Callable[[T], str],
|
|
order: Callable[[T], Optional[Ordering]],
|
|
) -> Sequence[T]:
|
|
from apischema.objects.fields import get_field_name
|
|
|
|
order_overriding = get_order_overriding(cls)
|
|
groups: Dict[int, List[T]] = defaultdict(list)
|
|
after: Dict[str, List[T]] = defaultdict(list)
|
|
before: Dict[str, List[T]] = defaultdict(list)
|
|
for elt in elts:
|
|
ordering = order_overriding.get(name(elt), order(elt))
|
|
if ordering is None:
|
|
groups[0].append(elt)
|
|
elif ordering.order is not None:
|
|
groups[ordering.order].append(elt)
|
|
elif ordering.after is not None:
|
|
after[get_field_name(ordering.after, methods=True)].append(elt)
|
|
elif ordering.before is not None:
|
|
before[get_field_name(ordering.before, methods=True)].append(elt)
|
|
else:
|
|
raise NotImplementedError
|
|
if not after and not before and len(groups) == 1:
|
|
return next(iter(groups.values()))
|
|
result = []
|
|
|
|
def add_to_result(elt: T):
|
|
elt_name = name(elt)
|
|
for before_elt in before[elt_name]:
|
|
add_to_result(before_elt)
|
|
result.append(elt)
|
|
for after_elt in after[elt_name]:
|
|
add_to_result(after_elt)
|
|
|
|
for value in sorted(groups):
|
|
for elt in groups[value]:
|
|
add_to_result(elt)
|
|
return result
|