tokencrawler/.venv/lib/python3.9/site-packages/apischema/graphql/relay/mutations.py
2022-03-17 22:16:30 +01:00

136 lines
5.1 KiB
Python

from dataclasses import Field, MISSING, field, make_dataclass
from functools import wraps
from inspect import Parameter, signature
from typing import (
Awaitable,
Callable,
ClassVar,
Collection,
Iterator,
List,
NewType,
Optional,
Tuple,
Type,
TypeVar,
)
from graphql.pyutils import camel_to_snake
from apischema.aliases import alias
from apischema.graphql.schema import Mutation as Mutation_
from apischema.schemas import Schema
from apischema.serialization.serialized_methods import ErrorHandler
from apischema.type_names import type_name
from apischema.types import AnyType, Undefined
from apischema.typing import get_type_hints
from apischema.utils import is_async, is_union_of, wrap_generic_init_subclass
ClientMutationId = NewType("ClientMutationId", str)
type_name(None)(ClientMutationId)
CLIENT_MUTATION_ID = "client_mutation_id"
M = TypeVar("M", bound="Mutation")
class Mutation:
_error_handler: ClassVar[ErrorHandler] = Undefined
_schema: ClassVar[Optional[Schema]] = None
_client_mutation_id: ClassVar[Optional[bool]] = None
_mutation: ClassVar[Mutation_] # set in __init_subclass__
# Mutate is not defined to prevent Mypy warning about signature of superclass
mutate: ClassVar[Callable]
@wrap_generic_init_subclass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not hasattr(cls, "mutate"):
return
if not isinstance(cls.__dict__["mutate"], (classmethod, staticmethod)):
raise TypeError(f"{cls.__name__}.mutate must be a classmethod/staticmethod")
mutate = getattr(cls, "mutate")
type_name(f"{cls.__name__}Payload")(cls)
types = get_type_hints(mutate, localns={cls.__name__: cls}, include_extras=True)
async_mutate = is_async(mutate, types)
fields: List[Tuple[str, AnyType, Field]] = []
cmi_param = None
for param_name, param in signature(mutate).parameters.items():
if param.kind is Parameter.POSITIONAL_ONLY:
raise TypeError("Positional only parameters are not supported")
if param.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY}:
if param_name not in types:
raise TypeError("Mutation parameters must be typed")
field_type = types[param_name]
field_ = MISSING if param.default is Parameter.empty else param.default
if is_union_of(field_type, ClientMutationId):
cmi_param = param_name
if cls._client_mutation_id is False:
if field_ is MISSING:
raise TypeError(
"Cannot have a ClientMutationId parameter"
" when _client_mutation_id = False"
)
continue
elif cls._client_mutation_id is True:
field_ = MISSING
field_ = field(default=field_, metadata=alias(CLIENT_MUTATION_ID))
fields.append((param_name, field_type, field_))
field_names = [name for (name, _, _) in fields]
if cmi_param is None and cls._client_mutation_id is not False:
fields.append(
(
CLIENT_MUTATION_ID,
ClientMutationId
if cls._client_mutation_id
else Optional[ClientMutationId],
MISSING if cls._client_mutation_id else None,
)
)
cmi_param = CLIENT_MUTATION_ID
input_cls = make_dataclass(f"{cls.__name__}Input", fields)
def wrapper(input):
return mutate(**{name: getattr(input, name) for name in field_names})
wrapper.__annotations__["input"] = input_cls
wrapper.__annotations__["return"] = Awaitable[cls] if async_mutate else cls
if cls._client_mutation_id is not False:
cls.__annotations__[CLIENT_MUTATION_ID] = input_cls.__annotations__[
cmi_param
]
setattr(cls, CLIENT_MUTATION_ID, field(init=False))
wrapped = wrapper
if async_mutate:
async def wrapper(input):
result = await wrapped(input)
setattr(result, CLIENT_MUTATION_ID, getattr(input, cmi_param))
return result
else:
def wrapper(input):
result = wrapped(input)
setattr(result, CLIENT_MUTATION_ID, getattr(input, cmi_param))
return result
wrapper = wraps(wrapped)(wrapper)
cls._mutation = Mutation_(
function=wrapper,
alias=camel_to_snake(cls.__name__),
schema=cls._schema,
error_handler=cls._error_handler,
)
def _mutations(cls: Type[Mutation] = Mutation) -> Iterator[Type[Mutation]]:
for base in cls.__subclasses__():
if hasattr(base, "_mutation"):
yield base
yield from _mutations(base)
def mutations() -> Collection[Mutation_]:
return [mut._mutation for mut in _mutations()]