diff --git a/example_project/app/schema.py b/example_project/app/schema.py index b08f965..76660ec 100644 --- a/example_project/app/schema.py +++ b/example_project/app/schema.py @@ -43,6 +43,8 @@ class Query(graphene.ObjectType): + node = relay.Node.Field() + all_postal_codes = DjangoListField(PostalCodeType) all_developers = DjangoListField(DeveloperType) all_property_managers = DjangoListField(PropertyManagerType) diff --git a/query_optimizer/ast.py b/query_optimizer/ast.py index dc0c511..043fa39 100644 --- a/query_optimizer/ast.py +++ b/query_optimizer/ast.py @@ -7,7 +7,7 @@ from django.db.models import Field, ForeignKey, Model from graphene import Connection, ObjectType, PageInfo from graphene.relay.node import AbstractNode -from graphene.types.definitions import GrapheneObjectType, GrapheneUnionType +from graphene.types.definitions import GrapheneInterfaceType, GrapheneObjectType, GrapheneUnionType from graphene.utils.str_converters import to_snake_case from graphene_django import DjangoObjectType from graphql import ( @@ -16,6 +16,7 @@ FragmentSpreadNode, GraphQLField, GraphQLOutputType, + GraphQLSchema, InlineFragmentNode, SelectionNode, ) @@ -182,14 +183,17 @@ def handle_fragment_spread(self, field_type: GrapheneObjectType, fragment_spread selections = get_selections(fragment_definition) return self.handle_selections(field_type, selections) - def handle_inline_fragment(self, field_type: GrapheneUnionType, inline_fragment: InlineFragmentNode) -> None: - fragment_type = get_fragment_type(field_type, inline_fragment) + def handle_inline_fragment( + self, + field_type: GrapheneUnionType | GrapheneInterfaceType, + inline_fragment: InlineFragmentNode, + ) -> None: + fragment_type = get_fragment_type(field_type, inline_fragment, self.info.schema) fragment_model: type[Model] = fragment_type.graphene_type._meta.model - if fragment_model != self.model: - return None - - selections = get_selections(inline_fragment) - return self.handle_selections(fragment_type, selections) + if fragment_model == self.model: + selections = get_selections(inline_fragment) + return self.handle_selections(fragment_type, selections) + return None def get_graphene_type(self, field_type: GrapheneObjectType, field_node: FieldNode) -> GrapheneType: graphql_field = get_field_def(self.info.schema, field_type, field_node) @@ -266,14 +270,27 @@ def is_to_one(field: Field) -> TypeGuard[ToOneField]: return bool(field.many_to_one or field.one_to_one) -def get_fragment_type(field_type: GrapheneUnionType, inline_fragment: InlineFragmentNode) -> GrapheneObjectType: +def get_fragment_type( + field_type: GrapheneUnionType | GrapheneInterfaceType, + inline_fragment: InlineFragmentNode, + schema: GraphQLSchema, +) -> GrapheneObjectType: fragment_type_name = inline_fragment.type_condition.name.value - gen = (t for t in field_type.types if t.name == fragment_type_name) - fragment_type: Optional[GrapheneObjectType] = next(gen, None) - if fragment_type is None: # pragma: no cover - msg = f"Fragment type '{fragment_type_name}' not found in union '{field_type}'" - raise OptimizerError(msg) + # For unions, fetch the type from in the union. + if isinstance(field_type, GrapheneUnionType): + gen = (t for t in field_type.types if t.name == fragment_type_name) + fragment_type: Optional[GrapheneObjectType] = next(gen, None) + if fragment_type is None: # pragma: no cover + msg = f"Fragment type '{fragment_type_name}' not found in union '{field_type}'" + raise OptimizerError(msg) + + # For interfaces, fetch the type from in the schema. + else: + fragment_type: Optional[GrapheneObjectType] = schema.get_type(fragment_type_name) + if fragment_type is None: # pragma: no cover + msg = f"Fragment type '{fragment_type_name}' not found in schema." + raise OptimizerError(msg) return fragment_type diff --git a/tests/test_relay_node.py b/tests/test_relay_node.py index 94fc4fc..a6edf57 100644 --- a/tests/test_relay_node.py +++ b/tests/test_relay_node.py @@ -10,6 +10,36 @@ ] +def test_relay__global_node(graphql_client): + apartment = ApartmentFactory.create(building__name="1") + global_id = to_global_id(str(ApartmentNode), apartment.pk) + + query = """ + query { + node(id: "%s") { + ... on ApartmentNode { + building { + name + } + } + } + } + """ % (global_id,) + + response = graphql_client(query) + assert response.no_errors, response.errors + + # 1 query for fetching apartment and related buildings + assert response.queries.count == 1, response.queries.log + + assert response.queries[0] == has( + 'FROM "app_apartment"', + 'INNER JOIN "app_building"', + ) + + assert response.content == {"building": {"name": "1"}} + + def test_relay__node(graphql_client): apartment = ApartmentFactory.create(building__name="1") global_id = to_global_id(str(ApartmentNode), apartment.pk) @@ -32,6 +62,7 @@ def test_relay__node(graphql_client): assert response.queries[0] == has( 'FROM "app_apartment"', + 'INNER JOIN "app_building"', ) assert response.content == {"building": {"name": "1"}}