tokencrawler/.venv/lib/python3.9/site-packages/apischema/graphql/resolvers.py
2022-03-17 22:16:30 +01:00

336 lines
10 KiB
Python

from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from inspect import Parameter, signature
from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterator,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
overload,
)
import graphql
from apischema import UndefinedType
from apischema.aliases import Aliaser
from apischema.cache import CacheAwareDict, cache
from apischema.conversions import Conversion
from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.deserialization import deserialization_method
from apischema.methods import method_registerer
from apischema.objects import ObjectField
from apischema.ordering import Ordering
from apischema.schemas import Schema
from apischema.serialization import (
PassThroughOptions,
SerializationMethod,
SerializationMethodVisitor,
)
from apischema.serialization.serialized_methods import (
ErrorHandler,
SerializedMethod,
_get_methods,
serialized as register_serialized,
)
from apischema.types import AnyType, NoneType, Undefined
from apischema.typing import is_type
from apischema.utils import (
awaitable_origin,
deprecate_kwargs,
empty_dict,
get_args2,
get_origin_or_type2,
identity,
is_async,
is_union_of,
keep_annotations,
)
from apischema.validation.errors import ValidationError
class PartialSerializationMethodVisitor(SerializationMethodVisitor):
use_cache = False
@property
def _factory(self) -> Callable[[type], SerializationMethod]:
return lambda _: identity
def enum(self, cls: Type[Enum]) -> SerializationMethod:
return identity
def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod:
return identity
def visit(self, tp: AnyType) -> SerializationMethod:
if tp is UndefinedType:
return lambda obj: None
return super().visit(tp)
@cache
def partial_serialization_method_factory(
aliaser: Aliaser,
conversion: Optional[AnyConversion],
default_conversion: DefaultConversion,
) -> Callable[[AnyType], SerializationMethod]:
@lru_cache()
def factory(tp: AnyType) -> SerializationMethod:
return PartialSerializationMethodVisitor(
additional_properties=False,
aliaser=aliaser,
check_type=False,
default_conversion=default_conversion,
exclude_defaults=False,
exclude_none=False,
exclude_unset=False,
fall_back_on_any=False,
pass_through_options=PassThroughOptions(),
).visit_with_conv(tp, conversion)
return factory
def unwrap_awaitable(tp: AnyType) -> AnyType:
if get_origin_or_type2(tp) == awaitable_origin:
return keep_annotations(get_args2(tp)[0] if get_args2(tp) else Any, tp)
else:
return tp
@dataclass(frozen=True)
class Resolver(SerializedMethod):
parameters: Sequence[Parameter]
parameters_metadata: Mapping[str, Mapping]
def error_type(self) -> AnyType:
return unwrap_awaitable(super().error_type())
def return_type(self, return_type: AnyType) -> AnyType:
return super().return_type(unwrap_awaitable(return_type))
_resolvers: MutableMapping[Type, Dict[str, Resolver]] = CacheAwareDict(
defaultdict(dict)
)
def get_resolvers(tp: AnyType) -> Collection[Tuple[Resolver, Mapping[str, AnyType]]]:
return _get_methods(tp, _resolvers)
def none_error_handler(
__error: Exception, __obj: Any, __info: graphql.GraphQLResolveInfo, **kwargs
) -> None:
return None
def resolver_parameters(
resolver: Callable, *, check_first: bool
) -> Iterator[Parameter]:
first = True
for param in signature(resolver).parameters.values():
if param.kind is Parameter.POSITIONAL_ONLY:
raise TypeError("Resolver can not have positional only parameters")
if param.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY}:
if param.annotation is Parameter.empty and (check_first or not first):
raise TypeError("Resolver parameters must be typed")
yield param
first = False
MethodOrProp = TypeVar("MethodOrProp", Callable, property)
@overload
def resolver(__method_or_property: MethodOrProp) -> MethodOrProp:
...
@overload
def resolver(
alias: str = None,
*,
conversion: AnyConversion = None,
error_handler: ErrorHandler = Undefined,
order: Optional[Ordering] = None,
schema: Schema = None,
parameters_metadata: Mapping[str, Mapping] = None,
serialized: bool = False,
owner: Type = None,
) -> Callable[[MethodOrProp], MethodOrProp]:
...
@deprecate_kwargs({"conversions": "conversion"})
def resolver(
__arg=None,
*,
alias: str = None,
conversion: AnyConversion = None,
error_handler: ErrorHandler = Undefined,
order: Optional[Ordering] = None,
schema: Schema = None,
parameters_metadata: Mapping[str, Mapping] = None,
serialized: bool = False,
owner: Type = None,
):
def register(func: Callable, owner: Type, alias2: str):
alias2 = alias or alias2
_, *parameters = resolver_parameters(func, check_first=owner is None)
error_handler2 = error_handler
if error_handler2 is None:
error_handler2 = none_error_handler
elif error_handler2 is Undefined:
error_handler2 = None
resolver = Resolver(
func,
alias2,
conversion,
error_handler2,
order,
schema,
parameters,
parameters_metadata or {},
)
_resolvers[owner][alias2] = resolver
if serialized:
if is_async(func):
raise TypeError("Async resolver cannot be used as a serialized method")
try:
register_serialized(
alias=alias2,
conversion=conversion,
schema=schema,
error_handler=error_handler,
owner=owner,
)(func)
except Exception:
raise TypeError("Resolver cannot be used as a serialized method")
if isinstance(__arg, str):
alias = __arg
__arg = None
return method_registerer(__arg, owner, register)
T = TypeVar("T")
U = TypeVar("U")
def as_async(func: Callable[[T], U]) -> Callable[[Awaitable[T]], Awaitable[U]]:
async def wrapper(arg: Awaitable[T]) -> U:
return func(await arg)
return wrapper
def resolver_resolve(
resolver: Resolver,
types: Mapping[str, AnyType],
aliaser: Aliaser,
default_deserialization: DefaultConversion,
default_serialization: DefaultConversion,
serialized: bool = True,
) -> Callable:
# graphql deserialization will give Enum objects instead of strings
def handle_enum(tp: AnyType) -> Optional[AnyConversion]:
if is_type(tp) and issubclass(tp, Enum):
return Conversion(identity, source=Any, target=tp)
return default_deserialization(tp)
parameters, info_parameter = [], None
for param in resolver.parameters:
param_type = types[param.name]
if is_union_of(param_type, graphql.GraphQLResolveInfo):
info_parameter = param.name
else:
param_field = ObjectField(
param.name,
param_type,
param.default is Parameter.empty,
resolver.parameters_metadata.get(param.name, empty_dict),
param.default,
)
deserializer = deserialization_method(
param_type,
additional_properties=False,
aliaser=aliaser,
coerce=False,
conversion=param_field.deserialization,
default_conversion=handle_enum,
fall_back_on_default=False,
schema=param_field.schema,
)
opt_param = is_union_of(param_type, NoneType) or param.default is None
parameters.append(
(
aliaser(param_field.alias),
param.name,
deserializer,
opt_param,
param_field.required,
)
)
func, error_handler = resolver.func, resolver.error_handler
method_factory = partial_serialization_method_factory(
aliaser, resolver.conversion, default_serialization
)
serialize_result: Callable[[Any], Any]
if not serialized:
serialize_result = identity
elif is_async(resolver.func):
serialize_result = as_async(method_factory(types["return"]))
else:
serialize_result = method_factory(types["return"])
serialize_error: Optional[Callable[[Any], Any]]
if error_handler is None:
serialize_error = None
elif is_async(error_handler):
serialize_error = as_async(method_factory(resolver.error_type()))
else:
serialize_error = method_factory(resolver.error_type())
def resolve(__self, __info, **kwargs):
values = {}
errors: Dict[str, ValidationError] = {}
for alias, param_name, deserializer, opt_param, required in parameters:
if alias in kwargs:
# It is possible for the parameter to be non-optional in Python
# type hints but optional in the generated schema. In this case
# we should ignore it.
# See: https://github.com/wyfo/apischema/pull/130#issuecomment-845497392
if not opt_param and kwargs[alias] is None:
assert not required
continue
try:
values[param_name] = deserializer(kwargs[alias])
except ValidationError as err:
errors[aliaser(param_name)] = err
elif opt_param and required:
values[param_name] = None
if errors:
raise ValueError(ValidationError(children=errors).errors)
if info_parameter:
values[info_parameter] = __info
try:
return serialize_result(func(__self, **values))
except Exception as error:
if error_handler is None:
raise
assert serialize_error is not None
return serialize_error(error_handler(error, __self, __info, **kwargs))
return resolve