diff --git a/hack/middleware/asset_middleware.go b/hack/middleware/asset_middleware.go index 4750739bd..b8fb8177f 100644 --- a/hack/middleware/asset_middleware.go +++ b/hack/middleware/asset_middleware.go @@ -17,6 +17,10 @@ import ( type AssetUrlMiddleware struct { } +func (a *AssetUrlMiddleware) PrepareSchema(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error { + return nil +} + func (a *AssetUrlMiddleware) OnRequest(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error { w.SetLookup(l) diff --git a/pkg/lexer/lexer.go b/pkg/lexer/lexer.go index 713cd111f..19325a900 100644 --- a/pkg/lexer/lexer.go +++ b/pkg/lexer/lexer.go @@ -47,6 +47,23 @@ func (l *Lexer) SetTypeSystemInput(input []byte) error { return nil } +func (l *Lexer) ExtendTypeSystemInput(input []byte) error { + + if len(l.input) != l.typeSystemEndPosition { + return fmt.Errorf("ExtendTypeSystemInput: you must not extend the type system input after setting the executable input") + } + + actual := len(l.input) + len(input) + if actual > maxInput { + return fmt.Errorf("ExtendTypeSystemInput: input size must not be > %d, got: %d", maxInput, actual) + } + + l.input = append(l.input, input...) + l.typeSystemEndPosition = len(l.input) + + return nil +} + func (l *Lexer) ResetTypeSystemInput() { l.input = l._storage[:0] l.inputPosition = 0 diff --git a/pkg/lexer/lexer_test.go b/pkg/lexer/lexer_test.go index 16c99c6e4..525b3a391 100644 --- a/pkg/lexer/lexer_test.go +++ b/pkg/lexer/lexer_test.go @@ -442,6 +442,75 @@ baz mustPeekAndRead(keyword.FLOAT, "13.37"), ) }) + t.Run("extend type system input", func(t *testing.T) { + t.Run("invalid flow", func(t *testing.T) { + l := NewLexer() + err := l.SetTypeSystemInput([]byte("foo")) + if err != nil { + t.Fatal(err) + } + err = l.SetExecutableInput([]byte("bar")) + if err != nil { + t.Fatal(err) + } + err = l.ExtendTypeSystemInput([]byte("baz")) + if err == nil { + t.Fatal("want err") + } + }) + t.Run("valid flow", func(t *testing.T) { + l := NewLexer() + err := l.SetTypeSystemInput([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + foo := l.Read() + if string(l.ByteSlice(foo.Literal)) != "foo" { + t.Fatal("want foo") + } + + err = l.ExtendTypeSystemInput([]byte(" bar")) + if err != nil { + t.Fatal(err) + } + + bar := l.Read() + if string(l.ByteSlice(bar.Literal)) != "bar" { + t.Fatal("want bar") + } + + err = l.ExtendTypeSystemInput([]byte(" baz")) + if err != nil { + t.Fatal(err) + } + + baz := l.Read() + if string(l.ByteSlice(baz.Literal)) != "baz" { + t.Fatal("want baz") + } + + err = l.SetExecutableInput([]byte("bal bat")) + if err != nil { + t.Fatal(err) + } + + bal := l.Read() + if string(l.ByteSlice(bal.Literal)) != "bal" { + t.Fatal("want bal") + } + + err = l.SetTypeSystemInput([]byte("foo2")) + if err != nil { + t.Fatal(err) + } + + foo2 := l.Read() + if string(l.ByteSlice(foo2.Literal)) != "foo2" { + t.Fatal("want foo2") + } + }) + }) } var introspectionQuery = `query IntrospectionQuery { diff --git a/pkg/middleware/context_middleware.go b/pkg/middleware/context_middleware.go index 9717caf66..4c7514c49 100644 --- a/pkg/middleware/context_middleware.go +++ b/pkg/middleware/context_middleware.go @@ -49,6 +49,19 @@ query myDocuments { type ContextMiddleware struct { } +var contextMiddlewareSchemaExtension = []byte(` +directive @addArgumentFromContext( + name: String! + contextKey: String! +) on FIELD_DEFINITION`) + +func (a *ContextMiddleware) PrepareSchema(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error { + + err := parser.ExtendTypeSystemDefinition(contextMiddlewareSchemaExtension) + + return err +} + func (a *ContextMiddleware) OnResponse(ctx context.Context, response *[]byte, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) (err error) { return nil } diff --git a/pkg/middleware/context_middleware_test.go b/pkg/middleware/context_middleware_test.go index 4a58a68b4..a752dc51a 100644 --- a/pkg/middleware/context_middleware_test.go +++ b/pkg/middleware/context_middleware_test.go @@ -107,13 +107,6 @@ func TestContextMiddleware(t *testing.T) { } const publicSchema = ` -directive @addArgumentFromContext( - name: String! - contextKey: String! -) on FIELD_DEFINITION - -scalar String - schema { query: Query } diff --git a/pkg/middleware/graphql_middleware.go b/pkg/middleware/graphql_middleware.go index 7c5965930..4a7862cc8 100644 --- a/pkg/middleware/graphql_middleware.go +++ b/pkg/middleware/graphql_middleware.go @@ -6,7 +6,17 @@ import ( "github.com/jensneuse/graphql-go-tools/pkg/parser" ) +// GraphqlMiddleware is the interface to be implemented when writing middlewares type GraphqlMiddleware interface { + // PrepareSchema is used to bring the schema in a valid state + // Example usages might be: + // - adding necessary directives to the schema, e.g. adding the context directive so that the context middleware works + // - define the graphql internal scalar types so that the validation middleware can do its thing + PrepareSchema(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error + // OnRequest is the handler func for a request from the client + // this can be used to transform the query and/or variables before sending it to the backend OnRequest(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error + // OnResponse is the handler func for the response from the backend server + // this can be used to transform the response before sending the result back to the client OnResponse(ctx context.Context, response *[]byte, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) (err error) } diff --git a/pkg/middleware/invoke_middleware.go b/pkg/middleware/invoke_middleware.go index 670c682c8..7c0914417 100644 --- a/pkg/middleware/invoke_middleware.go +++ b/pkg/middleware/invoke_middleware.go @@ -3,41 +3,29 @@ package middleware import ( "bytes" "context" - "github.com/jensneuse/graphql-go-tools/pkg/lookup" - "github.com/jensneuse/graphql-go-tools/pkg/parser" - "github.com/jensneuse/graphql-go-tools/pkg/printer" ) // InvokeMiddleware is a one off middleware invocation helper // This should only be used for testing as it's a waste of resources // It makes use of panics to don't use this in production! func InvokeMiddleware(middleware GraphqlMiddleware, ctx context.Context, schema, request string) (result string, err error) { - parse := parser.NewParser() - if err = parse.ParseTypeSystemDefinition([]byte(schema)); err != nil { - return - } - if err = parse.ParseExecutableDefinition([]byte(request)); err != nil { + + invoker := NewInvoker(middleware) + err = invoker.SetSchema([]byte(schema)) + if err != nil { return } - astPrint := printer.New() - look := lookup.New(parse) - walk := lookup.NewWalker(1024, 8) - mod := parser.NewManualAstMod(parse) - walk.SetLookup(look) - if err = middleware.OnRequest(ctx, look, walk, parse, mod); err != nil { + err = invoker.InvokeMiddleWares(ctx, []byte(request)) + if err != nil { return } - walk.SetLookup(look) - walk.WalkExecutable() - - astPrint.SetInput(parse, look, walk) buff := bytes.Buffer{} - if err = astPrint.PrintExecutableSchema(&buff); err != nil { + err = invoker.RewriteRequest(&buff) + if err != nil { return } - result = buff.String() - return + return buff.String(), err } diff --git a/pkg/middleware/invoker.go b/pkg/middleware/invoker.go index 0fe9de4fb..e571db2f3 100644 --- a/pkg/middleware/invoker.go +++ b/pkg/middleware/invoker.go @@ -39,6 +39,11 @@ func (i *Invoker) SetSchema(schema []byte) error { func (i *Invoker) InvokeMiddleWares(ctx context.Context, request []byte) (err error) { + err = i.middlewaresPrepareSchema(ctx) + if err != nil { + return err + } + err = i.parse.ParseExecutableDefinition(request) if err != nil { return err @@ -46,7 +51,7 @@ func (i *Invoker) InvokeMiddleWares(ctx context.Context, request []byte) (err er i.walk.SetLookup(i.look) - return i.invokeMiddleWares(ctx) + return i.middlewaresOnRequest(ctx) } func (i *Invoker) RewriteRequest(w io.Writer) error { @@ -55,7 +60,17 @@ func (i *Invoker) RewriteRequest(w io.Writer) error { return i.astPrint.PrintExecutableSchema(w) } -func (i *Invoker) invokeMiddleWares(ctx context.Context) error { +func (i *Invoker) middlewaresPrepareSchema(ctx context.Context) error { + for j := range i.middleWares { + err := i.middleWares[j].PrepareSchema(ctx, i.look, i.walk, i.parse, i.mod) + if err != nil { + return err + } + } + return nil +} + +func (i *Invoker) middlewaresOnRequest(ctx context.Context) error { for j := range i.middleWares { err := i.middleWares[j].OnRequest(ctx, i.look, i.walk, i.parse, i.mod) if err != nil { diff --git a/pkg/middleware/validation_middleware.go b/pkg/middleware/validation_middleware.go index 6e42a0d4a..11be736ca 100644 --- a/pkg/middleware/validation_middleware.go +++ b/pkg/middleware/validation_middleware.go @@ -12,6 +12,32 @@ import ( type ValidationMiddleware struct { } +var validationMiddlewareSchemaExtension = []byte(` +scalar Int +scalar Float +scalar String +scalar Boolean +scalar ID +directive @include( +if: Boolean! +) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT +directive @skip( + if: Boolean! +) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT +directive @deprecated( + reason: String = "No longer supported" +) on FIELD_DEFINITION | ENUM_VALUE +`) + +// PrepareSchema adds the base scalar and directive types to the schema so that the user doesn't have to add them +// if we omit these definitions from the schema definition the validation will fail +func (v *ValidationMiddleware) PrepareSchema(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error { + + err := parser.ExtendTypeSystemDefinition(validationMiddlewareSchemaExtension) + + return err +} + func (v *ValidationMiddleware) OnRequest(ctx context.Context, l *lookup.Lookup, w *lookup.Walker, parser *parser.Parser, mod *parser.ManualAstMod) error { w.SetLookup(l) diff --git a/pkg/middleware/validation_middleware_test.go b/pkg/middleware/validation_middleware_test.go index 8f5cf00ed..f58c8c425 100644 --- a/pkg/middleware/validation_middleware_test.go +++ b/pkg/middleware/validation_middleware_test.go @@ -7,16 +7,31 @@ import ( func TestValidationMiddleware(t *testing.T) { t.Run("valid", func(t *testing.T) { query := `query myDocuments {documents {sensitiveInformation}}` - _, err := InvokeMiddleware(&ValidationMiddleware{}, nil, publicSchema, query) + _, err := InvokeMiddleware(&ValidationMiddleware{}, nil, validationMiddlewarePublicSchema, query) if err != nil { t.Fatal(err) } }) t.Run("invalid", func(t *testing.T) { query := `query myDocuments {documents {fieldNotExists}}` - _, err := InvokeMiddleware(&ValidationMiddleware{}, nil, publicSchema, query) + _, err := InvokeMiddleware(&ValidationMiddleware{}, nil, validationMiddlewarePublicSchema, query) if err == nil { t.Fatal("want err") } }) } + +const validationMiddlewarePublicSchema = ` +schema { + query: Query +} + +type Query { + documents: [Document] +} + +type Document implements Node { + owner: String + sensitiveInformation: String +} +` diff --git a/pkg/parser/executabledefinition_parser.go b/pkg/parser/executabledefinition_parser.go index c52c7f5e1..02cfabe00 100644 --- a/pkg/parser/executabledefinition_parser.go +++ b/pkg/parser/executabledefinition_parser.go @@ -5,9 +5,7 @@ import ( "github.com/jensneuse/graphql-go-tools/pkg/lexing/keyword" ) -func (p *Parser) parseExecutableDefinition() (executableDefinition document.ExecutableDefinition, err error) { - - executableDefinition = p.makeExecutableDefinition() +func (p *Parser) parseExecutableDefinition() (err error) { for { next := p.l.Peek(true) @@ -15,23 +13,23 @@ func (p *Parser) parseExecutableDefinition() (executableDefinition document.Exec switch next { case keyword.CURLYBRACKETOPEN: - err := p.parseAnonymousOperation(&executableDefinition) + err := p.parseAnonymousOperation(&p.ParsedDefinitions.ExecutableDefinition) if err != nil { - return executableDefinition, err + return err } case keyword.FRAGMENT: - err := p.parseFragmentDefinition(&executableDefinition.FragmentDefinitions) + err := p.parseFragmentDefinition(&p.ParsedDefinitions.ExecutableDefinition.FragmentDefinitions) if err != nil { - return executableDefinition, err + return err } case keyword.QUERY, keyword.MUTATION, keyword.SUBSCRIPTION: - err := p.parseOperationDefinition(&executableDefinition.OperationDefinitions) + err := p.parseOperationDefinition(&p.ParsedDefinitions.ExecutableDefinition.OperationDefinitions) if err != nil { - return executableDefinition, err + return err } default: diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 175ef92e7..a64c3fd2f 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -147,6 +147,7 @@ type cacheStats struct { // Lexer is the interface used by the Parser to lex tokens type Lexer interface { SetTypeSystemInput(input []byte) error + ExtendTypeSystemInput(input []byte) error ResetTypeSystemInput() SetExecutableInput(input []byte) error AppendBytes(input []byte) (err error) @@ -261,12 +262,26 @@ func (p *Parser) ParseTypeSystemDefinition(input []byte) (err error) { return } - p.ParsedDefinitions.TypeSystemDefinition, err = p.parseTypeSystemDefinition() + p.initTypeSystemDefinition() + err = p.parseTypeSystemDefinition() p.setCacheStats() return err } +func (p *Parser) ExtendTypeSystemDefinition(input []byte) (err error) { + err = p.l.ExtendTypeSystemInput(input) + if err != nil { + return + } + err = p.parseTypeSystemDefinition() + if err != nil { + return + } + p.setCacheStats() + return +} + // ParseExecutableDefinition parses an ExecutableDefinition from an io.Reader func (p *Parser) ParseExecutableDefinition(input []byte) (err error) { p.resetExecutableCaches() @@ -275,7 +290,8 @@ func (p *Parser) ParseExecutableDefinition(input []byte) (err error) { return } - p.ParsedDefinitions.ExecutableDefinition, err = p.parseExecutableDefinition() + p.initExecutableDefinition() + err = p.parseExecutableDefinition() return err } @@ -354,8 +370,8 @@ func (p *Parser) makeInputObjectTypeDefinition() document.InputObjectTypeDefinit } } -func (p *Parser) initTypeSystemDefinition(definition *document.TypeSystemDefinition) { - definition.SchemaDefinition = document.SchemaDefinition{ +func (p *Parser) initTypeSystemDefinition() { + p.ParsedDefinitions.TypeSystemDefinition.SchemaDefinition = document.SchemaDefinition{ DirectiveSet: -1, } } @@ -418,11 +434,9 @@ func (p *Parser) makeFragmentSpread() document.FragmentSpread { } } -func (p *Parser) makeExecutableDefinition() document.ExecutableDefinition { - return document.ExecutableDefinition{ - FragmentDefinitions: p.IndexPoolGet(), - OperationDefinitions: p.IndexPoolGet(), - } +func (p *Parser) initExecutableDefinition() { + p.ParsedDefinitions.ExecutableDefinition.OperationDefinitions = p.IndexPoolGet() + p.ParsedDefinitions.ExecutableDefinition.FragmentDefinitions = p.IndexPoolGet() } func (p *Parser) makeListValue(index *int) document.ListValue { diff --git a/pkg/parser/parser_mock_test.go b/pkg/parser/parser_mock_test.go index bcbc8fe00..c3a7a8752 100644 --- a/pkg/parser/parser_mock_test.go +++ b/pkg/parser/parser_mock_test.go @@ -38,6 +38,7 @@ func (m *MockLexer) EXPECT() *MockLexerMockRecorder { // SetTypeSystemInput mocks base method func (m *MockLexer) SetTypeSystemInput(input []byte) error { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetTypeSystemInput", input) ret0, _ := ret[0].(error) return ret0 @@ -45,21 +46,39 @@ func (m *MockLexer) SetTypeSystemInput(input []byte) error { // SetTypeSystemInput indicates an expected call of SetTypeSystemInput func (mr *MockLexerMockRecorder) SetTypeSystemInput(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTypeSystemInput", reflect.TypeOf((*MockLexer)(nil).SetTypeSystemInput), input) } +// ExtendTypeSystemInput mocks base method +func (m *MockLexer) ExtendTypeSystemInput(input []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExtendTypeSystemInput", input) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExtendTypeSystemInput indicates an expected call of ExtendTypeSystemInput +func (mr *MockLexerMockRecorder) ExtendTypeSystemInput(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendTypeSystemInput", reflect.TypeOf((*MockLexer)(nil).ExtendTypeSystemInput), input) +} + // ResetTypeSystemInput mocks base method func (m *MockLexer) ResetTypeSystemInput() { + m.ctrl.T.Helper() m.ctrl.Call(m, "ResetTypeSystemInput") } // ResetTypeSystemInput indicates an expected call of ResetTypeSystemInput func (mr *MockLexerMockRecorder) ResetTypeSystemInput() *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetTypeSystemInput", reflect.TypeOf((*MockLexer)(nil).ResetTypeSystemInput)) } // SetExecutableInput mocks base method func (m *MockLexer) SetExecutableInput(input []byte) error { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetExecutableInput", input) ret0, _ := ret[0].(error) return ret0 @@ -67,11 +86,13 @@ func (m *MockLexer) SetExecutableInput(input []byte) error { // SetExecutableInput indicates an expected call of SetExecutableInput func (mr *MockLexerMockRecorder) SetExecutableInput(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetExecutableInput", reflect.TypeOf((*MockLexer)(nil).SetExecutableInput), input) } // AppendBytes mocks base method func (m *MockLexer) AppendBytes(input []byte) error { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendBytes", input) ret0, _ := ret[0].(error) return ret0 @@ -79,11 +100,13 @@ func (m *MockLexer) AppendBytes(input []byte) error { // AppendBytes indicates an expected call of AppendBytes func (mr *MockLexerMockRecorder) AppendBytes(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendBytes", reflect.TypeOf((*MockLexer)(nil).AppendBytes), input) } // Read mocks base method func (m *MockLexer) Read() token.Token { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Read") ret0, _ := ret[0].(token.Token) return ret0 @@ -91,11 +114,13 @@ func (m *MockLexer) Read() token.Token { // Read indicates an expected call of Read func (mr *MockLexerMockRecorder) Read() *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockLexer)(nil).Read)) } // Peek mocks base method func (m *MockLexer) Peek(ignoreWhitespace bool) keyword.Keyword { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Peek", ignoreWhitespace) ret0, _ := ret[0].(keyword.Keyword) return ret0 @@ -103,11 +128,13 @@ func (m *MockLexer) Peek(ignoreWhitespace bool) keyword.Keyword { // Peek indicates an expected call of Peek func (mr *MockLexerMockRecorder) Peek(ignoreWhitespace interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Peek", reflect.TypeOf((*MockLexer)(nil).Peek), ignoreWhitespace) } // ByteSlice mocks base method func (m *MockLexer) ByteSlice(reference document.ByteSliceReference) document.ByteSlice { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ByteSlice", reference) ret0, _ := ret[0].(document.ByteSlice) return ret0 @@ -115,11 +142,13 @@ func (m *MockLexer) ByteSlice(reference document.ByteSliceReference) document.By // ByteSlice indicates an expected call of ByteSlice func (mr *MockLexerMockRecorder) ByteSlice(reference interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ByteSlice", reflect.TypeOf((*MockLexer)(nil).ByteSlice), reference) } // TextPosition mocks base method func (m *MockLexer) TextPosition() position.Position { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TextPosition") ret0, _ := ret[0].(position.Position) return ret0 @@ -127,5 +156,6 @@ func (m *MockLexer) TextPosition() position.Position { // TextPosition indicates an expected call of TextPosition func (mr *MockLexerMockRecorder) TextPosition() *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TextPosition", reflect.TypeOf((*MockLexer)(nil).TextPosition)) } diff --git a/pkg/parser/parser_testhelper_test.go b/pkg/parser/parser_testhelper_test.go index 6acfeec23..ed4f80678 100644 --- a/pkg/parser/parser_testhelper_test.go +++ b/pkg/parser/parser_testhelper_test.go @@ -682,22 +682,44 @@ func mustParseEnumTypeDefinition(rules ...rule) checkFunc { } } +func mustParseAddedExecutableDefinition(input string, fragments []ruleSet, operations []ruleSet) checkFunc { + return func(parser *Parser, i int) { + + err := parser.ParseExecutableDefinition([]byte(input)) + if err != nil { + panic(err) + } + + for i, set := range fragments { + fragmentIndex := parser.ParsedDefinitions.ExecutableDefinition.FragmentDefinitions[i] + fragment := parser.ParsedDefinitions.FragmentDefinitions[fragmentIndex] + set.eval(fragment, parser, i) + } + + for i, set := range operations { + opIndex := parser.ParsedDefinitions.ExecutableDefinition.OperationDefinitions[i] + operation := parser.ParsedDefinitions.OperationDefinitions[opIndex] + set.eval(operation, parser, i) + } + } +} + func mustParseExecutableDefinition(fragments []ruleSet, operations []ruleSet) checkFunc { return func(parser *Parser, i int) { - definition, err := parser.parseExecutableDefinition() + err := parser.parseExecutableDefinition() if err != nil { panic(err) } for i, set := range fragments { - fragmentIndex := definition.FragmentDefinitions[i] + fragmentIndex := parser.ParsedDefinitions.ExecutableDefinition.FragmentDefinitions[i] fragment := parser.ParsedDefinitions.FragmentDefinitions[fragmentIndex] set.eval(fragment, parser, i) } for i, set := range operations { - opIndex := definition.OperationDefinitions[i] + opIndex := parser.ParsedDefinitions.ExecutableDefinition.OperationDefinitions[i] operation := parser.ParsedDefinitions.OperationDefinitions[opIndex] set.eval(operation, parser, i) } @@ -926,26 +948,37 @@ func mustParseScalarTypeDefinition(rules ...ruleSet) checkFunc { func mustParseTypeSystemDefinition(rules ruleSet) checkFunc { return func(parser *Parser, i int) { - definition, err := parser.parseTypeSystemDefinition() + err := parser.parseTypeSystemDefinition() + if err != nil { + panic(err) + } + + evalRules(parser.ParsedDefinitions.TypeSystemDefinition, parser, rules, i) + } +} + +func mustExtendTypeSystemDefinition(extension string, rules ruleSet) checkFunc { + return func(parser *Parser, i int) { + + err := parser.ExtendTypeSystemDefinition([]byte(extension)) if err != nil { panic(err) } - evalRules(definition, parser, rules, i) + evalRules(parser.ParsedDefinitions.TypeSystemDefinition, parser, rules, i) } } func mustParseSchemaDefinition(rules ...rule) checkFunc { return func(parser *Parser, i int) { - var typeSystemDefinition document.TypeSystemDefinition - parser.initTypeSystemDefinition(&typeSystemDefinition) - err := parser.parseSchemaDefinition(&typeSystemDefinition.SchemaDefinition) + parser.initTypeSystemDefinition() + err := parser.parseSchemaDefinition(&parser.ParsedDefinitions.TypeSystemDefinition.SchemaDefinition) if err != nil { panic(err) } for k, rule := range rules { - rule(typeSystemDefinition.SchemaDefinition, parser, k, i) + rule(parser.ParsedDefinitions.TypeSystemDefinition.SchemaDefinition, parser, k, i) } } } diff --git a/pkg/parser/typesystemdefinition_parser.go b/pkg/parser/typesystemdefinition_parser.go index 71d01c54c..3bde57339 100644 --- a/pkg/parser/typesystemdefinition_parser.go +++ b/pkg/parser/typesystemdefinition_parser.go @@ -1,14 +1,11 @@ package parser import ( - "github.com/jensneuse/graphql-go-tools/pkg/document" "github.com/jensneuse/graphql-go-tools/pkg/lexing/keyword" "github.com/jensneuse/graphql-go-tools/pkg/lexing/token" ) -func (p *Parser) parseTypeSystemDefinition() (definition document.TypeSystemDefinition, err error) { - - p.initTypeSystemDefinition(&definition) +func (p *Parser) parseTypeSystemDefinition() (err error) { var hasDescription bool var description token.Token @@ -18,7 +15,7 @@ func (p *Parser) parseTypeSystemDefinition() (definition document.TypeSystemDefi switch next { case keyword.EOF: - return definition, err + return case keyword.STRING, keyword.COMMENT: descriptionToken := p.l.Read() description = descriptionToken @@ -26,68 +23,68 @@ func (p *Parser) parseTypeSystemDefinition() (definition document.TypeSystemDefi continue case keyword.SCHEMA: - if definition.SchemaDefinition.IsDefined() { + if p.ParsedDefinitions.TypeSystemDefinition.SchemaDefinition.IsDefined() { invalid := p.l.Read() - return definition, newErrInvalidType(invalid.TextPosition, "parseTypeSystemDefinition", "not a re-assignment of SchemaDefinition", "multiple SchemaDefinition assignments") + return newErrInvalidType(invalid.TextPosition, "parseTypeSystemDefinition", "not a re-assignment of SchemaDefinition", "multiple SchemaDefinition assignments") } - err = p.parseSchemaDefinition(&definition.SchemaDefinition) + err = p.parseSchemaDefinition(&p.ParsedDefinitions.TypeSystemDefinition.SchemaDefinition) if err != nil { - return definition, err + return err } case keyword.SCALAR: err := p.parseScalarTypeDefinition(hasDescription, description) if err != nil { - return definition, err + return err } case keyword.TYPE: err := p.parseObjectTypeDefinition(hasDescription, description) if err != nil { - return definition, err + return err } case keyword.INTERFACE: err := p.parseInterfaceTypeDefinition(hasDescription, description) if err != nil { - return definition, err + return err } case keyword.UNION: err := p.parseUnionTypeDefinition(hasDescription, description) if err != nil { - return definition, err + return err } case keyword.ENUM: err := p.parseEnumTypeDefinition(hasDescription, description) if err != nil { - return definition, err + return err } case keyword.INPUT: err := p.parseInputObjectTypeDefinition(hasDescription, description) if err != nil { - return definition, err + return err } case keyword.DIRECTIVE: err := p.parseDirectiveDefinition(hasDescription, description) if err != nil { - return definition, err + return err } default: invalid := p.l.Read() - return definition, newErrInvalidType(invalid.TextPosition, "parseTypeSystemDefinition", "eof/string/schema/scalar/type/interface/union/directive/input/enum", invalid.Keyword.String()) + return newErrInvalidType(invalid.TextPosition, "parseTypeSystemDefinition", "eof/string/schema/scalar/type/interface/union/directive/input/enum", invalid.Keyword.String()) } hasDescription = false diff --git a/pkg/parser/typesystemdefinition_parser_test.go b/pkg/parser/typesystemdefinition_parser_test.go index a9b941cfc..bd8b787c3 100644 --- a/pkg/parser/typesystemdefinition_parser_test.go +++ b/pkg/parser/typesystemdefinition_parser_test.go @@ -214,4 +214,69 @@ func TestParser_parseTypeSystemDefinition(t *testing.T) { t.Run("invalid keyword", func(t *testing.T) { run(`unknown {}`, mustPanic(mustParseTypeSystemDefinition(node()))) }) + + // type system definition extension + t.Run("extend with scalars", func(t *testing.T) { + run(`schema {}`, + mustPanic(mustParseTypeSystemDefinition( + node( + hasScalarTypeSystemDefinitions( + node( + hasName("String"), + ), + ), + ), + )), + mustExtendTypeSystemDefinition(` + scalar String`, node( + hasScalarTypeSystemDefinitions( + node( + hasName("String"), + ), + ), + )), + mustExtendTypeSystemDefinition(` + scalar JSON`, node( + hasScalarTypeSystemDefinitions( + node( + hasName("String"), + ), + node( + hasName("JSON"), + ), + ), + )), + ) + }) + t.Run("extend after setting executable definition should fail", func(t *testing.T) { + run(`schema {}`, + mustParseTypeSystemDefinition( + node(), + ), + mustParseAddedExecutableDefinition("{foo}", nil, nil), + mustPanic(mustExtendTypeSystemDefinition(` + scalar String`, node( + hasScalarTypeSystemDefinitions( + node( + hasName("String"), + ), + ), + ))), + ) + }) + t.Run("extend after setting executable definition should fail reverse", func(t *testing.T) { + run(`schema {}`, + mustParseTypeSystemDefinition( + node(), + ), + mustExtendTypeSystemDefinition(` + scalar String`, node( + hasScalarTypeSystemDefinitions( + node( + hasName("String"), + ), + ), + )), + ) + }) } diff --git a/pkg/proxy/http/proxy.go b/pkg/proxy/http/proxy.go index be2462ad0..72c8483dc 100644 --- a/pkg/proxy/http/proxy.go +++ b/pkg/proxy/http/proxy.go @@ -79,9 +79,9 @@ func (pr *ProxyRequest) DispatchRequest(buff *bytes.Buffer) (io.ReadCloser, erro } request := http.Request{ Method: "POST", - URL: &pr.Config.BackendURL, + URL: &pr.Config.BackendURL, Header: headers, - Body: ioutil.NopCloser(bytes.NewReader(out.Bytes())), + Body: ioutil.NopCloser(bytes.NewReader(out.Bytes())), } request.Header.Set("Content-Type", "application/json") diff --git a/pkg/proxy/http/proxy_integration_test.go b/pkg/proxy/http/proxy_integration_test.go index 6724f4871..be6767b78 100644 --- a/pkg/proxy/http/proxy_integration_test.go +++ b/pkg/proxy/http/proxy_integration_test.go @@ -39,7 +39,6 @@ func TestProxyIntegration(t *testing.T) { t.Fatalf("Expected:\n%s\ngot\n%s\n\n", privateAuthHeader, authHeader) } - body, err := ioutil.ReadAll(r.Body) if err != nil { t.Error(err) @@ -67,8 +66,8 @@ func TestProxyIntegration(t *testing.T) { headers := make(http.Header) headers.Set("Authorization", privateAuthHeader) schemaProvider := proxy.NewStaticRequestConfigProvider(proxy.RequestConfig{ - Schema: &schema, - BackendURL: *backendURL, + Schema: &schema, + BackendURL: *backendURL, BackendHeaders: headers, }) @@ -144,4 +143,4 @@ const publicQuery = `{"query":"query myDocuments {documents {sensitiveInformatio const privateQuery = `{"query":"query myDocuments {documents(user:\"jsmith@example.org\") {sensitiveInformation}}"} ` -const privateAuthHeader = "testAuth" \ No newline at end of file +const privateAuthHeader = "testAuth"