diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java index 14ffcf2d..51c9c8a1 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilter.java @@ -6,25 +6,20 @@ import graphql.analysis.QueryVisitorInlineFragmentEnvironment; import graphql.analysis.QueryVisitorStub; import graphql.language.Argument; -import graphql.language.AstTransformer; import graphql.language.Document; import graphql.language.Field; import graphql.language.FragmentDefinition; import graphql.language.FragmentSpread; import graphql.language.InlineFragment; import graphql.language.Node; -import graphql.language.NodeVisitorStub; import graphql.language.OperationDefinition; import graphql.language.Value; import graphql.language.VariableReference; import graphql.schema.GraphQLObjectType; import graphql.schema.GraphQLSchema; -import graphql.util.TraversalControl; -import graphql.util.TraverserContext; import lombok.Getter; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; @@ -36,8 +31,6 @@ */ public class VariableDefinitionFilter { - private static AstTransformer astTransformer = new AstTransformer(); - /** * Traverses a GraphQL Node and returns all VariableReference names used in all nodes in the graph. * @@ -50,17 +43,17 @@ public class VariableDefinitionFilter { * reference indicator prefix '$' will be <b>excluded</b> in the result. */ public Set<String> getVariableReferencesFromNode(GraphQLSchema graphQLSchema, GraphQLObjectType rootType, - Map<String, FragmentDefinition> fragmentsByName, Map<String, Object> variables, Node<?> rootNode) { + Map<String, FragmentDefinition> fragmentsByName, Map<String, Object> variables, Node<?> rootNode) { final VariableReferenceVisitor variableReferenceVisitor = new VariableReferenceVisitor(); //need to utilize a better pattern for creating mockable QueryTraverser/QueryTransformer QueryTraverser queryTraverser = QueryTraverser.newQueryTraverser() - .schema(graphQLSchema) - .rootParentType(rootType) //need to support also for subscription - .fragmentsByName(fragmentsByName) - .variables(variables) - .root(rootNode) - .build(); + .schema(graphQLSchema) + .rootParentType(rootType) //need to support also for subscription + .fragmentsByName(fragmentsByName) + .variables(variables) + .root(rootNode) + .build(); queryTraverser.visitPreOrder(variableReferenceVisitor); @@ -75,28 +68,16 @@ public Set<String> getVariableReferencesFromNode(GraphQLSchema graphQLSchema, Gr Set<VariableReference> additionalReferences = operationDirectiveVariableReferences(operationDefinitions); - Stream<VariableReference> variableReferenceStream; - if((variableReferenceVisitor.getVariableReferences().size() + additionalReferences.size()) != variables.size()) { - NodeTraverser nodeTraverser = new NodeTraverser(); - astTransformer.transform(rootNode, nodeTraverser); - - variableReferenceStream = Stream.of(variableReferenceVisitor.getVariableReferences(), - additionalReferences, - nodeTraverser.getVariableReferenceExtractor().getVariableReferences()) - .flatMap(Collection::stream); - } else { - variableReferenceStream = Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream()); - } - return variableReferenceStream.map(VariableReference::getName).collect(Collectors.toSet()); - + return Stream.concat(variableReferenceVisitor.getVariableReferences().stream(), additionalReferences.stream()) + .map(VariableReference::getName).collect(Collectors.toSet()); } private Set<VariableReference> operationDirectiveVariableReferences(List<OperationDefinition> operationDefinitions) { final List<Value> values = operationDefinitions.stream() - .flatMap(operationDefinition -> operationDefinition.getDirectives().stream()) - .flatMap(directive -> directive.getArguments().stream()) - .map(Argument::getValue) - .collect(Collectors.toList()); + .flatMap(operationDefinition -> operationDefinition.getDirectives().stream()) + .flatMap(directive -> directive.getArguments().stream()) + .map(Argument::getValue) + .collect(Collectors.toList()); VariableReferenceExtractor extractor = new VariableReferenceExtractor(); extractor.captureVariableReferences(values); @@ -138,7 +119,7 @@ public void visitField(final QueryVisitorFieldEnvironment env) { } final Stream<Argument> directiveArgumentStream = field.getDirectives().stream() - .flatMap(directive -> directive.getArguments().stream()); + .flatMap(directive -> directive.getArguments().stream()); final Stream<Argument> fieldArgumentStream = field.getArguments().stream(); @@ -154,7 +135,7 @@ public void visitInlineFragment(final QueryVisitorInlineFragmentEnvironment env) } Stream<Argument> arguments = env.getInlineFragment().getDirectives().stream() - .flatMap(directive -> directive.getArguments().stream()); + .flatMap(directive -> directive.getArguments().stream()); captureVariableReferences(arguments); } @@ -169,8 +150,8 @@ public void visitFragmentSpread(final QueryVisitorFragmentSpreadEnvironment env) } final Stream<Argument> allArguments = Stream.concat( - fragmentDefinition.getDirectives().stream(), - fragmentSpread.getDirectives().stream() + fragmentDefinition.getDirectives().stream(), + fragmentSpread.getDirectives().stream() ).flatMap(directive -> directive.getArguments().stream()); captureVariableReferences(allArguments); @@ -178,24 +159,9 @@ public void visitFragmentSpread(final QueryVisitorFragmentSpreadEnvironment env) private void captureVariableReferences(Stream<Argument> arguments) { final List<Value> values = arguments.map(Argument::getValue) - .collect(Collectors.toList()); + .collect(Collectors.toList()); variableReferenceExtractor.captureVariableReferences(values); } } - - static class NodeTraverser extends NodeVisitorStub { - - @Getter - private final VariableReferenceExtractor variableReferenceExtractor = new VariableReferenceExtractor(); - - public TraversalControl visitArgument(Argument node, TraverserContext<Node> context) { - return this.visitNode(node, context); - } - - public TraversalControl visitVariableReference(VariableReference node, TraverserContext<Node> context) { - variableReferenceExtractor.captureVariableReference(node); - return this.visitValue(node, context); - } - } } diff --git a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java index ca20637c..218530ae 100644 --- a/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java +++ b/src/main/java/com/intuit/graphql/orchestrator/batch/VariableReferenceExtractor.java @@ -19,14 +19,10 @@ public Set<VariableReference> getVariableReferences() { public void captureVariableReferences(List<Value> values) { for (final Value value : values) { - captureVariableReference(value); + doSwitch(value); } } - public void captureVariableReference(Value value) { - doSwitch(value); - } - private void doSwitch(Value value) { if (value instanceof ArrayValue) { handleArrayValue((ArrayValue) value); diff --git a/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy b/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy index 67e11fc5..d4258f94 100644 --- a/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy +++ b/src/test/groovy/com/intuit/graphql/orchestrator/batch/VariableDefinitionFilterSpec.groovy @@ -47,17 +47,6 @@ class VariableDefinitionFilterSpec extends Specification { directive @field_directive_argument(arg: InputObject) on FIELD_DEFINITION ''' - private String schema2 = ''' - type Query { person: Person } - - type Person { - address : Address - id: String - } - - type Address { city: String state: String zip: String } - ''' - private GraphQLSchema graphQLSchema private VariableDefinitionFilter variableDefinitionFilter @@ -74,12 +63,6 @@ class VariableDefinitionFilterSpec extends Specification { RuntimeWiring.newRuntimeWiring().build()) } - private GraphQLSchema getSchema2() { - return new SchemaGenerator() - .makeExecutableSchema(new SchemaParser().parse(schema2), - RuntimeWiring.newRuntimeWiring().build()) - } - private Map<String, FragmentDefinition> getFragmentsByName(Document document) { return document.getDefinitionsOfType(FragmentDefinition.class).stream() .inject([:]) {map, it -> map << [(it.getName()): it]} @@ -196,62 +179,6 @@ class VariableDefinitionFilterSpec extends Specification { results.containsAll("int_arg", "string_arg") } - def "variable References In Built in Query Directive includes"() { - given: - String query = ''' - query($includeContext: Boolean!) { - consumer { - liabilities(arg: 1) @include(if: $includeContext) { - totalDebt(arg: 1) - } - income - } - } - ''' - - Document document = parser.parseDocument(query) - HashMap<String, Object> variables = new HashMap<>() - variables.put("includeContext", false) - - when: - final Set<String> results = variableDefinitionFilter - .getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(), - variables, document) - - then: - results.size() == 1 - - results.containsAll("includeContext") - } - - def "variable References In Built in Query Directive skip"() { - given: - String query = ''' - query($includeContext: Boolean!) { - consumer { - liabilities(arg: 1) @skip(if: $includeContext) { - totalDebt(arg: 1) - } - income - } - } - ''' - - Document document = parser.parseDocument(query) - HashMap<String, Object> variables = new HashMap<>() - variables.put("includeContext", true) - - when: - final Set<String> results = variableDefinitionFilter - .getVariableReferencesFromNode(graphQLSchema, graphQLSchema.getQueryType(), Collections.emptyMap(), - variables, document) - - then: - results.size() == 1 - - results.containsAll("includeContext") - } - def "test Negative Cases"() { given: final String negativeTestCaseQuery = "query { consumer { liabilities { totalDebt(arg: 1234) } } }"