Skip to content

Commit

Permalink
Support for include and skip directive with references
Browse files Browse the repository at this point in the history
  • Loading branch information
kmoore-intuit committed Dec 7, 2023
1 parent 1409d7f commit f386387
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
package com.intuit.graphql.orchestrator.batch;

import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey;
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
import static graphql.schema.FieldCoordinates.coordinates;
import static graphql.util.TreeTransformerUtil.changeNode;
import static graphql.util.TreeTransformerUtil.deleteNode;
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;

import com.intuit.graphql.orchestrator.authorization.FieldAuthorization;
import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationEnvironment;
import com.intuit.graphql.orchestrator.authorization.FieldAuthorizationResult;
Expand Down Expand Up @@ -40,17 +28,31 @@
import graphql.schema.GraphQLUnionType;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import lombok.Builder;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Builder;
import lombok.NonNull;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;

import static com.intuit.graphql.orchestrator.resolverdirective.FieldResolverDirectiveUtil.hasResolverDirective;
import static com.intuit.graphql.orchestrator.utils.QueryDirectivesUtil.shouldIgnoreNode;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.getNodesAsPathList;
import static com.intuit.graphql.orchestrator.utils.QueryPathUtils.pathListToFQN;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.convertGraphqlFieldWithOriginalName;
import static com.intuit.graphql.orchestrator.utils.RenameDirectiveUtil.getRenameKey;
import static graphql.introspection.Introspection.TypeNameMetaFieldDef;
import static graphql.schema.FieldCoordinates.coordinates;
import static graphql.util.TreeTransformerUtil.changeNode;
import static graphql.util.TreeTransformerUtil.deleteNode;
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;

