diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 0bf254343d..e72bb9c7a4 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -224,6 +224,11 @@ func (c *SubscriptionClient) generateHandlerIDHash(ctx *resolve.Context, options } } } + if options.Body.Extensions != nil { + if _, err := xxh.Write(options.Body.Extensions); err != nil { + return 0, err + } + } return xxh.Sum64(), nil } @@ -254,6 +259,13 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti return nil, err } + if options.Body.Extensions != nil { + connectionInitMessage, err = jsonparser.Set(connectionInitMessage, options.Body.Extensions, "payload", "extensions") + if err != nil { + return nil, err + } + } + // init + ack err = conn.Write(reqCtx, websocket.MessageText, connectionInitMessage) if err != nil { diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index e9e54d6fa3..e7f37fea5f 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -12,6 +12,7 @@ type Context struct { Request Request RenameTypeNames []RenameTypeName RequestTracingOptions RequestTraceOptions + Extensions []byte } type Request struct { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index ed7b2b1c62..0421c77279 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -718,7 +718,13 @@ WithNextItem: return nil } -func (l *Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight bool, source DataSource, input []byte, out io.Writer, trace *DataSourceLoadTrace) error { +func (l *Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight bool, source DataSource, input []byte, out io.Writer, trace *DataSourceLoadTrace) (err error) { + if l.ctx.Extensions != nil { + input, err = jsonparser.Set(input, l.ctx.Extensions, "body", "extensions") + if err != nil { + return errors.WithStack(err) + } + } if l.traceOptions.Enable { trace.Path = l.renderPath() if !l.traceOptions.ExcludeInput { @@ -825,7 +831,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight boo } keyGen := pool.Hash64.Get() defer pool.Hash64.Put(keyGen) - _, err := keyGen.Write(input) + _, err = keyGen.Write(input) if err != nil { return errors.WithStack(err) } diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 7649514d8f..902a074a15 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -11,7 +11,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/astjson" ) -func TestV2Loader_LoadGraphQLResponseData(t *testing.T) { +func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, @@ -299,7 +299,296 @@ func TestV2Loader_LoadGraphQLResponseData(t *testing.T) { assert.Equal(t, expected, out.String()) } -func BenchmarkV2Loader_LoadGraphQLResponseData(b *testing.B) { +func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { + ctrl := gomock.NewController(t) + productsService := mockedDS(t, ctrl, + `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}","extensions":{"foo":"bar"}}}`, + `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + + reviewsService := mockedDS(t, ctrl, + `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, + `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + + stockService := mockedDS(t, ctrl, + `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, + `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + + usersService := mockedDS(t, ctrl, + `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]},"extensions":{"foo":"bar"}}}`, + `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + response := &GraphQLResponse{ + Data: &Object{ + Fetch: &SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: productsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, + Fields: []*Field{ + { + Name: []byte("topProducts"), + Value: &Array{ + Path: []string{"topProducts"}, + Item: &Object{ + Fetch: &ParallelFetch{ + Fetches: []Fetch{ + &BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: reviewsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, + &BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: stockService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, + }, + }, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("stock"), + Value: &Integer{ + Path: []string{"stock"}, + }, + }, + { + Name: []byte("reviews"), + Value: &Array{ + Path: []string{"reviews"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("body"), + Value: &String{ + Path: []string{"body"}, + }, + }, + { + Name: []byte("author"), + Value: &Object{ + Path: []string{"author"}, + Fetch: &BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: usersService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + ctx := &Context{ + ctx: context.Background(), + Extensions: []byte(`{"foo":"bar"}`), + } + resolvable := &Resolvable{ + storage: &astjson.JSON{}, + } + loader := &Loader{} + err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) + assert.NoError(t, err) + err = loader.LoadGraphQLResponseData(ctx, response, resolvable) + assert.NoError(t, err) + ctrl.Finish() + out := &bytes.Buffer{} + err = resolvable.storage.PrintNode(resolvable.storage.Nodes[resolvable.storage.RootNode], out) + assert.NoError(t, err) + expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + assert.Equal(t, expected, out.String()) +} + +func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { productsService := FakeDataSource(`{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := FakeDataSource(`{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) @@ -591,18 +880,3 @@ func BenchmarkV2Loader_LoadGraphQLResponseData(b *testing.B) { } } } - -var ( - DefaultPostProcessingConfiguration = PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - } - EntitiesPostProcessingConfiguration = PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities"}, - SelectResponseErrorsPath: []string{"errors"}, - } - SingleEntityPostProcessingConfiguration = PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, - SelectResponseErrorsPath: []string{"errors"}, - } -) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index ad1e2c584c..4d448e93be 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,6 +7,7 @@ import ( "io" "sync" + "github.com/buger/jsonparser" "github.com/pkg/errors" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" @@ -79,7 +80,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return t.resolvable.Resolve(ctx.ctx, response.Data, writer) } -func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer FlushWriter) error { +func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer FlushWriter) (err error) { if subscription.Trigger.Source == nil { msg := []byte(`{"errors":[{"message":"no data source found"}]}`) @@ -95,6 +96,10 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ subscriptionInput := make([]byte, len(rendered)) copy(subscriptionInput, rendered) + if ctx.Extensions != nil { + subscriptionInput, err = jsonparser.Set(subscriptionInput, ctx.Extensions, "body", "extensions") + } + c, cancel := context.WithCancel(ctx.Context()) defer cancel() resolverDone := r.ctx.Done() diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 4bbf7e4894..f6e87ae25d 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -3805,9 +3805,17 @@ func FakeStream(cancelFunc func(), messageFunc func(count int) (message string, type _fakeStream struct { cancel context.CancelFunc messageFunc func(counter int) (message string, ok bool) + onStart func(input []byte) +} + +func (f *_fakeStream) SetOnStart(fn func(input []byte)) { + f.onStart = fn } func (f *_fakeStream) Start(ctx *Context, input []byte, next chan<- []byte) error { + if f.onStart != nil { + f.onStart(input) + } go func() { time.Sleep(time.Millisecond) count := 0 @@ -3834,6 +3842,14 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { plan := &GraphQLSubscription{ Trigger: GraphQLSubscriptionTrigger{ Source: stream, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`), + }, + }, + }, PostProcessing: PostProcessingConfiguration{ SelectResponseDataPath: []string{"data"}, SelectResponseErrorsPath: []string{"errors"}, @@ -3897,7 +3913,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { }) t.Run("should successfully get result from upstream", func(t *testing.T) { - t.Skip("TODO: This test hangs with the race detector enabled") c, cancel := context.WithCancel(context.Background()) defer cancel() @@ -3918,6 +3933,35 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { assert.Equal(t, `{"data":{"counter":1}}`, out.flushed[1]) assert.Equal(t, `{"data":{"counter":2}}`, out.flushed[2]) }) + + t.Run("should propagate extensions to stream", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := FakeStream(cancel, func(count int) (message string, ok bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, count), true + }) + + resolver, plan, out := setup(c, fakeStream) + + ctx := Context{ + ctx: c, + Extensions: []byte(`{"foo":"bar"}`), + } + + var inputResult string + + fakeStream.SetOnStart(func(input []byte) { + inputResult = string(input) + }) + err := resolver.ResolveGraphQLSubscription(&ctx, plan, out) + assert.NoError(t, err) + assert.Equal(t, 3, len(out.flushed)) + assert.Equal(t, `{"data":{"counter":0}}`, out.flushed[0]) + assert.Equal(t, `{"data":{"counter":1}}`, out.flushed[1]) + assert.Equal(t, `{"data":{"counter":2}}`, out.flushed[2]) + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }","extensions":{"foo":"bar"}}}`, inputResult) + }) } func Benchmark_ResolveGraphQLResponse(b *testing.B) {