336 lines
10 KiB
Python
336 lines
10 KiB
Python
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from functools import lru_cache
|
|
from inspect import Parameter, signature
|
|
from typing import (
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
Collection,
|
|
Dict,
|
|
Iterator,
|
|
Mapping,
|
|
MutableMapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
import graphql
|
|
|
|
from apischema import UndefinedType
|
|
from apischema.aliases import Aliaser
|
|
from apischema.cache import CacheAwareDict, cache
|
|
from apischema.conversions import Conversion
|
|
from apischema.conversions.conversions import AnyConversion, DefaultConversion
|
|
from apischema.deserialization import deserialization_method
|
|
from apischema.methods import method_registerer
|
|
from apischema.objects import ObjectField
|
|
from apischema.ordering import Ordering
|
|
from apischema.schemas import Schema
|
|
from apischema.serialization import (
|
|
PassThroughOptions,
|
|
SerializationMethod,
|
|
SerializationMethodVisitor,
|
|
)
|
|
from apischema.serialization.serialized_methods import (
|
|
ErrorHandler,
|
|
SerializedMethod,
|
|
_get_methods,
|
|
serialized as register_serialized,
|
|
)
|
|
from apischema.types import AnyType, NoneType, Undefined
|
|
from apischema.typing import is_type
|
|
from apischema.utils import (
|
|
awaitable_origin,
|
|
deprecate_kwargs,
|
|
empty_dict,
|
|
get_args2,
|
|
get_origin_or_type2,
|
|
identity,
|
|
is_async,
|
|
is_union_of,
|
|
keep_annotations,
|
|
)
|
|
from apischema.validation.errors import ValidationError
|
|
|
|
|
|
class PartialSerializationMethodVisitor(SerializationMethodVisitor):
|
|
use_cache = False
|
|
|
|
@property
|
|
def _factory(self) -> Callable[[type], SerializationMethod]:
|
|
return lambda _: identity
|
|
|
|
def enum(self, cls: Type[Enum]) -> SerializationMethod:
|
|
return identity
|
|
|
|
def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod:
|
|
return identity
|
|
|
|
def visit(self, tp: AnyType) -> SerializationMethod:
|
|
if tp is UndefinedType:
|
|
return lambda obj: None
|
|
return super().visit(tp)
|
|
|
|
|
|
@cache
|
|
def partial_serialization_method_factory(
|
|
aliaser: Aliaser,
|
|
conversion: Optional[AnyConversion],
|
|
default_conversion: DefaultConversion,
|
|
) -> Callable[[AnyType], SerializationMethod]:
|
|
@lru_cache()
|
|
def factory(tp: AnyType) -> SerializationMethod:
|
|
return PartialSerializationMethodVisitor(
|
|
additional_properties=False,
|
|
aliaser=aliaser,
|
|
check_type=False,
|
|
default_conversion=default_conversion,
|
|
exclude_defaults=False,
|
|
exclude_none=False,
|
|
exclude_unset=False,
|
|
fall_back_on_any=False,
|
|
pass_through_options=PassThroughOptions(),
|
|
).visit_with_conv(tp, conversion)
|
|
|
|
return factory
|
|
|
|
|
|
def unwrap_awaitable(tp: AnyType) -> AnyType:
|
|
if get_origin_or_type2(tp) == awaitable_origin:
|
|
return keep_annotations(get_args2(tp)[0] if get_args2(tp) else Any, tp)
|
|
else:
|
|
return tp
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Resolver(SerializedMethod):
|
|
parameters: Sequence[Parameter]
|
|
parameters_metadata: Mapping[str, Mapping]
|
|
|
|
def error_type(self) -> AnyType:
|
|
return unwrap_awaitable(super().error_type())
|
|
|
|
def return_type(self, return_type: AnyType) -> AnyType:
|
|
return super().return_type(unwrap_awaitable(return_type))
|
|
|
|
|
|
_resolvers: MutableMapping[Type, Dict[str, Resolver]] = CacheAwareDict(
|
|
defaultdict(dict)
|
|
)
|
|
|
|
|
|
def get_resolvers(tp: AnyType) -> Collection[Tuple[Resolver, Mapping[str, AnyType]]]:
|
|
return _get_methods(tp, _resolvers)
|
|
|
|
|
|
def none_error_handler(
|
|
__error: Exception, __obj: Any, __info: graphql.GraphQLResolveInfo, **kwargs
|
|
) -> None:
|
|
return None
|
|
|
|
|
|
def resolver_parameters(
|
|
resolver: Callable, *, check_first: bool
|
|
) -> Iterator[Parameter]:
|
|
first = True
|
|
for param in signature(resolver).parameters.values():
|
|
if param.kind is Parameter.POSITIONAL_ONLY:
|
|
raise TypeError("Resolver can not have positional only parameters")
|
|
if param.kind in {Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY}:
|
|
if param.annotation is Parameter.empty and (check_first or not first):
|
|
raise TypeError("Resolver parameters must be typed")
|
|
yield param
|
|
first = False
|
|
|
|
|
|
MethodOrProp = TypeVar("MethodOrProp", Callable, property)
|
|
|
|
|
|
@overload
|
|
def resolver(__method_or_property: MethodOrProp) -> MethodOrProp:
|
|
...
|
|
|
|
|
|
@overload
|
|
def resolver(
|
|
alias: str = None,
|
|
*,
|
|
conversion: AnyConversion = None,
|
|
error_handler: ErrorHandler = Undefined,
|
|
order: Optional[Ordering] = None,
|
|
schema: Schema = None,
|
|
parameters_metadata: Mapping[str, Mapping] = None,
|
|
serialized: bool = False,
|
|
owner: Type = None,
|
|
) -> Callable[[MethodOrProp], MethodOrProp]:
|
|
...
|
|
|
|
|
|
@deprecate_kwargs({"conversions": "conversion"})
|
|
def resolver(
|
|
__arg=None,
|
|
*,
|
|
alias: str = None,
|
|
conversion: AnyConversion = None,
|
|
error_handler: ErrorHandler = Undefined,
|
|
order: Optional[Ordering] = None,
|
|
schema: Schema = None,
|
|
parameters_metadata: Mapping[str, Mapping] = None,
|
|
serialized: bool = False,
|
|
owner: Type = None,
|
|
):
|
|
def register(func: Callable, owner: Type, alias2: str):
|
|
alias2 = alias or alias2
|
|
_, *parameters = resolver_parameters(func, check_first=owner is None)
|
|
error_handler2 = error_handler
|
|
if error_handler2 is None:
|
|
error_handler2 = none_error_handler
|
|
elif error_handler2 is Undefined:
|
|
error_handler2 = None
|
|
resolver = Resolver(
|
|
func,
|
|
alias2,
|
|
conversion,
|
|
error_handler2,
|
|
order,
|
|
schema,
|
|
parameters,
|
|
parameters_metadata or {},
|
|
)
|
|
_resolvers[owner][alias2] = resolver
|
|
if serialized:
|
|
if is_async(func):
|
|
raise TypeError("Async resolver cannot be used as a serialized method")
|
|
try:
|
|
register_serialized(
|
|
alias=alias2,
|
|
conversion=conversion,
|
|
schema=schema,
|
|
error_handler=error_handler,
|
|
owner=owner,
|
|
)(func)
|
|
except Exception:
|
|
raise TypeError("Resolver cannot be used as a serialized method")
|
|
|
|
if isinstance(__arg, str):
|
|
alias = __arg
|
|
__arg = None
|
|
return method_registerer(__arg, owner, register)
|
|
|
|
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
|
|
|
|
def as_async(func: Callable[[T], U]) -> Callable[[Awaitable[T]], Awaitable[U]]:
|
|
async def wrapper(arg: Awaitable[T]) -> U:
|
|
return func(await arg)
|
|
|
|
return wrapper
|
|
|
|
|
|
def resolver_resolve(
|
|
resolver: Resolver,
|
|
types: Mapping[str, AnyType],
|
|
aliaser: Aliaser,
|
|
default_deserialization: DefaultConversion,
|
|
default_serialization: DefaultConversion,
|
|
serialized: bool = True,
|
|
) -> Callable:
|
|
# graphql deserialization will give Enum objects instead of strings
|
|
def handle_enum(tp: AnyType) -> Optional[AnyConversion]:
|
|
if is_type(tp) and issubclass(tp, Enum):
|
|
return Conversion(identity, source=Any, target=tp)
|
|
return default_deserialization(tp)
|
|
|
|
parameters, info_parameter = [], None
|
|
for param in resolver.parameters:
|
|
param_type = types[param.name]
|
|
if is_union_of(param_type, graphql.GraphQLResolveInfo):
|
|
info_parameter = param.name
|
|
else:
|
|
param_field = ObjectField(
|
|
param.name,
|
|
param_type,
|
|
param.default is Parameter.empty,
|
|
resolver.parameters_metadata.get(param.name, empty_dict),
|
|
param.default,
|
|
)
|
|
deserializer = deserialization_method(
|
|
param_type,
|
|
additional_properties=False,
|
|
aliaser=aliaser,
|
|
coerce=False,
|
|
conversion=param_field.deserialization,
|
|
default_conversion=handle_enum,
|
|
fall_back_on_default=False,
|
|
schema=param_field.schema,
|
|
)
|
|
opt_param = is_union_of(param_type, NoneType) or param.default is None
|
|
parameters.append(
|
|
(
|
|
aliaser(param_field.alias),
|
|
param.name,
|
|
deserializer,
|
|
opt_param,
|
|
param_field.required,
|
|
)
|
|
)
|
|
func, error_handler = resolver.func, resolver.error_handler
|
|
method_factory = partial_serialization_method_factory(
|
|
aliaser, resolver.conversion, default_serialization
|
|
)
|
|
|
|
serialize_result: Callable[[Any], Any]
|
|
if not serialized:
|
|
serialize_result = identity
|
|
elif is_async(resolver.func):
|
|
serialize_result = as_async(method_factory(types["return"]))
|
|
else:
|
|
serialize_result = method_factory(types["return"])
|
|
serialize_error: Optional[Callable[[Any], Any]]
|
|
if error_handler is None:
|
|
serialize_error = None
|
|
elif is_async(error_handler):
|
|
serialize_error = as_async(method_factory(resolver.error_type()))
|
|
else:
|
|
serialize_error = method_factory(resolver.error_type())
|
|
|
|
def resolve(__self, __info, **kwargs):
|
|
values = {}
|
|
errors: Dict[str, ValidationError] = {}
|
|
for alias, param_name, deserializer, opt_param, required in parameters:
|
|
if alias in kwargs:
|
|
# It is possible for the parameter to be non-optional in Python
|
|
# type hints but optional in the generated schema. In this case
|
|
# we should ignore it.
|
|
# See: https://github.com/wyfo/apischema/pull/130#issuecomment-845497392
|
|
if not opt_param and kwargs[alias] is None:
|
|
assert not required
|
|
continue
|
|
try:
|
|
values[param_name] = deserializer(kwargs[alias])
|
|
except ValidationError as err:
|
|
errors[aliaser(param_name)] = err
|
|
elif opt_param and required:
|
|
values[param_name] = None
|
|
|
|
if errors:
|
|
raise ValueError(ValidationError(children=errors).errors)
|
|
if info_parameter:
|
|
values[info_parameter] = __info
|
|
try:
|
|
return serialize_result(func(__self, **values))
|
|
except Exception as error:
|
|
if error_handler is None:
|
|
raise
|
|
assert serialize_error is not None
|
|
return serialize_error(error_handler(error, __self, __info, **kwargs))
|
|
|
|
return resolve
|