/**
* This class modifies for query for a downstream provider.
Expand Down Expand Up @@ -91,6 +93,12 @@ public TraversalControl visitField(Field node, TraverserContext<Node> context) {
requireNonNull(fieldDefinition, "Failed to get Field Definition for " + node.getName());

context.setVar(GraphQLType.class, fieldDefinition.getType());

if(shouldIgnoreNode(node, this.queryVariables)) {
decreaseParentSelectionSetCount(context.getParentContext());
return deleteNode(context);
}

FieldAuthorizationResult fieldAuthorizationResult = authorize(node, fieldDefinition, parentType, context);
if (!fieldAuthorizationResult.isAllowed()) {
decreaseParentSelectionSetCount(context.getParentContext());
Expand All @@ -112,8 +120,10 @@ public TraversalControl visitField(Field node, TraverserContext<Node> context) {
GraphQLFieldDefinition fieldDefinition = getFieldDefinition(node.getName(), parentType);
requireNonNull(fieldDefinition, "Failed to get Field Definition for " + node.getName());

if (serviceMetadata.shouldModifyDownStreamQuery() && (hasResolverDirective(fieldDefinition)
|| isExternalField(parentType.getName(), node.getName()))) {
boolean shouldRemoveNode = (serviceMetadata.shouldModifyDownStreamQuery() && (hasResolverDirective(fieldDefinition)
|| isExternalField(parentType.getName(), node.getName())))
|| shouldIgnoreNode(node, this.queryVariables);
if (shouldRemoveNode) {
decreaseParentSelectionSetCount(context.getParentContext());
return deleteNode(context);
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.intuit.graphql.orchestrator.utils;

import graphql.language.Argument;
import graphql.language.BooleanValue;
import graphql.language.Directive;
import graphql.language.Field;
import graphql.language.Value;
import graphql.language.VariableReference;

import java.util.Map;
import java.util.Optional;

public class QueryDirectivesUtil {

public static boolean shouldIgnoreNode(Field node, Map<String, Object> queryVariables) {
Optional<Directive> optionalIncludesDir = node.getDirectives("include").stream().findFirst();
Optional<Directive> optionalSkipDir = node.getDirectives("skip").stream().findFirst();
if(optionalIncludesDir.isPresent() || optionalSkipDir.isPresent()) {
if(optionalIncludesDir.isPresent() && (!getIfValue(optionalIncludesDir.get(), queryVariables))) {
return true;
}
return optionalSkipDir.isPresent() && (getIfValue(optionalSkipDir.get(), queryVariables));
}

return false;
}

private static boolean getIfValue(Directive directive, Map<String, Object> queryVariables){
Argument ifArg = directive.getArgument("if");
Value ifValue = ifArg.getValue();

boolean defaultValue = directive.getName().equals("skip");

if(ifValue instanceof VariableReference) {
String variableRefName = ((VariableReference) ifValue).getName();
return (boolean) queryVariables.getOrDefault(variableRefName, defaultValue);
} else if(ifValue instanceof BooleanValue) {
return ((BooleanValue) ifValue).isValue();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ class AuthDownstreamQueryRedactorVisitorSpec extends Specification {
}
"""

def skipQuery = """ query skipQuery(\$shouldSkip: Boolean) {
a {
b1 @skip(if: \$shouldSkip) {
c1 {
s1
}
}
b2 {
i1
}
}
}
"""

def includesQuery = """ query includesQuery(\$shouldInclude: Boolean) {
a {
b1 {
c1 {
s1
}
}
b2 @include(if: \$shouldInclude) {
i1
}
}
}
"""

static final Object TEST_AUTH_DATA = "TestAuthDataCanBeAnyObject"

Field mockField = Mock()
Expand Down Expand Up @@ -134,6 +162,171 @@ class AuthDownstreamQueryRedactorVisitorSpec extends Specification {
argumentValueResolver.resolve(_, _, _) >> Collections.emptyMap()
}

def "skip query with skip directive true removes selection set"() {
given:
Document document = new Parser().parseDocument(skipQuery)
OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0)
Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet())
GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query")

Map<String, Object> queryVariables = new HashMap<>()
queryVariables.put("shouldSkip", true)

AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder()
.rootParentType((GraphQLFieldsContainer) rootFieldParentType)
.fieldAuthorization(mockFieldAuthorization)
.graphQLContext(mockGraphQLContext)
.queryVariables(queryVariables)
.graphQLSchema(testGraphQLSchema)
.selectionCollector(new SelectionCollector(fragmentsByName))
.serviceMetadata(mockServiceMetadata)
.authData(TEST_AUTH_DATA)
.build()

when:
Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest)

then:
transformedField.getName() == "a"
Object[] selectionSet = transformedField.getSelectionSet()
.getSelections()
.asList()
selectionSet.size() == 1
((Field)selectionSet.first()).getName() == ("b2")

1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap()
mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata
}

def "skip query with skip directive false keeps selection set"() {
given:
Document document = new Parser().parseDocument(skipQuery)
OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0)
Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet())
GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query")

Map<String, Object> queryVariables = new HashMap<>()
queryVariables.put("shouldSkip", false)

AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder()
.rootParentType((GraphQLFieldsContainer) rootFieldParentType)
.fieldAuthorization(mockFieldAuthorization)
.graphQLContext(mockGraphQLContext)
.queryVariables(queryVariables)
.graphQLSchema(testGraphQLSchema)
.selectionCollector(new SelectionCollector(fragmentsByName))
.serviceMetadata(mockServiceMetadata)
.authData(TEST_AUTH_DATA)
.build()

when:
Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest)

then:
transformedField.getName() == "a"
Object[] selectionSet = transformedField.getSelectionSet()
.getSelections()
.asList()
selectionSet.size() == 2
((Field)selectionSet[0]).getName() == "b1"
((Field)selectionSet[1]).getName() == "b2"

1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap()
mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata
}

def "includes query with include directive true keeps selection set"() {
given:
Document document = new Parser().parseDocument(includesQuery)
OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0)
Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet())
GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query")

Map<String, Object> queryVariables = new HashMap<>()
queryVariables.put("shouldInclude", true)

AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder()
.rootParentType((GraphQLFieldsContainer) rootFieldParentType)
.fieldAuthorization(mockFieldAuthorization)
.graphQLContext(mockGraphQLContext)
.queryVariables(queryVariables)
.graphQLSchema(testGraphQLSchema)
.selectionCollector(new SelectionCollector(fragmentsByName))
.serviceMetadata(mockServiceMetadata)
.authData(TEST_AUTH_DATA)
.build()

when:
Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest)

then:
transformedField.getName() == "a"
Object[] selectionSet = transformedField.getSelectionSet()
.getSelections()
.asList()
selectionSet.size() == 2
((Field)selectionSet[0]).getName() == "b1"
((Field)selectionSet[1]).getName() == "b2"

1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(aB2) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(b2i1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap()
mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata
}

def "includes query with include directive false removes selection set"() {
given:
Document document = new Parser().parseDocument(includesQuery)
OperationDefinition operationDefinition = document.getDefinitionsOfType(OperationDefinition.class).get(0)
Field rootField = SelectionSetUtil.getFieldByPath(Arrays.asList("a"), operationDefinition.getSelectionSet())
GraphQLFieldsContainer rootFieldParentType = (GraphQLFieldsContainer) testGraphQLSchema.getType("Query")

Map<String, Object> queryVariables = new HashMap<>()
queryVariables.put("shouldInclude", false)

AuthDownstreamQueryModifier specUnderTest = AuthDownstreamQueryModifier.builder()
.rootParentType((GraphQLFieldsContainer) rootFieldParentType)
.fieldAuthorization(mockFieldAuthorization)
.graphQLContext(mockGraphQLContext)
.queryVariables(queryVariables)
.graphQLSchema(testGraphQLSchema)
.selectionCollector(new SelectionCollector(fragmentsByName))
.serviceMetadata(mockServiceMetadata)
.authData(TEST_AUTH_DATA)
.build()

when:
Field transformedField = (Field) astTransformer.transform(rootField, specUnderTest)

then:
transformedField.getName() == "a"
Object[] selectionSet = transformedField.getSelectionSet()
.getSelections()
.asList()
selectionSet.size() == 1
((Field)selectionSet[0]).getName() == "b1"

1 * mockFieldAuthorization.authorize(queryA) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(aB1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(b1C1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
1 * mockFieldAuthorization.authorize(c1S1) >> FieldAuthorizationResult.ALLOWED_FIELD_AUTH_RESULT
mockRenamedMetadata.getOriginalFieldNamesByRenamedName() >> Collections.emptyMap()
mockServiceMetadata.getRenamedMetadata() >> mockRenamedMetadata
}

def "redact query, results to empty selection set"() {
given:

Expand Down
Loading

0 comments on commit f386387

Please sign in to comment.