Skip to content

Commit

Permalink
Update validation rules
Browse files Browse the repository at this point in the history
Ensure consistent behaviour with graphql-js, matching expected
test outputs. In many cases this has been a change to quoting and
formatting of identifiers, however other changes to application of
rules have also been applied to satisfy they required behaviours.

Rule names exercised by these tests are updated to match what has
changed in graphql-js. Other rules however have been left as-is, with
none of the tests from upstream validating those. As additional tests
are enabled and behaviour brought inline with that, these should be
updated at that time.
  • Loading branch information
dackroyd committed Apr 2, 2024
1 parent b084162 commit ad43f96
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 55 deletions.
4 changes: 2 additions & 2 deletions graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5211,8 +5211,8 @@ func TestCircularFragmentMaxDepth(t *testing.T) {
}
`,
ExpectedErrors: []*gqlerrors.QueryError{{
Message: `Cannot spread fragment "X" within itself via Y.`,
Rule: "NoFragmentCycles",
Message: `Cannot spread fragment "X" within itself via "Y".`,
Rule: "NoFragmentCyclesRule",
Locations: []gqlerrors.Location{
{Line: 7, Column: 20},
{Line: 10, Column: 20},
Expand Down
2 changes: 1 addition & 1 deletion internal/common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func ResolveType(t ast.Type, resolver Resolver) (ast.Type, *errors.QueryError) {
refT := resolver(t.Name)
if refT == nil {
err := errors.Errorf("Unknown type %q.", t.Name)
err.Rule = "KnownTypeNames"
err.Rule = "KnownTypeNamesRule"
err.Locations = []errors.Location{t.Loc}
return nil, err
}
Expand Down
160 changes: 114 additions & 46 deletions internal/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type varSet map[*ast.InputValueDefinition]struct{}

type selectionPair struct{ a, b ast.Selection }

type nameSet map[string]errors.Location
type nameSet map[string][]errors.Location

type fieldInfo struct {
sf *ast.FieldDefinition
Expand Down Expand Up @@ -68,7 +68,7 @@ func newContext(s *ast.Schema, doc *ast.ExecutableDefinition, maxDepth int) *con
func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string]interface{}, maxDepth int) []*errors.QueryError {
c := newContext(s, doc, maxDepth)

opNames := make(nameSet)
opNames := make(nameSet, len(doc.Operations))
fragUsedBy := make(map[*ast.FragmentDefinition][]*ast.OperationDefinition)
for _, op := range doc.Operations {
c.usedVars[op] = make(varSet)
Expand All @@ -81,21 +81,22 @@ func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string
}

if op.Name.Name == "" && len(doc.Operations) != 1 {
c.addErr(op.Loc, "LoneAnonymousOperation", "This anonymous operation must be the only defined operation.")
}
if op.Name.Name != "" {
validateName(c, opNames, op.Name, "UniqueOperationNames", "operation")
c.addErr(op.Loc, "LoneAnonymousOperationRule", "This anonymous operation must be the only defined operation.")
}

validateDirectives(opc, string(op.Type), op.Directives)
if n := op.Name.Name; n != "" {
opNames[n] = append(opNames[n], op.Name.Loc)
}

varNames := make(nameSet)
varNames := make(nameSet, len(op.Vars))
for _, v := range op.Vars {
validateName(c, varNames, v.Name, "UniqueVariableNames", "variable")
varNames[v.Name.Name] = append(varNames[v.Name.Name], v.Name.Loc)

validateDirectives(opc, "VARIABLE_DEFINITION", v.Directives)

t := resolveType(c, v.Type)
if !canBeInput(t) {
c.addErr(v.TypeLoc, "VariablesAreInputTypes", "Variable %q cannot be non-input type %q.", "$"+v.Name.Name, t)
c.addErr(v.TypeLoc, "VariablesAreInputTypesRule", "Variable %q cannot be non-input type %q.", "$"+v.Name.Name, t)
}
validateValue(opc, v, variables[v.Name.Name], t)

Expand All @@ -114,6 +115,12 @@ func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string
}
}

validateDirectives(opc, string(op.Type), op.Directives)

for n, locs := range varNames {
validateName(c, locs, n, "UniqueVariableNamesRule", "variable")
}

var entryPoint ast.NamedType
switch op.Type {
case query.Query:
Expand All @@ -135,18 +142,23 @@ func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string
}
}

fragNames := make(nameSet)
for n, locs := range opNames {
validateName(c, locs, n, "UniqueOperationNamesRule", "operation")
}

fragNames := make(nameSet, len(doc.Fragments))
fragVisited := make(map[*ast.FragmentDefinition]struct{})
for _, frag := range doc.Fragments {
opc := &opContext{c, fragUsedBy[frag]}

validateName(c, fragNames, frag.Name, "UniqueFragmentNames", "fragment")
fragNames[frag.Name.Name] = append(fragNames[frag.Name.Name], frag.Name.Loc)

validateDirectives(opc, "FRAGMENT_DEFINITION", frag.Directives)

t := unwrapType(resolveType(c, &frag.On))
// continue even if t is nil
if t != nil && !canBeFragment(t) {
c.addErr(frag.On.Loc, "FragmentsOnCompositeTypes", "Fragment %q cannot condition on non composite type %q.", frag.Name.Name, t)
c.addErr(frag.On.Loc, "FragmentsOnCompositeTypesRule", "Fragment %q cannot condition on non composite type %q.", frag.Name.Name, t)
continue
}

Expand All @@ -157,9 +169,13 @@ func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string
}
}

for n, locs := range fragNames {
validateName(c, locs, n, "UniqueFragmentNamesRule", "fragment")
}

for _, frag := range doc.Fragments {
if len(fragUsedBy[frag]) == 0 {
c.addErr(frag.Loc, "NoUnusedFragments", "Fragment %q is never used.", frag.Name.Name)
c.addErr(frag.Loc, "NoUnusedFragmentsRule", "Fragment %q is never used.", frag.Name.Name)
}
}

Expand All @@ -173,7 +189,7 @@ func Validate(s *ast.Schema, doc *ast.ExecutableDefinition, variables map[string
if op.Name.Name != "" {
opSuffix = fmt.Sprintf(" in operation %q", op.Name.Name)
}
c.addErr(v.Loc, "NoUnusedVariables", "Variable %q is never used%s.", "$"+v.Name.Name, opSuffix)
c.addErr(v.Loc, "NoUnusedVariablesRule", "Variable %q is never used%s.", "$"+v.Name.Name, opSuffix)
}
}
}
Expand Down Expand Up @@ -331,7 +347,7 @@ func validateSelection(c *opContext, sel ast.Selection, t ast.NamedType) {
f = fields(t).Get(fieldName)
if f == nil && t != nil {
suggestion := makeSuggestion("Did you mean", fields(t).Names(), fieldName)
c.addErr(sel.Alias.Loc, "FieldsOnCorrectType", "Cannot query field %q on type %q.%s", fieldName, t, suggestion)
c.addErr(sel.Alias.Loc, "FieldsOnCorrectTypeRule", "Cannot query field %q on type %q.%s", fieldName, t, suggestion)
}
}
c.fieldMap[sel] = fieldInfo{sf: f, parent: t}
Expand All @@ -349,10 +365,10 @@ func validateSelection(c *opContext, sel ast.Selection, t ast.NamedType) {
ft = f.Type
sf := hasSubfields(ft)
if sf && sel.SelectionSet == nil {
c.addErr(sel.Alias.Loc, "ScalarLeafs", "Field %q of type %q must have a selection of subfields. Did you mean \"%s { ... }\"?", fieldName, ft, fieldName)
c.addErr(sel.Alias.Loc, "ScalarLeafsRule", "Field %q of type %q must have a selection of subfields. Did you mean \"%s { ... }\"?", fieldName, ft, fieldName)
}
if !sf && sel.SelectionSet != nil {
c.addErr(sel.SelectionSetLoc, "ScalarLeafs", "Field %q must not have a selection since type %q has no subfields.", fieldName, ft)
c.addErr(sel.SelectionSetLoc, "ScalarLeafsRule", "Field %q must not have a selection since type %q has no subfields.", fieldName, ft)
}
}
if sel.SelectionSet != nil {
Expand All @@ -370,7 +386,7 @@ func validateSelection(c *opContext, sel ast.Selection, t ast.NamedType) {
// continue even if t is nil
}
if t != nil && !canBeFragment(t) {
c.addErr(sel.On.Loc, "FragmentsOnCompositeTypes", "Fragment cannot condition on non composite type %q.", t)
c.addErr(sel.On.Loc, "FragmentsOnCompositeTypesRule", "Fragment cannot condition on non composite type %q.", t)
return
}
validateSelectionSet(c, sel.Selections, unwrapType(t))
Expand All @@ -379,7 +395,7 @@ func validateSelection(c *opContext, sel ast.Selection, t ast.NamedType) {
validateDirectives(c, "FRAGMENT_SPREAD", sel.Directives)
frag := c.doc.Fragments.Get(sel.Name.Name)
if frag == nil {
c.addErr(sel.Name.Loc, "KnownFragmentNames", "Unknown fragment %q.", sel.Name.Name)
c.addErr(sel.Name.Loc, "KnownFragmentNamesRule", "Unknown fragment %q.", sel.Name.Name)
return
}
fragTyp := c.schema.Types[frag.On.Name]
Expand Down Expand Up @@ -475,7 +491,7 @@ func detectFragmentCycleSel(c *context, sel ast.Selection, fragVisited map[*ast.
if len(cyclePath) > 1 {
names := make([]string, len(cyclePath)-1)
for i, frag := range cyclePath[:len(cyclePath)-1] {
names[i] = frag.Name.Name
names[i] = fmt.Sprintf("%q", frag.Name.Name)
}
via = " via " + strings.Join(names, ", ")
}
Expand All @@ -484,7 +500,7 @@ func detectFragmentCycleSel(c *context, sel ast.Selection, fragVisited map[*ast.
for i, frag := range cyclePath {
locs[i] = frag.Loc
}
c.addErrMultiLoc(locs, "NoFragmentCycles", "Cannot spread fragment %q within itself%s.", frag.Name.Name, via)
c.addErrMultiLoc(locs, "NoFragmentCyclesRule", "Cannot spread fragment %q within itself%s.", frag.Name.Name, via)
return
}

Expand Down Expand Up @@ -523,7 +539,7 @@ func (c *context) validateOverlap(a, b ast.Selection, reasons *[]string, locs *[
if reasons2, locs2 := c.validateFieldOverlap(a, b); len(reasons2) != 0 {
locs2 = append(locs2, a.Alias.Loc, b.Alias.Loc)
if reasons == nil {
c.addErrMultiLoc(locs2, "OverlappingFieldsCanBeMerged", "Fields %q conflict because %s. Use different aliases on the fields to fetch both if this was intentional.", a.Alias.Name, strings.Join(reasons2, " and "))
c.addErrMultiLoc(locs2, "OverlappingFieldsCanBeMergedRule", "Fields %q conflict because %s. Use different aliases on the fields to fetch both if this was intentional.", a.Alias.Name, strings.Join(reasons2, " and "))
return
}
for _, r := range reasons2 {
Expand Down Expand Up @@ -573,7 +589,7 @@ func (c *context) validateFieldOverlap(a, b *ast.Field) ([]string, []errors.Loca
if asf := c.fieldMap[a].sf; asf != nil {
if bsf := c.fieldMap[b].sf; bsf != nil {
if !typesCompatible(asf.Type, bsf.Type) {
return []string{fmt.Sprintf("they return conflicting types %s and %s", asf.Type, bsf.Type)}, nil
return []string{fmt.Sprintf("they return conflicting types %q and %q", asf.Type, bsf.Type)}, nil
}
}
}
Expand All @@ -582,7 +598,7 @@ func (c *context) validateFieldOverlap(a, b *ast.Field) ([]string, []errors.Loca
bt := c.fieldMap[b].parent
if at == nil || bt == nil || at == bt {
if a.Name.Name != b.Name.Name {
return []string{fmt.Sprintf("%s and %s are different fields", a.Name.Name, b.Name.Name)}, nil
return []string{fmt.Sprintf("%q and %q are different fields", a.Name.Name, b.Name.Name)}, nil
}

if argumentsConflict(a.Arguments, b.Arguments) {
Expand Down Expand Up @@ -651,18 +667,17 @@ func resolveType(c *context, t ast.Type) ast.Type {
}

func validateDirectives(c *opContext, loc string, directives ast.DirectiveList) {
directiveNames := make(nameSet)
directiveNames := make(nameSet, len(directives))
for _, d := range directives {
dirName := d.Name.Name
validateNameCustomMsg(c.context, directiveNames, d.Name, "UniqueDirectivesPerLocation", func() string {
return fmt.Sprintf("The directive %q can only be used once at this location.", dirName)
})

directiveNames[dirName] = append(directiveNames[dirName], d.Name.Loc)

validateArgumentLiterals(c, d.Arguments)

dd, ok := c.schema.Directives[dirName]
if !ok {
c.addErr(d.Name.Loc, "KnownDirectives", "Unknown directive %q.", dirName)
c.addErr(d.Name.Loc, "KnownDirectivesRule", "Unknown directive %q.", "@"+dirName)
continue
}

Expand All @@ -674,28 +689,57 @@ func validateDirectives(c *opContext, loc string, directives ast.DirectiveList)
}
}
if !locOK {
c.addErr(d.Name.Loc, "KnownDirectives", "Directive %q may not be used on %s.", dirName, loc)
c.addErr(d.Name.Loc, "KnownDirectivesRule", "Directive %q may not be used on %s.", "@"+dirName, loc)
}

validateArgumentTypes(c, d.Arguments, dd.Arguments, d.Name.Loc,
func() string { return fmt.Sprintf("directive %q", "@"+dirName) },
func() string { return fmt.Sprintf("Directive %q", "@"+dirName) },
)
}

for n := range directiveNames {
dd, ok := c.schema.Directives[n]
if !ok {
// Invalid directive will have been flagged already
continue
}

if dd.Repeatable {
continue
}

ds := directiveNames[n]
if len(ds) <= 1 {
continue
}

for _, loc := range ds[1:] {
// Duplicate directive errors are inconsistent with the behaviour for other types in graphql-js
// Instead of reporting a single error with all locations, errors are reported for each duplicate after the first declaration
// with the original location, and the duplicate. Behaviour is replicated here, as we use those tests to validate the implementation
validateNameCustomMsg(c.context, []errors.Location{ds[0], loc}, "UniqueDirectivesPerLocationRule", func() string {
return fmt.Sprintf("The directive %q can only be used once at this location.", "@"+n)
})
}
}
}

func validateName(c *context, set nameSet, name ast.Ident, rule string, kind string) {
validateNameCustomMsg(c, set, name, rule, func() string {
return fmt.Sprintf("There can be only one %s named %q.", kind, name.Name)
func validateName(c *context, locs []errors.Location, name string, rule string, kind string) {
validateNameCustomMsg(c, locs, rule, func() string {
if kind == "variable" {
return fmt.Sprintf("There can be only one %s named %q.", kind, "$"+name)
}

return fmt.Sprintf("There can be only one %s named %q.", kind, name)
})
}

func validateNameCustomMsg(c *context, set nameSet, name ast.Ident, rule string, msg func() string) {
if loc, ok := set[name.Name]; ok {
c.addErrMultiLoc([]errors.Location{loc, name.Loc}, rule, msg())
func validateNameCustomMsg(c *context, locs []errors.Location, rule string, msg func() string) {
if len(locs) > 1 {
c.addErrMultiLoc(locs, rule, msg())
return
}
set[name.Name] = name.Loc
}

func validateArgumentTypes(c *opContext, args ast.ArgumentList, argDecls ast.ArgumentsDefinition, loc errors.Location, owner1, owner2 func() string) {
Expand All @@ -713,28 +757,48 @@ func validateArgumentTypes(c *opContext, args ast.ArgumentList, argDecls ast.Arg
for _, decl := range argDecls {
if _, ok := decl.Type.(*ast.NonNull); ok {
if _, ok := args.Get(decl.Name.Name); !ok {
c.addErr(loc, "ProvidedNonNullArguments", "%s argument %q of type %q is required but not provided.", owner2(), decl.Name.Name, decl.Type)
if decl.Default != nil {
continue
}

c.addErr(loc, "ProvidedRequiredArgumentsRule", "%s argument %q of type %q is required, but it was not provided.", owner2(), decl.Name.Name, decl.Type)
}
}
}
}

func validateArgumentLiterals(c *opContext, args ast.ArgumentList) {
argNames := make(nameSet)
argNames := make(nameSet, len(args))
for _, arg := range args {
validateName(c.context, argNames, arg.Name, "UniqueArgumentNames", "argument")
validateLiteral(c, arg.Value)

argNames[arg.Name.Name] = append(argNames[arg.Name.Name], arg.Name.Loc)
}

for n, locs := range argNames {
validateName(c.context, locs, n, "UniqueArgumentNamesRule", "argument")
}
}

func validateLiteral(c *opContext, l ast.Value) {
switch l := l.(type) {
case *ast.ObjectValue:
fieldNames := make(nameSet)
fieldNames := make(nameSet, len(l.Fields))
for _, f := range l.Fields {
validateName(c.context, fieldNames, f.Name, "UniqueInputFieldNames", "input field")
fieldNames[f.Name.Name] = append(fieldNames[f.Name.Name], f.Name.Loc)
validateLiteral(c, f.Value)
}

for n, locs := range fieldNames {
if len(locs) <= 1 {
continue
}

// Similar to for directives, duplicates here aren't all reported together but using an error for each duplicate
for _, loc := range locs[1:] {
validateName(c.context, []errors.Location{locs[0], loc}, n, "UniqueInputFieldNamesRule", "input field")
}
}
case *ast.ListValue:
for _, entry := range l.Values {
validateLiteral(c, entry)
Expand All @@ -750,7 +814,7 @@ func validateLiteral(c *opContext, l ast.Value) {
c.opErrs[op] = append(c.opErrs[op], &errors.QueryError{
Message: fmt.Sprintf("Variable %q is not defined%s.", "$"+l.Name, byOp),
Locations: []errors.Location{l.Loc, op.Loc},
Rule: "NoUndefinedVariables",
Rule: "NoUndefinedVariablesRule",
})
continue
}
Expand All @@ -766,10 +830,12 @@ func validateValueType(c *opContext, v ast.Value, t ast.Type) (bool, string) {
if v2 := op.Vars.Get(v.Name); v2 != nil {
t2, err := common.ResolveType(v2.Type, c.schema.Resolve)
if _, ok := t2.(*ast.NonNull); !ok && v2.Default != nil {
t2 = &ast.NonNull{OfType: t2}
if _, ok := v2.Default.(*ast.NullValue); !ok {
t2 = &ast.NonNull{OfType: t2}
}
}
if err == nil && !typeCanBeUsedAs(t2, t) {
c.addErrMultiLoc([]errors.Location{v2.Loc, v.Loc}, "VariablesInAllowedPosition", "Variable %q of type %q used in position expecting type %q.", "$"+v.Name, t2, t)
c.addErrMultiLoc([]errors.Location{v2.Loc, v.Loc}, "VariablesInAllowedPositionRule", "Variable %q of type %q used in position expecting type %q.", "$"+v.Name, t2, t)
}
}
}
Expand Down Expand Up @@ -918,6 +984,8 @@ func canBeInput(t ast.Type) bool {
return canBeInput(t.OfType)
case *ast.NonNull:
return canBeInput(t.OfType)
case nil:
return true
default:
return false
}
Expand Down
Loading

0 comments on commit ad43f96

Please sign in to comment.