136 lines
5.1 KiB
Python
136 lines
5.1 KiB
Python
from dataclasses import Field, MISSING, field, make_dataclass
|
|
from functools import wraps
|
|
from inspect import Parameter, signature
|
|
from typing import (
|
|
Awaitable,
|
|
Callable,
|
|
ClassVar,
|
|
Collection,
|
|
Iterator,
|
|
List,
|
|
NewType,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
)
|
|
|
|
from graphql.pyutils import camel_to_snake
|
|
|
|
from apischema.aliases import alias
|
|
from apischema.graphql.schema import Mutation as Mutation_
|
|
from apischema.schemas import Schema
|
|
from apischema.serialization.serialized_methods import ErrorHandler
|
|
from apischema.type_names import type_name
|
|
from apischema.types import AnyType, Undefined
|
|
from apischema.typing import get_type_hints
|
|
from apischema.utils import is_async, is_union_of, wrap_generic_init_subclass
|
|
|
|
ClientMutationId = NewType("ClientMutationId", str)
|
|
type_name(None)(ClientMutationId)
|
|
CLIENT_MUTATION_ID = "client_mutation_id"
|
|
M = TypeVar("M", bound="Mutation")
|
|
|
|
|
|
class Mutation:
|
|
_error_handler: ClassVar[ErrorHandler] = Undefined
|
|
_schema: ClassVar[Optional[Schema]] = None
|
|
_client_mutation_id: ClassVar[Optional[bool]] = None
|
|
_mutation: ClassVar[Mutation_] # set in __init_subclass__
|
|
|
|
# Mutate is not defined to prevent Mypy warning about signature of superclass
|
|
mutate: ClassVar[Callable]
|
|
|
|
@wrap_generic_init_subclass
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
if not hasattr(cls, "mutate"):
|
|
return
|
|
if not isinstance(cls.__dict__["mutate"], (classmethod, staticmethod)):
|
|
raise TypeError(f"{cls.__name__}.mutate must be a classmethod/staticmethod")
|
|
mutate = getattr(cls, "mutate")
|
|
type_name(f"{cls.__name__}Payload")(cls)
|
|
types = get_type_hints(mutate, localns={cls.__name__: cls}, include_extras=True)
|
|
async_mutate = is_async(mutate, types)
|
|
fields: List[Tuple[str, AnyType, Field]] = []
|
|
cmi_param = None
|
|
for param_name, param in signature(mutate).parameters.items():
|
|
if param.kind is Parameter.POSITIONAL_ONLY:
|
|
raise TypeError("Positional only parameters are not supported")
|
|
if param.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY}:
|
|
if param_name not in types:
|
|
raise TypeError("Mutation parameters must be typed")
|
|
field_type = types[param_name]
|
|
field_ = MISSING if param.default is Parameter.empty else param.default
|
|
if is_union_of(field_type, ClientMutationId):
|
|
cmi_param = param_name
|
|
if cls._client_mutation_id is False:
|
|
if field_ is MISSING:
|
|
raise TypeError(
|
|
"Cannot have a ClientMutationId parameter"
|
|
" when _client_mutation_id = False"
|
|
)
|
|
continue
|
|
elif cls._client_mutation_id is True:
|
|
field_ = MISSING
|
|
field_ = field(default=field_, metadata=alias(CLIENT_MUTATION_ID))
|
|
fields.append((param_name, field_type, field_))
|
|
field_names = [name for (name, _, _) in fields]
|
|
if cmi_param is None and cls._client_mutation_id is not False:
|
|
fields.append(
|
|
(
|
|
CLIENT_MUTATION_ID,
|
|
ClientMutationId
|
|
if cls._client_mutation_id
|
|
else Optional[ClientMutationId],
|
|
MISSING if cls._client_mutation_id else None,
|
|
)
|
|
)
|
|
cmi_param = CLIENT_MUTATION_ID
|
|
input_cls = make_dataclass(f"{cls.__name__}Input", fields)
|
|
|
|
def wrapper(input):
|
|
return mutate(**{name: getattr(input, name) for name in field_names})
|
|
|
|
wrapper.__annotations__["input"] = input_cls
|
|
wrapper.__annotations__["return"] = Awaitable[cls] if async_mutate else cls
|
|
if cls._client_mutation_id is not False:
|
|
cls.__annotations__[CLIENT_MUTATION_ID] = input_cls.__annotations__[
|
|
cmi_param
|
|
]
|
|
setattr(cls, CLIENT_MUTATION_ID, field(init=False))
|
|
wrapped = wrapper
|
|
|
|
if async_mutate:
|
|
|
|
async def wrapper(input):
|
|
result = await wrapped(input)
|
|
setattr(result, CLIENT_MUTATION_ID, getattr(input, cmi_param))
|
|
return result
|
|
|
|
else:
|
|
|
|
def wrapper(input):
|
|
result = wrapped(input)
|
|
setattr(result, CLIENT_MUTATION_ID, getattr(input, cmi_param))
|
|
return result
|
|
|
|
wrapper = wraps(wrapped)(wrapper)
|
|
|
|
cls._mutation = Mutation_(
|
|
function=wrapper,
|
|
alias=camel_to_snake(cls.__name__),
|
|
schema=cls._schema,
|
|
error_handler=cls._error_handler,
|
|
)
|
|
|
|
|
|
def _mutations(cls: Type[Mutation] = Mutation) -> Iterator[Type[Mutation]]:
|
|
for base in cls.__subclasses__():
|
|
if hasattr(base, "_mutation"):
|
|
yield base
|
|
yield from _mutations(base)
|
|
|
|
|
|
def mutations() -> Collection[Mutation_]:
|
|
return [mut._mutation for mut in _mutations()]
|