tokencrawler/.venv/lib/python3.9/site-packages/apischema/graphql/schema.py
2022-03-17 22:16:30 +01:00

1040 lines
36 KiB
Python

from dataclasses import dataclass, field as field_, replace
from enum import Enum
from functools import wraps
from inspect import Parameter, iscoroutinefunction
from itertools import chain
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Callable,
Collection,
Dict,
Generic,
Iterable,
List,
Mapping,
NewType,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import graphql
from apischema import settings
from apischema.aliases import Aliaser
from apischema.conversions.conversions import AnyConversion, DefaultConversion
from apischema.conversions.visitor import (
Conv,
Deserialization,
DeserializationVisitor,
Serialization,
SerializationVisitor,
)
from apischema.graphql.interfaces import get_interfaces, is_interface
from apischema.graphql.resolvers import (
Resolver,
get_resolvers,
none_error_handler,
partial_serialization_method_factory,
resolver_parameters,
resolver_resolve,
)
from apischema.json_schema.schema import get_field_schema, get_method_schema, get_schema
from apischema.metadata.keys import SCHEMA_METADATA
from apischema.objects import ObjectField
from apischema.objects.visitor import (
DeserializationObjectVisitor,
ObjectVisitor,
SerializationObjectVisitor,
)
from apischema.ordering import Ordering, sort_by_order
from apischema.recursion import RecursiveConversionsVisitor
from apischema.schemas import Schema, merge_schema
from apischema.serialization import SerializationMethod, serialize
from apischema.serialization.serialized_methods import ErrorHandler
from apischema.type_names import TypeName, TypeNameFactory, get_type_name
from apischema.types import AnyType, NoneType, OrderedDict, Undefined, UndefinedType
from apischema.typing import get_args, get_origin, is_annotated
from apischema.utils import (
Lazy,
as_predicate,
context_setter,
deprecate_kwargs,
empty_dict,
get_args2,
get_origin2,
get_origin_or_type,
identity,
is_union_of,
to_camel_case,
)
JsonScalar = graphql.GraphQLScalarType("JSON")
if graphql.version_info >= (3, 1, 2):
JsonScalar.specified_by_url = (
"http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf"
)
GRAPHQL_PRIMITIVE_TYPES = {
int: graphql.GraphQLInt,
float: graphql.GraphQLFloat,
str: graphql.GraphQLString,
bool: graphql.GraphQLBoolean,
}
ID = NewType("ID", str)
class MissingName(Exception):
pass
class Nullable(Exception):
pass
T = TypeVar("T")
Thunk = Union[Callable[[], T], T]
TypeThunk = Thunk[graphql.GraphQLType]
def exec_thunk(thunk: TypeThunk, *, non_null=None) -> Any:
result = thunk if isinstance(thunk, graphql.GraphQLType) else thunk()
if non_null is True and not isinstance(result, graphql.GraphQLNonNull):
return graphql.GraphQLNonNull(result) # type: ignore
if non_null is False and isinstance(result, graphql.GraphQLNonNull):
return result.of_type
return result
def get_parameter_schema(
func: Callable, parameter: Parameter, field: ObjectField
) -> Optional[Schema]:
from apischema import settings
return merge_schema(
settings.base_schema.parameter(func, parameter, field.alias), field.schema
)
def merged_schema(
schema: Optional[Schema], tp: Optional[AnyType]
) -> Tuple[Optional[Schema], Mapping[str, Any]]:
if is_annotated(tp):
for annotation in reversed(get_args(tp)[1:]):
if isinstance(annotation, TypeNameFactory):
break
elif isinstance(annotation, Mapping) and SCHEMA_METADATA in annotation:
schema = merge_schema(annotation[SCHEMA_METADATA], schema)
schema_dict: Dict[str, Any] = {}
if schema is not None:
schema.merge_into(schema_dict)
return schema, schema_dict
def get_description(
schema: Optional[Schema], tp: Optional[AnyType] = None
) -> Optional[str]:
_, schema_dict = merged_schema(schema, tp)
return schema_dict.get("description")
def get_deprecated(
schema: Optional[Schema], tp: Optional[AnyType] = None
) -> Optional[str]:
schema, schema_dict = merged_schema(schema, tp)
if not schema_dict.get("deprecated", False):
return None
while schema is not None:
if schema.annotations is not None:
if isinstance(schema.annotations.deprecated, str):
return schema.annotations.deprecated
elif schema.annotations.deprecated:
return graphql.DEFAULT_DEPRECATION_REASON
schema = schema.child
return graphql.DEFAULT_DEPRECATION_REASON
@dataclass(frozen=True)
class ResolverField:
resolver: Resolver
types: Mapping[str, AnyType]
parameters: Sequence[Parameter]
metadata: Mapping[str, Mapping]
subscribe: Optional[Callable] = None
IdPredicate = Callable[[AnyType], bool]
UnionNameFactory = Callable[[Sequence[str]], str]
GraphQLTp = TypeVar("GraphQLTp", graphql.GraphQLInputType, graphql.GraphQLOutputType)
FactoryFunction = Callable[[Optional[str], Optional[str]], GraphQLTp]
@dataclass(frozen=True)
class TypeFactory(Generic[GraphQLTp]):
factory: FactoryFunction[GraphQLTp]
name: Optional[str] = None
description: Optional[str] = None
# non_null cannot be a field because it can not be forward to factories called in
# wrapping factories (e.g. recursive wrapper)
def merge(
self, type_name: TypeName = TypeName(), schema: Optional[Schema] = None
) -> "TypeFactory[GraphQLTp]":
if type_name == TypeName() and schema is None:
return self
return replace(
self,
name=type_name.graphql or self.name,
description=get_description(schema) or self.description,
)
@property
def type(self) -> GraphQLTp:
return self.factory(self.name, self.description) # type: ignore
@property
def raw_type(self) -> GraphQLTp:
tp = self.type
return tp.of_type if isinstance(tp, graphql.GraphQLNonNull) else tp
def unwrap_name(name: Optional[str], tp: AnyType) -> str:
if name is None:
raise TypeError(f"Missing name for {tp}")
return name
Method = TypeVar("Method", bound=Callable[..., TypeFactory])
def cache_type(method: Method) -> Method:
@wraps(method)
def wrapper(self: "SchemaBuilder", *args, **kwargs):
factory = method(self, *args, **kwargs)
@wraps(factory.factory) # type: ignore
def name_cache(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLNonNull:
if name is None:
return graphql.GraphQLNonNull(factory.factory(name, description)) # type: ignore
# Method is in cache key because scalar types will have the same method,
# and then be shared by both visitors, while input/output types will have
# their own cache entry.
if (name, method, description) in self._cache_by_name:
tp, cached_args = self._cache_by_name[(name, method, description)]
if cached_args == (args, kwargs):
return tp
tp = graphql.GraphQLNonNull(factory.factory(name, description)) # type: ignore
# Don't put args in cache in order to avoid hashable issue
self._cache_by_name[(name, method, description)] = (tp, (args, kwargs))
return tp
return replace(factory, factory=name_cache)
return cast(Method, wrapper)
class SchemaBuilder(
RecursiveConversionsVisitor[Conv, TypeFactory[GraphQLTp]],
ObjectVisitor[TypeFactory[GraphQLTp]],
):
types: Tuple[Type[graphql.GraphQLType], ...]
def __init__(
self,
aliaser: Aliaser,
enum_aliaser: Aliaser,
enum_schemas: Mapping[Enum, Schema],
default_conversion: DefaultConversion,
id_type: graphql.GraphQLScalarType,
is_id: Optional[IdPredicate],
):
super().__init__(default_conversion)
self.aliaser = aliaser
self.enum_aliaser = enum_aliaser
self.enum_schemas = enum_schemas
self.id_type = id_type
self.is_id = is_id or (lambda t: False)
self._cache_by_name: Dict[
Tuple[str, Callable, Optional[str]],
Tuple[graphql.GraphQLNonNull, Tuple[tuple, dict]],
] = {}
def _recursive_result(
self, lazy: Lazy[TypeFactory[GraphQLTp]]
) -> TypeFactory[GraphQLTp]:
def factory(name: Optional[str], description: Optional[str]) -> GraphQLTp:
cached_fact = lazy()
return cached_fact.factory( # type: ignore
name or cached_fact.name, description or cached_fact.description
)
return TypeFactory(factory)
def annotated(
self, tp: AnyType, annotations: Sequence[Any]
) -> TypeFactory[GraphQLTp]:
factory = super().annotated(tp, annotations)
type_name = False
for annotation in reversed(annotations):
if isinstance(annotation, TypeNameFactory):
if type_name:
break
type_name = True
factory = factory.merge(annotation.to_type_name(tp))
if isinstance(annotation, Mapping):
if type_name:
factory = factory.merge(schema=annotation.get(SCHEMA_METADATA))
return factory # type: ignore
@cache_type
def any(self) -> TypeFactory[GraphQLTp]:
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLScalarType:
if name is None:
return JsonScalar
else:
return graphql.GraphQLScalarType(name, description=description)
return TypeFactory(factory)
@cache_type
def collection(
self, cls: Type[Collection], value_type: AnyType
) -> TypeFactory[GraphQLTp]:
return TypeFactory(lambda *_: graphql.GraphQLList(self.visit(value_type).type))
@cache_type
def enum(self, cls: Type[Enum]) -> TypeFactory[GraphQLTp]:
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLEnumType:
return graphql.GraphQLEnumType(
unwrap_name(name, cls),
{
self.enum_aliaser(name): graphql.GraphQLEnumValue(
member,
get_description(self.enum_schemas.get(member)),
get_deprecated(self.enum_schemas.get(member)),
)
for name, member in cls.__members__.items()
},
description=description,
)
return TypeFactory(factory)
@cache_type
def literal(self, values: Sequence[Any]) -> TypeFactory[GraphQLTp]:
from apischema.typing import Literal
if not all(isinstance(v, str) for v in values):
raise TypeError("apischema GraphQL only support Literal of strings")
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLEnumType:
return graphql.GraphQLEnumType(
unwrap_name(name, Literal[tuple(values)]), # type: ignore
dict(zip(map(self.enum_aliaser, values), values)),
description=description,
)
return TypeFactory(factory)
@cache_type
def mapping(
self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType
) -> TypeFactory[GraphQLTp]:
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLScalarType:
if name is not None:
return graphql.GraphQLScalarType(name, description=description)
else:
return JsonScalar
return TypeFactory(factory)
def object(
self, tp: AnyType, fields: Sequence[ObjectField]
) -> TypeFactory[GraphQLTp]:
raise NotImplementedError
@cache_type
def primitive(self, cls: Type) -> TypeFactory[GraphQLTp]:
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLScalarType:
assert cls is not NoneType
if name is not None:
return graphql.GraphQLScalarType(name, description=description)
else:
return GRAPHQL_PRIMITIVE_TYPES[cls]
return TypeFactory(factory)
def tuple(self, types: Sequence[AnyType]) -> TypeFactory[GraphQLTp]:
raise TypeError("Tuple are not supported")
def union(self, alternatives: Sequence[AnyType]) -> TypeFactory[GraphQLTp]:
factories = self._union_results(
(alt for alt in alternatives if alt is not NoneType)
)
if len(factories) == 1:
factory = factories[0]
else:
factory = self._visited_union(factories)
if NoneType in alternatives or UndefinedType in alternatives:
def nullable(name: Optional[str], description: Optional[str]) -> GraphQLTp:
res = factory.factory(name, description) # type: ignore
return res.of_type if isinstance(res, graphql.GraphQLNonNull) else res
return replace(factory, factory=nullable)
else:
return factory
def visit_conversion(
self,
tp: AnyType,
conversion: Optional[Conv],
dynamic: bool,
next_conversion: Optional[AnyConversion] = None,
) -> TypeFactory[GraphQLTp]:
if not dynamic and self.is_id(tp) or tp == ID:
return TypeFactory(lambda *_: graphql.GraphQLNonNull(self.id_type))
factory = super().visit_conversion(tp, conversion, dynamic, next_conversion)
if not dynamic:
factory = factory.merge(get_type_name(tp), get_schema(tp))
if get_args(tp):
factory = factory.merge(schema=get_schema(get_origin(tp)))
return factory # type: ignore
FieldType = TypeVar("FieldType", graphql.GraphQLInputField, graphql.GraphQLField)
class BaseField(Generic[FieldType]):
name: str
ordering: Optional[Ordering]
def items(self) -> Iterable[Tuple[str, FieldType]]:
raise NotImplementedError
@dataclass
class NormalField(BaseField[FieldType]):
alias: str
name: str
field: Lazy[FieldType]
ordering: Optional[Ordering]
def items(self) -> Iterable[Tuple[str, FieldType]]:
yield self.alias, self.field()
@dataclass
class FlattenedField(BaseField[FieldType]):
name: str
ordering: Optional[Ordering]
type: TypeFactory
def items(self) -> Iterable[Tuple[str, FieldType]]:
tp = self.type.raw_type
if not isinstance(
tp,
(
graphql.GraphQLObjectType,
graphql.GraphQLInterfaceType,
graphql.GraphQLInputObjectType,
),
):
raise FlattenedError(self)
yield from tp.fields.items()
class FlattenedError(Exception):
def __init__(self, field: FlattenedField):
self.field = field
def merge_fields(cls: type, fields: Sequence[BaseField]) -> Dict[str, FieldType]:
try:
sorted_fields = sort_by_order(
cls, fields, lambda f: f.name, lambda f: f.ordering
)
except FlattenedError as err:
raise TypeError(
f"Flattened field {cls.__name__}.{err.field.name}"
f" must have an object type"
)
return OrderedDict(chain.from_iterable(map(lambda f: f.items(), sorted_fields)))
class InputSchemaBuilder(
SchemaBuilder[Deserialization, graphql.GraphQLInputType],
DeserializationVisitor[TypeFactory[graphql.GraphQLInputType]],
DeserializationObjectVisitor[TypeFactory[graphql.GraphQLInputType]],
):
types = graphql.type.definition.graphql_input_types
def _field(
self, tp: AnyType, field: ObjectField
) -> Lazy[graphql.GraphQLInputField]:
field_type = field.type
field_default = graphql.Undefined if field.required else field.get_default()
default: Any = graphql.Undefined
# Don't put `null` default + handle Undefined as None
if field_default in {None, Undefined}:
field_type = Optional[field_type]
elif field_default is not graphql.Undefined:
try:
default = serialize(
field_type,
field_default,
aliaser=self.aliaser,
conversion=field.deserialization,
)
except Exception:
field_type = Optional[field_type]
factory = self.visit_with_conv(field_type, field.deserialization)
return lambda: graphql.GraphQLInputField(
factory.type, # type: ignore
default_value=default,
description=get_description(get_field_schema(tp, field), field.type),
)
@cache_type
def object(
self, tp: AnyType, fields: Sequence[ObjectField]
) -> TypeFactory[graphql.GraphQLInputType]:
visited_fields: List[BaseField] = []
for field in fields:
if not field.is_aggregate:
normal_field = NormalField(
self.aliaser(field.alias),
field.name,
self._field(tp, field),
field.ordering,
)
visited_fields.append(normal_field)
elif field.flattened:
flattened_fields = FlattenedField(
field.name,
field.ordering,
self.visit_with_conv(field.type, field.deserialization),
)
visited_fields.append(flattened_fields)
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLInputObjectType:
name = unwrap_name(name, tp)
if not name.endswith("Input"):
name += "Input"
return graphql.GraphQLInputObjectType(
name,
lambda: merge_fields(get_origin_or_type(tp), visited_fields),
description,
)
return TypeFactory(factory)
def _visited_union(
self, results: Sequence[TypeFactory]
) -> TypeFactory[graphql.GraphQLInputType]:
# Check must be done here too because _union_result is used by visit_conversion
if len(results) != 1:
raise TypeError("Union are not supported for input")
return results[0]
Func = TypeVar("Func", bound=Callable)
class OutputSchemaBuilder(
SchemaBuilder[Serialization, graphql.GraphQLOutputType],
SerializationVisitor[TypeFactory[graphql.GraphQLOutputType]],
SerializationObjectVisitor[TypeFactory[graphql.GraphQLOutputType]],
):
types = graphql.type.definition.graphql_output_types
def __init__(
self,
aliaser: Aliaser,
enum_aliaser: Aliaser,
enum_schemas: Mapping[Enum, Schema],
default_conversion: DefaultConversion,
id_type: graphql.GraphQLScalarType,
is_id: Optional[IdPredicate],
union_name_factory: UnionNameFactory,
default_deserialization: DefaultConversion,
):
super().__init__(
aliaser, enum_aliaser, enum_schemas, default_conversion, id_type, is_id
)
self.union_name_factory = union_name_factory
self.input_builder = InputSchemaBuilder(
self.aliaser,
self.enum_aliaser,
self.enum_schemas,
default_deserialization,
self.id_type,
self.is_id,
)
# Share the same cache for input_builder in order to share scalar types
self.input_builder._cache_by_name = self._cache_by_name
self.get_flattened: Optional[Callable[[Any], Any]] = None
def _field_serialization_method(self, field: ObjectField) -> SerializationMethod:
return partial_serialization_method_factory(
self.aliaser, field.serialization, self.default_conversion
)(Optional[field.type] if field.none_as_undefined else field.type)
def _wrap_resolve(self, resolve: Func) -> Func:
if self.get_flattened is None:
return resolve
else:
get_flattened = self.get_flattened
def resolve_wrapper(__obj, __info, **kwargs):
return resolve(get_flattened(__obj), __info, **kwargs)
return cast(Func, resolve_wrapper)
def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]:
field_name = field.name
partial_serialize = self._field_serialization_method(field)
@self._wrap_resolve
def resolve(obj, _):
return partial_serialize(getattr(obj, field_name))
factory = self.visit_with_conv(field.type, field.serialization)
field_schema = get_field_schema(tp, field)
return lambda: graphql.GraphQLField(
factory.type,
None,
resolve,
description=get_description(field_schema, field.type),
deprecation_reason=get_deprecated(field_schema, field.type),
)
def _resolver(
self, tp: AnyType, field: ResolverField
) -> Lazy[graphql.GraphQLField]:
resolve = self._wrap_resolve(
resolver_resolve(
field.resolver,
field.types,
self.aliaser,
self.input_builder.default_conversion,
self.default_conversion,
)
)
args = None
if field.parameters is not None:
args = {}
for param in field.parameters:
default: Any = graphql.Undefined
param_type = field.types[param.name]
if is_union_of(param_type, graphql.GraphQLResolveInfo):
break
param_field = ObjectField(
param.name,
param_type,
param.default is Parameter.empty,
field.metadata.get(param.name, empty_dict),
default=param.default,
)
if param_field.required:
pass
# Don't put `null` default + handle Undefined as None
# also https://github.com/python/typing/issues/775
elif param.default in {None, Undefined}:
param_type = Optional[param_type]
# param.default == graphql.Undefined means the parameter is required
# even if it has a default
elif param.default not in {Parameter.empty, graphql.Undefined}:
try:
default = serialize(
param_type,
param.default,
fall_back_on_any=False,
check_type=True,
)
except Exception:
param_type = Optional[param_type]
arg_factory = self.input_builder.visit_with_conv(
param_type, param_field.deserialization
)
description = get_description(
get_parameter_schema(field.resolver.func, param, param_field),
param_field.type,
)
def arg_thunk(
arg_factory=arg_factory, default=default, description=description
) -> graphql.GraphQLArgument:
return graphql.GraphQLArgument(
arg_factory.type, default, description
)
args[self.aliaser(param_field.alias)] = arg_thunk
factory = self.visit_with_conv(field.types["return"], field.resolver.conversion)
field_schema = get_method_schema(tp, field.resolver)
return lambda: graphql.GraphQLField(
factory.type, # type: ignore
{name: arg() for name, arg in args.items()} if args else None,
resolve,
field.subscribe,
get_description(field_schema),
get_deprecated(field_schema),
)
def _visit_flattened(
self, field: ObjectField
) -> TypeFactory[graphql.GraphQLOutputType]:
get_prev_flattened = (
self.get_flattened if self.get_flattened is not None else identity
)
field_name = field.name
partial_serialize = self._field_serialization_method(field)
def get_flattened(obj):
return partial_serialize(getattr(get_prev_flattened(obj), field_name))
with context_setter(self):
self.get_flattened = get_flattened
return self.visit_with_conv(field.type, field.serialization)
@cache_type
def object(
self,
tp: AnyType,
fields: Sequence[ObjectField],
resolvers: Sequence[ResolverField] = (),
) -> TypeFactory[graphql.GraphQLOutputType]:
cls = get_origin_or_type(tp)
visited_fields: List[BaseField[graphql.GraphQLField]] = []
flattened_factories = []
for field in fields:
if not field.is_aggregate:
normal_field = NormalField(
self.aliaser(field.name),
field.name,
self._field(tp, field),
field.ordering,
)
visited_fields.append(normal_field)
elif field.flattened:
flattened_factory = self._visit_flattened(field)
flattened_factories.append(flattened_factory)
visited_fields.append(
FlattenedField(field.name, field.ordering, flattened_factory)
)
resolvers = list(resolvers)
for resolver, types in get_resolvers(tp):
resolver_field = ResolverField(
resolver, types, resolver.parameters, resolver.parameters_metadata
)
resolvers.append(resolver_field)
for resolver_field in resolvers:
normal_field = NormalField(
self.aliaser(resolver_field.resolver.alias),
resolver_field.resolver.func.__name__,
self._resolver(tp, resolver_field),
resolver_field.resolver.ordering,
)
visited_fields.append(normal_field)
interface_thunk = None
interfaces = list(map(self.visit, get_interfaces(cls)))
if interfaces or flattened_factories:
def interface_thunk() -> Collection[graphql.GraphQLInterfaceType]:
all_interfaces = {
cast(graphql.GraphQLInterfaceType, i.raw_type) for i in interfaces
}
for flattened_factory in flattened_factories:
flattened = cast(
Union[graphql.GraphQLObjectType, graphql.GraphQLInterfaceType],
flattened_factory.raw_type,
)
if isinstance(flattened, graphql.GraphQLObjectType):
all_interfaces.update(flattened.interfaces)
elif isinstance(flattened, graphql.GraphQLInterfaceType):
all_interfaces.add(flattened)
return sorted(all_interfaces, key=lambda i: i.name)
def factory(
name: Optional[str], description: Optional[str]
) -> Union[graphql.GraphQLObjectType, graphql.GraphQLInterfaceType]:
name = unwrap_name(name, cls)
if is_interface(cls):
return graphql.GraphQLInterfaceType(
name,
lambda: merge_fields(cls, visited_fields),
interface_thunk,
description=description,
)
else:
return graphql.GraphQLObjectType(
name,
lambda: merge_fields(cls, visited_fields),
interface_thunk,
is_type_of=lambda obj, _: isinstance(obj, cls),
description=description,
)
return TypeFactory(factory)
def typed_dict(
self, tp: Type, types: Mapping[str, AnyType], required_keys: Collection[str]
) -> TypeFactory[graphql.GraphQLOutputType]:
raise TypeError("TypedDict are not supported in output schema")
@cache_type
def _visited_union(
self, results: Sequence[TypeFactory]
) -> TypeFactory[graphql.GraphQLOutputType]:
def factory(
name: Optional[str], description: Optional[str]
) -> graphql.GraphQLOutputType:
types = [factory.raw_type for factory in results]
if name is None:
name = self.union_name_factory([t.name for t in types])
return graphql.GraphQLUnionType(name, types, description=description)
return TypeFactory(factory)
async_iterable_origins = set(map(get_origin, (AsyncIterable[Any], AsyncIterator[Any])))
_fake_type = cast(type, ...)
@dataclass(frozen=True)
class Operation(Generic[T]):
function: Callable[..., T]
alias: Optional[str] = None
conversion: Optional[AnyConversion] = None
error_handler: ErrorHandler = Undefined
order: Optional[Ordering] = None
schema: Optional[Schema] = None
parameters_metadata: Mapping[str, Mapping] = field_(default_factory=dict)
class Query(Operation):
pass
class Mutation(Operation):
pass
@dataclass(frozen=True)
class Subscription(Operation[AsyncIterable]):
resolver: Optional[Callable] = None
Op = TypeVar("Op", bound=Operation)
def operation_resolver(operation: Union[Callable, Op], op_class: Type[Op]) -> Resolver:
if not isinstance(operation, op_class):
operation = op_class(operation) # type: ignore
error_handler: Optional[Callable]
if operation.error_handler is Undefined:
error_handler = None
elif operation.error_handler is None:
error_handler = none_error_handler
else:
error_handler = operation.error_handler
op = operation.function
if iscoroutinefunction(op):
async def wrapper(_, *args, **kwargs):
return await op(*args, **kwargs)
else:
def wrapper(_, *args, **kwargs):
return op(*args, **kwargs)
wrapper.__annotations__ = op.__annotations__
(*parameters,) = resolver_parameters(operation.function, check_first=True)
return Resolver(
wrapper,
operation.alias or operation.function.__name__,
operation.conversion,
error_handler,
operation.order,
operation.schema,
parameters,
operation.parameters_metadata,
)
@deprecate_kwargs({"union_ref": "union_name"})
def graphql_schema(
*,
query: Iterable[Union[Callable, Query]] = (),
mutation: Iterable[Union[Callable, Mutation]] = (),
subscription: Iterable[Union[Callable[..., AsyncIterable], Subscription]] = (),
types: Iterable[Type] = (),
directives: Optional[Collection[graphql.GraphQLDirective]] = None,
description: Optional[str] = None,
extensions: Optional[Dict[str, Any]] = None,
aliaser: Optional[Aliaser] = to_camel_case,
enum_aliaser: Optional[Aliaser] = str.upper,
enum_schemas: Optional[Mapping[Enum, Schema]] = None,
id_types: Union[Collection[AnyType], IdPredicate] = (),
id_encoding: Tuple[
Optional[Callable[[str], Any]], Optional[Callable[[Any], str]]
] = (None, None),
union_name: UnionNameFactory = "Or".join,
default_deserialization: DefaultConversion = None,
default_serialization: DefaultConversion = None,
) -> graphql.GraphQLSchema:
if aliaser is None:
aliaser = settings.aliaser
if enum_aliaser is None:
enum_aliaser = lambda s: s
if default_deserialization is None:
default_deserialization = settings.deserialization.default_conversion
if default_serialization is None:
default_serialization = settings.serialization.default_conversion
query_fields: List[ResolverField] = []
mutation_fields: List[ResolverField] = []
subscription_fields: List[ResolverField] = []
for operations, op_class, fields in [
(query, Query, query_fields),
(mutation, Mutation, mutation_fields),
]:
for operation in operations: # type: ignore
resolver = operation_resolver(operation, op_class)
resolver_field = ResolverField(
resolver,
resolver.types(),
resolver.parameters,
resolver.parameters_metadata,
)
fields.append(resolver_field)
for sub_op in subscription: # type: ignore
if not isinstance(sub_op, Subscription):
sub_op = Subscription(sub_op) # type: ignore
sub_parameters: Sequence[Parameter]
if sub_op.resolver is not None:
subscriber2 = operation_resolver(sub_op, Subscription)
_, *sub_parameters = resolver_parameters(sub_op.resolver, check_first=False)
resolver = Resolver(
sub_op.resolver,
sub_op.alias or sub_op.resolver.__name__,
sub_op.conversion,
subscriber2.error_handler,
sub_op.order,
sub_op.schema,
sub_parameters,
sub_op.parameters_metadata,
)
sub_types = resolver.types()
subscriber = replace(subscriber2, error_handler=None)
subscribe = resolver_resolve(
subscriber,
subscriber.types(),
aliaser,
default_deserialization,
default_serialization,
serialized=False,
)
else:
subscriber2 = operation_resolver(sub_op, Subscription)
resolver = Resolver(
lambda _: _,
subscriber2.alias,
sub_op.conversion,
subscriber2.error_handler,
sub_op.order,
sub_op.schema,
(),
{},
)
subscriber = replace(subscriber2, error_handler=None)
sub_parameters = subscriber.parameters
sub_types = subscriber.types()
if get_origin2(sub_types["return"]) not in async_iterable_origins:
raise TypeError(
"Subscriptions must return an AsyncIterable/AsyncIterator"
)
event_type = get_args2(sub_types["return"])[0]
subscribe = resolver_resolve(
subscriber,
sub_types,
aliaser,
default_deserialization,
default_serialization,
serialized=False,
)
sub_types = {**sub_types, "return": resolver.return_type(event_type)}
resolver_field = ResolverField(
resolver, sub_types, sub_parameters, sub_op.parameters_metadata, subscribe
)
subscription_fields.append(resolver_field)
is_id = as_predicate(id_types)
if id_encoding == (None, None):
id_type: graphql.GraphQLScalarType = graphql.GraphQLID
else:
id_deserializer, id_serializer = id_encoding
id_type = graphql.GraphQLScalarType(
name="ID",
serialize=id_serializer or graphql.GraphQLID.serialize,
parse_value=id_deserializer or graphql.GraphQLID.parse_value,
parse_literal=graphql.GraphQLID.parse_literal,
description=graphql.GraphQLID.description,
)
output_builder = OutputSchemaBuilder(
aliaser,
enum_aliaser,
enum_schemas or {},
default_serialization,
id_type,
is_id,
union_name,
default_deserialization,
)
def root_type(
name: str, fields: Sequence[ResolverField]
) -> Optional[graphql.GraphQLObjectType]:
if not fields:
return None
tp, type_name = type(name, (), {}), TypeName(graphql=name)
return output_builder.object(tp, (), fields).merge(type_name, None).raw_type # type: ignore
return graphql.GraphQLSchema(
query=root_type("Query", query_fields),
mutation=root_type("Mutation", mutation_fields),
subscription=root_type("Subscription", subscription_fields),
types=[output_builder.visit(cls).raw_type for cls in types], # type: ignore
directives=directives,
description=description,
extensions=extensions,
)