Skip to content

Commit

Permalink
Support the Node interface
Browse files Browse the repository at this point in the history
  • Loading branch information
MrThearMan committed Oct 9, 2024
1 parent 72ef714 commit 64afd18
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 14 deletions.
2 changes: 2 additions & 0 deletions example_project/app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@


class Query(graphene.ObjectType):
node = relay.Node.Field()

all_postal_codes = DjangoListField(PostalCodeType)
all_developers = DjangoListField(DeveloperType)
all_property_managers = DjangoListField(PropertyManagerType)
Expand Down
45 changes: 31 additions & 14 deletions query_optimizer/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.db.models import Field, ForeignKey, Model
from graphene import Connection, ObjectType, PageInfo
from graphene.relay.node import AbstractNode
from graphene.types.definitions import GrapheneObjectType, GrapheneUnionType
from graphene.types.definitions import GrapheneInterfaceType, GrapheneObjectType, GrapheneUnionType
from graphene.utils.str_converters import to_snake_case
from graphene_django import DjangoObjectType
from graphql import (
Expand All @@ -16,6 +16,7 @@
FragmentSpreadNode,
GraphQLField,
GraphQLOutputType,
GraphQLSchema,
InlineFragmentNode,
SelectionNode,
)
Expand Down Expand Up @@ -182,14 +183,17 @@ def handle_fragment_spread(self, field_type: GrapheneObjectType, fragment_spread
selections = get_selections(fragment_definition)
return self.handle_selections(field_type, selections)

def handle_inline_fragment(self, field_type: GrapheneUnionType, inline_fragment: InlineFragmentNode) -> None:
fragment_type = get_fragment_type(field_type, inline_fragment)
def handle_inline_fragment(
self,
field_type: GrapheneUnionType | GrapheneInterfaceType,
inline_fragment: InlineFragmentNode,
) -> None:
fragment_type = get_fragment_type(field_type, inline_fragment, self.info.schema)
fragment_model: type[Model] = fragment_type.graphene_type._meta.model
if fragment_model != self.model:
return None

selections = get_selections(inline_fragment)
return self.handle_selections(fragment_type, selections)
if fragment_model == self.model:
selections = get_selections(inline_fragment)
return self.handle_selections(fragment_type, selections)
return None

def get_graphene_type(self, field_type: GrapheneObjectType, field_node: FieldNode) -> GrapheneType:
graphql_field = get_field_def(self.info.schema, field_type, field_node)
Expand Down Expand Up @@ -266,14 +270,27 @@ def is_to_one(field: Field) -> TypeGuard[ToOneField]:
return bool(field.many_to_one or field.one_to_one)


def get_fragment_type(field_type: GrapheneUnionType, inline_fragment: InlineFragmentNode) -> GrapheneObjectType:
def get_fragment_type(
field_type: GrapheneUnionType | GrapheneInterfaceType,
inline_fragment: InlineFragmentNode,
schema: GraphQLSchema,
) -> GrapheneObjectType:
fragment_type_name = inline_fragment.type_condition.name.value
gen = (t for t in field_type.types if t.name == fragment_type_name)
fragment_type: Optional[GrapheneObjectType] = next(gen, None)

if fragment_type is None: # pragma: no cover
msg = f"Fragment type '{fragment_type_name}' not found in union '{field_type}'"
raise OptimizerError(msg)
# For unions, fetch the type from in the union.
if isinstance(field_type, GrapheneUnionType):
gen = (t for t in field_type.types if t.name == fragment_type_name)
fragment_type: Optional[GrapheneObjectType] = next(gen, None)
if fragment_type is None: # pragma: no cover
msg = f"Fragment type '{fragment_type_name}' not found in union '{field_type}'"
raise OptimizerError(msg)

# For interfaces, fetch the type from in the schema.
else:
fragment_type: Optional[GrapheneObjectType] = schema.get_type(fragment_type_name)
if fragment_type is None: # pragma: no cover
msg = f"Fragment type '{fragment_type_name}' not found in schema."
raise OptimizerError(msg)

return fragment_type

Expand Down
31 changes: 31 additions & 0 deletions tests/test_relay_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,36 @@
]


def test_relay__global_node(graphql_client):
apartment = ApartmentFactory.create(building__name="1")
global_id = to_global_id(str(ApartmentNode), apartment.pk)

query = """
query {
node(id: "%s") {
... on ApartmentNode {
building {
name
}
}
}
}
""" % (global_id,)

response = graphql_client(query)
assert response.no_errors, response.errors

# 1 query for fetching apartment and related buildings
assert response.queries.count == 1, response.queries.log

assert response.queries[0] == has(
'FROM "app_apartment"',
'INNER JOIN "app_building"',
)

assert response.content == {"building": {"name": "1"}}


def test_relay__node(graphql_client):
apartment = ApartmentFactory.create(building__name="1")
global_id = to_global_id(str(ApartmentNode), apartment.pk)
Expand All @@ -32,6 +62,7 @@ def test_relay__node(graphql_client):

assert response.queries[0] == has(
'FROM "app_apartment"',
'INNER JOIN "app_building"',
)

assert response.content == {"building": {"name": "1"}}
Expand Down

0 comments on commit 64afd18

Please sign in to comment.