85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import Generic, Optional, Sequence, Type, TypeVar
|
|
|
|
from apischema.type_names import get_type_name, type_name
|
|
from apischema.types import NoneType
|
|
from apischema.typing import generic_mro, get_args, get_origin
|
|
from apischema.utils import get_args2, is_union_of, wrap_generic_init_subclass
|
|
|
|
Cursor_ = TypeVar("Cursor_")
|
|
Node_ = TypeVar("Node_")
|
|
|
|
|
|
def get_node_name(tp):
|
|
if is_union_of(tp, NoneType) and len(get_args2(tp)):
|
|
tp = next(arg for arg in get_args2(tp) if arg is not NoneType)
|
|
ref = get_type_name(tp).graphql
|
|
if ref is None:
|
|
raise TypeError(
|
|
f"Node {tp} must have a ref registered to be used with connection"
|
|
)
|
|
return ref
|
|
|
|
|
|
def edge_name(tp: Type["Edge"], *args) -> str:
|
|
for base in generic_mro(tp[tuple(args)] if args else tp): # type: ignore
|
|
if get_origin(base) == Edge:
|
|
return f"{get_node_name(get_args(base)[0])}Edge"
|
|
raise NotImplementedError
|
|
|
|
|
|
@type_name(graphql=edge_name)
|
|
@dataclass
|
|
class Edge(Generic[Node_, Cursor_]):
|
|
node: Node_
|
|
cursor: Cursor_
|
|
|
|
@wrap_generic_init_subclass
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
type_name(graphql=edge_name)(cls)
|
|
|
|
|
|
@type_name(graphql=lambda *_: "PageInfo")
|
|
@dataclass
|
|
class PageInfo(Generic[Cursor_]):
|
|
has_previous_page: bool = False
|
|
has_next_page: bool = False
|
|
start_cursor: Optional[Cursor_] = None
|
|
end_cursor: Optional[Cursor_] = None
|
|
|
|
@staticmethod
|
|
def from_edges(
|
|
edges: Sequence[Optional[Edge[Node_, Cursor_]]],
|
|
has_previous_page: bool = False,
|
|
has_next_page: bool = False,
|
|
) -> "PageInfo":
|
|
start_cursor, end_cursor = None, None
|
|
if edges is not None:
|
|
if edges[0] is not None:
|
|
start_cursor = edges[0].cursor
|
|
if edges[-1] is not None:
|
|
end_cursor = edges[-1].cursor
|
|
return PageInfo(has_previous_page, has_next_page, start_cursor, end_cursor)
|
|
|
|
|
|
def connection_name(tp: Type["Connection"], *args) -> str:
|
|
for base in generic_mro(tp[tuple(args)] if args else tp): # type: ignore
|
|
if get_origin(base) == Connection:
|
|
return f"{get_node_name(get_args(base)[0])}Connection"
|
|
raise NotImplementedError
|
|
|
|
|
|
Edge_ = TypeVar("Edge_", bound=Edge)
|
|
|
|
|
|
@type_name(graphql=connection_name)
|
|
@dataclass
|
|
class Connection(Generic[Node_, Cursor_, Edge_]):
|
|
edges: Optional[Sequence[Optional[Edge_]]]
|
|
page_info: PageInfo[Cursor_]
|
|
|
|
@wrap_generic_init_subclass
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
type_name(graphql=connection_name)(cls)
|