from dataclasses import dataclass, field as field_, replace from enum import Enum from functools import wraps from inspect import Parameter, iscoroutinefunction from itertools import chain from typing import ( Any, AsyncIterable, AsyncIterator, Callable, Collection, Dict, Generic, Iterable, List, Mapping, NewType, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, ) import graphql from apischema import settings from apischema.aliases import Aliaser from apischema.conversions.conversions import AnyConversion, DefaultConversion from apischema.conversions.visitor import ( Conv, Deserialization, DeserializationVisitor, Serialization, SerializationVisitor, ) from apischema.graphql.interfaces import get_interfaces, is_interface from apischema.graphql.resolvers import ( Resolver, get_resolvers, none_error_handler, partial_serialization_method_factory, resolver_parameters, resolver_resolve, ) from apischema.json_schema.schema import get_field_schema, get_method_schema, get_schema from apischema.metadata.keys import SCHEMA_METADATA from apischema.objects import ObjectField from apischema.objects.visitor import ( DeserializationObjectVisitor, ObjectVisitor, SerializationObjectVisitor, ) from apischema.ordering import Ordering, sort_by_order from apischema.recursion import RecursiveConversionsVisitor from apischema.schemas import Schema, merge_schema from apischema.serialization import SerializationMethod, serialize from apischema.serialization.serialized_methods import ErrorHandler from apischema.type_names import TypeName, TypeNameFactory, get_type_name from apischema.types import AnyType, NoneType, OrderedDict, Undefined, UndefinedType from apischema.typing import get_args, get_origin, is_annotated from apischema.utils import ( Lazy, as_predicate, context_setter, deprecate_kwargs, empty_dict, get_args2, get_origin2, get_origin_or_type, identity, is_union_of, to_camel_case, ) JsonScalar = graphql.GraphQLScalarType("JSON") if graphql.version_info >= (3, 1, 2): JsonScalar.specified_by_url = ( "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf" ) GRAPHQL_PRIMITIVE_TYPES = { int: graphql.GraphQLInt, float: graphql.GraphQLFloat, str: graphql.GraphQLString, bool: graphql.GraphQLBoolean, } ID = NewType("ID", str) class MissingName(Exception): pass class Nullable(Exception): pass T = TypeVar("T") Thunk = Union[Callable[[], T], T] TypeThunk = Thunk[graphql.GraphQLType] def exec_thunk(thunk: TypeThunk, *, non_null=None) -> Any: result = thunk if isinstance(thunk, graphql.GraphQLType) else thunk() if non_null is True and not isinstance(result, graphql.GraphQLNonNull): return graphql.GraphQLNonNull(result) # type: ignore if non_null is False and isinstance(result, graphql.GraphQLNonNull): return result.of_type return result def get_parameter_schema( func: Callable, parameter: Parameter, field: ObjectField ) -> Optional[Schema]: from apischema import settings return merge_schema( settings.base_schema.parameter(func, parameter, field.alias), field.schema ) def merged_schema( schema: Optional[Schema], tp: Optional[AnyType] ) -> Tuple[Optional[Schema], Mapping[str, Any]]: if is_annotated(tp): for annotation in reversed(get_args(tp)[1:]): if isinstance(annotation, TypeNameFactory): break elif isinstance(annotation, Mapping) and SCHEMA_METADATA in annotation: schema = merge_schema(annotation[SCHEMA_METADATA], schema) schema_dict: Dict[str, Any] = {} if schema is not None: schema.merge_into(schema_dict) return schema, schema_dict def get_description( schema: Optional[Schema], tp: Optional[AnyType] = None ) -> Optional[str]: _, schema_dict = merged_schema(schema, tp) return schema_dict.get("description") def get_deprecated( schema: Optional[Schema], tp: Optional[AnyType] = None ) -> Optional[str]: schema, schema_dict = merged_schema(schema, tp) if not schema_dict.get("deprecated", False): return None while schema is not None: if schema.annotations is not None: if isinstance(schema.annotations.deprecated, str): return schema.annotations.deprecated elif schema.annotations.deprecated: return graphql.DEFAULT_DEPRECATION_REASON schema = schema.child return graphql.DEFAULT_DEPRECATION_REASON @dataclass(frozen=True) class ResolverField: resolver: Resolver types: Mapping[str, AnyType] parameters: Sequence[Parameter] metadata: Mapping[str, Mapping] subscribe: Optional[Callable] = None IdPredicate = Callable[[AnyType], bool] UnionNameFactory = Callable[[Sequence[str]], str] GraphQLTp = TypeVar("GraphQLTp", graphql.GraphQLInputType, graphql.GraphQLOutputType) FactoryFunction = Callable[[Optional[str], Optional[str]], GraphQLTp] @dataclass(frozen=True) class TypeFactory(Generic[GraphQLTp]): factory: FactoryFunction[GraphQLTp] name: Optional[str] = None description: Optional[str] = None # non_null cannot be a field because it can not be forward to factories called in # wrapping factories (e.g. recursive wrapper) def merge( self, type_name: TypeName = TypeName(), schema: Optional[Schema] = None ) -> "TypeFactory[GraphQLTp]": if type_name == TypeName() and schema is None: return self return replace( self, name=type_name.graphql or self.name, description=get_description(schema) or self.description, ) @property def type(self) -> GraphQLTp: return self.factory(self.name, self.description) # type: ignore @property def raw_type(self) -> GraphQLTp: tp = self.type return tp.of_type if isinstance(tp, graphql.GraphQLNonNull) else tp def unwrap_name(name: Optional[str], tp: AnyType) -> str: if name is None: raise TypeError(f"Missing name for {tp}") return name Method = TypeVar("Method", bound=Callable[..., TypeFactory]) def cache_type(method: Method) -> Method: @wraps(method) def wrapper(self: "SchemaBuilder", *args, **kwargs): factory = method(self, *args, **kwargs) @wraps(factory.factory) # type: ignore def name_cache( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLNonNull: if name is None: return graphql.GraphQLNonNull(factory.factory(name, description)) # type: ignore # Method is in cache key because scalar types will have the same method, # and then be shared by both visitors, while input/output types will have # their own cache entry. if (name, method, description) in self._cache_by_name: tp, cached_args = self._cache_by_name[(name, method, description)] if cached_args == (args, kwargs): return tp tp = graphql.GraphQLNonNull(factory.factory(name, description)) # type: ignore # Don't put args in cache in order to avoid hashable issue self._cache_by_name[(name, method, description)] = (tp, (args, kwargs)) return tp return replace(factory, factory=name_cache) return cast(Method, wrapper) class SchemaBuilder( RecursiveConversionsVisitor[Conv, TypeFactory[GraphQLTp]], ObjectVisitor[TypeFactory[GraphQLTp]], ): types: Tuple[Type[graphql.GraphQLType], ...] def __init__( self, aliaser: Aliaser, enum_aliaser: Aliaser, enum_schemas: Mapping[Enum, Schema], default_conversion: DefaultConversion, id_type: graphql.GraphQLScalarType, is_id: Optional[IdPredicate], ): super().__init__(default_conversion) self.aliaser = aliaser self.enum_aliaser = enum_aliaser self.enum_schemas = enum_schemas self.id_type = id_type self.is_id = is_id or (lambda t: False) self._cache_by_name: Dict[ Tuple[str, Callable, Optional[str]], Tuple[graphql.GraphQLNonNull, Tuple[tuple, dict]], ] = {} def _recursive_result( self, lazy: Lazy[TypeFactory[GraphQLTp]] ) -> TypeFactory[GraphQLTp]: def factory(name: Optional[str], description: Optional[str]) -> GraphQLTp: cached_fact = lazy() return cached_fact.factory( # type: ignore name or cached_fact.name, description or cached_fact.description ) return TypeFactory(factory) def annotated( self, tp: AnyType, annotations: Sequence[Any] ) -> TypeFactory[GraphQLTp]: factory = super().annotated(tp, annotations) type_name = False for annotation in reversed(annotations): if isinstance(annotation, TypeNameFactory): if type_name: break type_name = True factory = factory.merge(annotation.to_type_name(tp)) if isinstance(annotation, Mapping): if type_name: factory = factory.merge(schema=annotation.get(SCHEMA_METADATA)) return factory # type: ignore @cache_type def any(self) -> TypeFactory[GraphQLTp]: def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLScalarType: if name is None: return JsonScalar else: return graphql.GraphQLScalarType(name, description=description) return TypeFactory(factory) @cache_type def collection( self, cls: Type[Collection], value_type: AnyType ) -> TypeFactory[GraphQLTp]: return TypeFactory(lambda *_: graphql.GraphQLList(self.visit(value_type).type)) @cache_type def enum(self, cls: Type[Enum]) -> TypeFactory[GraphQLTp]: def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLEnumType: return graphql.GraphQLEnumType( unwrap_name(name, cls), { self.enum_aliaser(name): graphql.GraphQLEnumValue( member, get_description(self.enum_schemas.get(member)), get_deprecated(self.enum_schemas.get(member)), ) for name, member in cls.__members__.items() }, description=description, ) return TypeFactory(factory) @cache_type def literal(self, values: Sequence[Any]) -> TypeFactory[GraphQLTp]: from apischema.typing import Literal if not all(isinstance(v, str) for v in values): raise TypeError("apischema GraphQL only support Literal of strings") def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLEnumType: return graphql.GraphQLEnumType( unwrap_name(name, Literal[tuple(values)]), # type: ignore dict(zip(map(self.enum_aliaser, values), values)), description=description, ) return TypeFactory(factory) @cache_type def mapping( self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType ) -> TypeFactory[GraphQLTp]: def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLScalarType: if name is not None: return graphql.GraphQLScalarType(name, description=description) else: return JsonScalar return TypeFactory(factory) def object( self, tp: AnyType, fields: Sequence[ObjectField] ) -> TypeFactory[GraphQLTp]: raise NotImplementedError @cache_type def primitive(self, cls: Type) -> TypeFactory[GraphQLTp]: def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLScalarType: assert cls is not NoneType if name is not None: return graphql.GraphQLScalarType(name, description=description) else: return GRAPHQL_PRIMITIVE_TYPES[cls] return TypeFactory(factory) def tuple(self, types: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: raise TypeError("Tuple are not supported") def union(self, alternatives: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: factories = self._union_results( (alt for alt in alternatives if alt is not NoneType) ) if len(factories) == 1: factory = factories[0] else: factory = self._visited_union(factories) if NoneType in alternatives or UndefinedType in alternatives: def nullable(name: Optional[str], description: Optional[str]) -> GraphQLTp: res = factory.factory(name, description) # type: ignore return res.of_type if isinstance(res, graphql.GraphQLNonNull) else res return replace(factory, factory=nullable) else: return factory def visit_conversion( self, tp: AnyType, conversion: Optional[Conv], dynamic: bool, next_conversion: Optional[AnyConversion] = None, ) -> TypeFactory[GraphQLTp]: if not dynamic and self.is_id(tp) or tp == ID: return TypeFactory(lambda *_: graphql.GraphQLNonNull(self.id_type)) factory = super().visit_conversion(tp, conversion, dynamic, next_conversion) if not dynamic: factory = factory.merge(get_type_name(tp), get_schema(tp)) if get_args(tp): factory = factory.merge(schema=get_schema(get_origin(tp))) return factory # type: ignore FieldType = TypeVar("FieldType", graphql.GraphQLInputField, graphql.GraphQLField) class BaseField(Generic[FieldType]): name: str ordering: Optional[Ordering] def items(self) -> Iterable[Tuple[str, FieldType]]: raise NotImplementedError @dataclass class NormalField(BaseField[FieldType]): alias: str name: str field: Lazy[FieldType] ordering: Optional[Ordering] def items(self) -> Iterable[Tuple[str, FieldType]]: yield self.alias, self.field() @dataclass class FlattenedField(BaseField[FieldType]): name: str ordering: Optional[Ordering] type: TypeFactory def items(self) -> Iterable[Tuple[str, FieldType]]: tp = self.type.raw_type if not isinstance( tp, ( graphql.GraphQLObjectType, graphql.GraphQLInterfaceType, graphql.GraphQLInputObjectType, ), ): raise FlattenedError(self) yield from tp.fields.items() class FlattenedError(Exception): def __init__(self, field: FlattenedField): self.field = field def merge_fields(cls: type, fields: Sequence[BaseField]) -> Dict[str, FieldType]: try: sorted_fields = sort_by_order( cls, fields, lambda f: f.name, lambda f: f.ordering ) except FlattenedError as err: raise TypeError( f"Flattened field {cls.__name__}.{err.field.name}" f" must have an object type" ) return OrderedDict(chain.from_iterable(map(lambda f: f.items(), sorted_fields))) class InputSchemaBuilder( SchemaBuilder[Deserialization, graphql.GraphQLInputType], DeserializationVisitor[TypeFactory[graphql.GraphQLInputType]], DeserializationObjectVisitor[TypeFactory[graphql.GraphQLInputType]], ): types = graphql.type.definition.graphql_input_types def _field( self, tp: AnyType, field: ObjectField ) -> Lazy[graphql.GraphQLInputField]: field_type = field.type field_default = graphql.Undefined if field.required else field.get_default() default: Any = graphql.Undefined # Don't put `null` default + handle Undefined as None if field_default in {None, Undefined}: field_type = Optional[field_type] elif field_default is not graphql.Undefined: try: default = serialize( field_type, field_default, aliaser=self.aliaser, conversion=field.deserialization, ) except Exception: field_type = Optional[field_type] factory = self.visit_with_conv(field_type, field.deserialization) return lambda: graphql.GraphQLInputField( factory.type, # type: ignore default_value=default, description=get_description(get_field_schema(tp, field), field.type), ) @cache_type def object( self, tp: AnyType, fields: Sequence[ObjectField] ) -> TypeFactory[graphql.GraphQLInputType]: visited_fields: List[BaseField] = [] for field in fields: if not field.is_aggregate: normal_field = NormalField( self.aliaser(field.alias), field.name, self._field(tp, field), field.ordering, ) visited_fields.append(normal_field) elif field.flattened: flattened_fields = FlattenedField( field.name, field.ordering, self.visit_with_conv(field.type, field.deserialization), ) visited_fields.append(flattened_fields) def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLInputObjectType: name = unwrap_name(name, tp) if not name.endswith("Input"): name += "Input" return graphql.GraphQLInputObjectType( name, lambda: merge_fields(get_origin_or_type(tp), visited_fields), description, ) return TypeFactory(factory) def _visited_union( self, results: Sequence[TypeFactory] ) -> TypeFactory[graphql.GraphQLInputType]: # Check must be done here too because _union_result is used by visit_conversion if len(results) != 1: raise TypeError("Union are not supported for input") return results[0] Func = TypeVar("Func", bound=Callable) class OutputSchemaBuilder( SchemaBuilder[Serialization, graphql.GraphQLOutputType], SerializationVisitor[TypeFactory[graphql.GraphQLOutputType]], SerializationObjectVisitor[TypeFactory[graphql.GraphQLOutputType]], ): types = graphql.type.definition.graphql_output_types def __init__( self, aliaser: Aliaser, enum_aliaser: Aliaser, enum_schemas: Mapping[Enum, Schema], default_conversion: DefaultConversion, id_type: graphql.GraphQLScalarType, is_id: Optional[IdPredicate], union_name_factory: UnionNameFactory, default_deserialization: DefaultConversion, ): super().__init__( aliaser, enum_aliaser, enum_schemas, default_conversion, id_type, is_id ) self.union_name_factory = union_name_factory self.input_builder = InputSchemaBuilder( self.aliaser, self.enum_aliaser, self.enum_schemas, default_deserialization, self.id_type, self.is_id, ) # Share the same cache for input_builder in order to share scalar types self.input_builder._cache_by_name = self._cache_by_name self.get_flattened: Optional[Callable[[Any], Any]] = None def _field_serialization_method(self, field: ObjectField) -> SerializationMethod: return partial_serialization_method_factory( self.aliaser, field.serialization, self.default_conversion )(Optional[field.type] if field.none_as_undefined else field.type) def _wrap_resolve(self, resolve: Func) -> Func: if self.get_flattened is None: return resolve else: get_flattened = self.get_flattened def resolve_wrapper(__obj, __info, **kwargs): return resolve(get_flattened(__obj), __info, **kwargs) return cast(Func, resolve_wrapper) def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]: field_name = field.name partial_serialize = self._field_serialization_method(field) @self._wrap_resolve def resolve(obj, _): return partial_serialize(getattr(obj, field_name)) factory = self.visit_with_conv(field.type, field.serialization) field_schema = get_field_schema(tp, field) return lambda: graphql.GraphQLField( factory.type, None, resolve, description=get_description(field_schema, field.type), deprecation_reason=get_deprecated(field_schema, field.type), ) def _resolver( self, tp: AnyType, field: ResolverField ) -> Lazy[graphql.GraphQLField]: resolve = self._wrap_resolve( resolver_resolve( field.resolver, field.types, self.aliaser, self.input_builder.default_conversion, self.default_conversion, ) ) args = None if field.parameters is not None: args = {} for param in field.parameters: default: Any = graphql.Undefined param_type = field.types[param.name] if is_union_of(param_type, graphql.GraphQLResolveInfo): break param_field = ObjectField( param.name, param_type, param.default is Parameter.empty, field.metadata.get(param.name, empty_dict), default=param.default, ) if param_field.required: pass # Don't put `null` default + handle Undefined as None # also https://github.com/python/typing/issues/775 elif param.default in {None, Undefined}: param_type = Optional[param_type] # param.default == graphql.Undefined means the parameter is required # even if it has a default elif param.default not in {Parameter.empty, graphql.Undefined}: try: default = serialize( param_type, param.default, fall_back_on_any=False, check_type=True, ) except Exception: param_type = Optional[param_type] arg_factory = self.input_builder.visit_with_conv( param_type, param_field.deserialization ) description = get_description( get_parameter_schema(field.resolver.func, param, param_field), param_field.type, ) def arg_thunk( arg_factory=arg_factory, default=default, description=description ) -> graphql.GraphQLArgument: return graphql.GraphQLArgument( arg_factory.type, default, description ) args[self.aliaser(param_field.alias)] = arg_thunk factory = self.visit_with_conv(field.types["return"], field.resolver.conversion) field_schema = get_method_schema(tp, field.resolver) return lambda: graphql.GraphQLField( factory.type, # type: ignore {name: arg() for name, arg in args.items()} if args else None, resolve, field.subscribe, get_description(field_schema), get_deprecated(field_schema), ) def _visit_flattened( self, field: ObjectField ) -> TypeFactory[graphql.GraphQLOutputType]: get_prev_flattened = ( self.get_flattened if self.get_flattened is not None else identity ) field_name = field.name partial_serialize = self._field_serialization_method(field) def get_flattened(obj): return partial_serialize(getattr(get_prev_flattened(obj), field_name)) with context_setter(self): self.get_flattened = get_flattened return self.visit_with_conv(field.type, field.serialization) @cache_type def object( self, tp: AnyType, fields: Sequence[ObjectField], resolvers: Sequence[ResolverField] = (), ) -> TypeFactory[graphql.GraphQLOutputType]: cls = get_origin_or_type(tp) visited_fields: List[BaseField[graphql.GraphQLField]] = [] flattened_factories = [] for field in fields: if not field.is_aggregate: normal_field = NormalField( self.aliaser(field.name), field.name, self._field(tp, field), field.ordering, ) visited_fields.append(normal_field) elif field.flattened: flattened_factory = self._visit_flattened(field) flattened_factories.append(flattened_factory) visited_fields.append( FlattenedField(field.name, field.ordering, flattened_factory) ) resolvers = list(resolvers) for resolver, types in get_resolvers(tp): resolver_field = ResolverField( resolver, types, resolver.parameters, resolver.parameters_metadata ) resolvers.append(resolver_field) for resolver_field in resolvers: normal_field = NormalField( self.aliaser(resolver_field.resolver.alias), resolver_field.resolver.func.__name__, self._resolver(tp, resolver_field), resolver_field.resolver.ordering, ) visited_fields.append(normal_field) interface_thunk = None interfaces = list(map(self.visit, get_interfaces(cls))) if interfaces or flattened_factories: def interface_thunk() -> Collection[graphql.GraphQLInterfaceType]: all_interfaces = { cast(graphql.GraphQLInterfaceType, i.raw_type) for i in interfaces } for flattened_factory in flattened_factories: flattened = cast( Union[graphql.GraphQLObjectType, graphql.GraphQLInterfaceType], flattened_factory.raw_type, ) if isinstance(flattened, graphql.GraphQLObjectType): all_interfaces.update(flattened.interfaces) elif isinstance(flattened, graphql.GraphQLInterfaceType): all_interfaces.add(flattened) return sorted(all_interfaces, key=lambda i: i.name) def factory( name: Optional[str], description: Optional[str] ) -> Union[graphql.GraphQLObjectType, graphql.GraphQLInterfaceType]: name = unwrap_name(name, cls) if is_interface(cls): return graphql.GraphQLInterfaceType( name, lambda: merge_fields(cls, visited_fields), interface_thunk, description=description, ) else: return graphql.GraphQLObjectType( name, lambda: merge_fields(cls, visited_fields), interface_thunk, is_type_of=lambda obj, _: isinstance(obj, cls), description=description, ) return TypeFactory(factory) def typed_dict( self, tp: Type, types: Mapping[str, AnyType], required_keys: Collection[str] ) -> TypeFactory[graphql.GraphQLOutputType]: raise TypeError("TypedDict are not supported in output schema") @cache_type def _visited_union( self, results: Sequence[TypeFactory] ) -> TypeFactory[graphql.GraphQLOutputType]: def factory( name: Optional[str], description: Optional[str] ) -> graphql.GraphQLOutputType: types = [factory.raw_type for factory in results] if name is None: name = self.union_name_factory([t.name for t in types]) return graphql.GraphQLUnionType(name, types, description=description) return TypeFactory(factory) async_iterable_origins = set(map(get_origin, (AsyncIterable[Any], AsyncIterator[Any]))) _fake_type = cast(type, ...) @dataclass(frozen=True) class Operation(Generic[T]): function: Callable[..., T] alias: Optional[str] = None conversion: Optional[AnyConversion] = None error_handler: ErrorHandler = Undefined order: Optional[Ordering] = None schema: Optional[Schema] = None parameters_metadata: Mapping[str, Mapping] = field_(default_factory=dict) class Query(Operation): pass class Mutation(Operation): pass @dataclass(frozen=True) class Subscription(Operation[AsyncIterable]): resolver: Optional[Callable] = None Op = TypeVar("Op", bound=Operation) def operation_resolver(operation: Union[Callable, Op], op_class: Type[Op]) -> Resolver: if not isinstance(operation, op_class): operation = op_class(operation) # type: ignore error_handler: Optional[Callable] if operation.error_handler is Undefined: error_handler = None elif operation.error_handler is None: error_handler = none_error_handler else: error_handler = operation.error_handler op = operation.function if iscoroutinefunction(op): async def wrapper(_, *args, **kwargs): return await op(*args, **kwargs) else: def wrapper(_, *args, **kwargs): return op(*args, **kwargs) wrapper.__annotations__ = op.__annotations__ (*parameters,) = resolver_parameters(operation.function, check_first=True) return Resolver( wrapper, operation.alias or operation.function.__name__, operation.conversion, error_handler, operation.order, operation.schema, parameters, operation.parameters_metadata, ) @deprecate_kwargs({"union_ref": "union_name"}) def graphql_schema( *, query: Iterable[Union[Callable, Query]] = (), mutation: Iterable[Union[Callable, Mutation]] = (), subscription: Iterable[Union[Callable[..., AsyncIterable], Subscription]] = (), types: Iterable[Type] = (), directives: Optional[Collection[graphql.GraphQLDirective]] = None, description: Optional[str] = None, extensions: Optional[Dict[str, Any]] = None, aliaser: Optional[Aliaser] = to_camel_case, enum_aliaser: Optional[Aliaser] = str.upper, enum_schemas: Optional[Mapping[Enum, Schema]] = None, id_types: Union[Collection[AnyType], IdPredicate] = (), id_encoding: Tuple[ Optional[Callable[[str], Any]], Optional[Callable[[Any], str]] ] = (None, None), union_name: UnionNameFactory = "Or".join, default_deserialization: DefaultConversion = None, default_serialization: DefaultConversion = None, ) -> graphql.GraphQLSchema: if aliaser is None: aliaser = settings.aliaser if enum_aliaser is None: enum_aliaser = lambda s: s if default_deserialization is None: default_deserialization = settings.deserialization.default_conversion if default_serialization is None: default_serialization = settings.serialization.default_conversion query_fields: List[ResolverField] = [] mutation_fields: List[ResolverField] = [] subscription_fields: List[ResolverField] = [] for operations, op_class, fields in [ (query, Query, query_fields), (mutation, Mutation, mutation_fields), ]: for operation in operations: # type: ignore resolver = operation_resolver(operation, op_class) resolver_field = ResolverField( resolver, resolver.types(), resolver.parameters, resolver.parameters_metadata, ) fields.append(resolver_field) for sub_op in subscription: # type: ignore if not isinstance(sub_op, Subscription): sub_op = Subscription(sub_op) # type: ignore sub_parameters: Sequence[Parameter] if sub_op.resolver is not None: subscriber2 = operation_resolver(sub_op, Subscription) _, *sub_parameters = resolver_parameters(sub_op.resolver, check_first=False) resolver = Resolver( sub_op.resolver, sub_op.alias or sub_op.resolver.__name__, sub_op.conversion, subscriber2.error_handler, sub_op.order, sub_op.schema, sub_parameters, sub_op.parameters_metadata, ) sub_types = resolver.types() subscriber = replace(subscriber2, error_handler=None) subscribe = resolver_resolve( subscriber, subscriber.types(), aliaser, default_deserialization, default_serialization, serialized=False, ) else: subscriber2 = operation_resolver(sub_op, Subscription) resolver = Resolver( lambda _: _, subscriber2.alias, sub_op.conversion, subscriber2.error_handler, sub_op.order, sub_op.schema, (), {}, ) subscriber = replace(subscriber2, error_handler=None) sub_parameters = subscriber.parameters sub_types = subscriber.types() if get_origin2(sub_types["return"]) not in async_iterable_origins: raise TypeError( "Subscriptions must return an AsyncIterable/AsyncIterator" ) event_type = get_args2(sub_types["return"])[0] subscribe = resolver_resolve( subscriber, sub_types, aliaser, default_deserialization, default_serialization, serialized=False, ) sub_types = {**sub_types, "return": resolver.return_type(event_type)} resolver_field = ResolverField( resolver, sub_types, sub_parameters, sub_op.parameters_metadata, subscribe ) subscription_fields.append(resolver_field) is_id = as_predicate(id_types) if id_encoding == (None, None): id_type: graphql.GraphQLScalarType = graphql.GraphQLID else: id_deserializer, id_serializer = id_encoding id_type = graphql.GraphQLScalarType( name="ID", serialize=id_serializer or graphql.GraphQLID.serialize, parse_value=id_deserializer or graphql.GraphQLID.parse_value, parse_literal=graphql.GraphQLID.parse_literal, description=graphql.GraphQLID.description, ) output_builder = OutputSchemaBuilder( aliaser, enum_aliaser, enum_schemas or {}, default_serialization, id_type, is_id, union_name, default_deserialization, ) def root_type( name: str, fields: Sequence[ResolverField] ) -> Optional[graphql.GraphQLObjectType]: if not fields: return None tp, type_name = type(name, (), {}), TypeName(graphql=name) return output_builder.object(tp, (), fields).merge(type_name, None).raw_type # type: ignore return graphql.GraphQLSchema( query=root_type("Query", query_fields), mutation=root_type("Mutation", mutation_fields), subscription=root_type("Subscription", subscription_fields), types=[output_builder.visit(cls).raw_type for cls in types], # type: ignore directives=directives, description=description, extensions=extensions, )