From 72ca42a7e6ba7b3f281575d441a71c229a28ab8e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 07:02:39 +0200 Subject: [PATCH 01/57] feat: add astjson & ArenaResolveGraphQLResponse --- go.work.sum | 5 +- v2/go.mod | 8 +- v2/go.sum | 6 +- .../astnormalization/uploads/upload_finder.go | 2 +- .../grpc_datasource/grpc_datasource.go | 6 +- .../grpc_datasource/grpc_datasource_test.go | 5 +- .../grpc_datasource/json_builder.go | 174 +++++++++--------- v2/pkg/engine/resolve/context.go | 2 +- v2/pkg/engine/resolve/loader.go | 97 +++++----- v2/pkg/engine/resolve/loader_test.go | 18 +- v2/pkg/engine/resolve/resolvable.go | 37 ++-- .../resolvable_custom_field_renderer_test.go | 4 +- v2/pkg/engine/resolve/resolvable_test.go | 52 +++--- v2/pkg/engine/resolve/resolve.go | 34 +++- v2/pkg/engine/resolve/tainted_objects_test.go | 8 +- v2/pkg/engine/resolve/variables_renderer.go | 2 +- v2/pkg/fastjsonext/fastjsonext.go | 37 ++-- v2/pkg/fastjsonext/fastjsonext_test.go | 10 +- .../variablesvalidation.go | 2 +- 19 files changed, 273 insertions(+), 236 deletions(-) diff --git a/go.work.sum b/go.work.sum index 5f48a89a0d..9e675e2c37 100644 --- a/go.work.sum +++ b/go.work.sum @@ -247,6 +247,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -268,6 +270,8 @@ github.com/wundergraph/astjson v0.0.0-20241210135722-15ca0ac078f8/go.mod h1:eOTL github.com/wundergraph/cosmo/composition-go v0.0.0-20240404083832-79d2290084c6/go.mod h1:Ib+rknmwn4oZFN9SQ4VMP3uF/C/tEINEug5iPQxfrPc= github.com/wundergraph/cosmo/composition-go v0.0.0-20240729154441-b20b00e892c6/go.mod h1:WbKC2jd0g6BFsMpNDRVSoQyZ0QB6sWqpRfe0/1pTah4= github.com/wundergraph/cosmo/router v0.0.0-20240404083832-79d2290084c6/go.mod h1:LS+5qlr4fQVEW7JMXXI1sz7CH5cdnqx3BNc10p+UbW4= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f h1:5snewyMaIpajTu4wj22L/DgrGimICqXtUVjkZInBH3Y= +github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xdg/scram v1.0.3 h1:nTadYh2Fs4BK2xdldEa2g5bbaZp0/+1nJMMPtPxS/to= github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v1.0.3 h1:cmL5Enob4W83ti/ZHuZLuKD/xqJfus4fVPwE+/BDm+4= @@ -438,7 +442,6 @@ google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917 h1: google.golang.org/genproto/googleapis/api v0.0.0-20240102182953-50ed04b92917/go.mod h1:CmlNWB9lSezaYELKS5Ym1r44VrrbPUa7JTvw+6MbpJ0= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= -google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= diff --git a/v2/go.mod b/v2/go.mod index 8ff4759fb5..50365c0a9b 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -24,11 +24,12 @@ require ( github.com/r3labs/sse/v2 v2.8.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sebdah/goldie/v2 v2.7.1 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.14 github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/go-arena v0.0.1 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 @@ -70,3 +71,8 @@ require ( gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace ( + github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 => ../../wundergraph-projects/astjson + github.com/wundergraph/go-arena v0.0.1 => ../../wundergraph-projects/go-arena +) diff --git a/v2/go.sum b/v2/go.sum index a98384ae84..f2c6a7e004 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -115,8 +115,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -129,8 +129,6 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/vektah/gqlparser/v2 v2.5.14 h1:dzLq75BJe03jjQm6n56PdH1oweB8ana42wj7E4jRy70= github.com/vektah/gqlparser/v2 v2.5.14/go.mod h1:WQQjFc+I1YIzoPvZBhUQX7waZgg3pMLi0r8KymvAE2w= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= diff --git a/v2/pkg/astnormalization/uploads/upload_finder.go b/v2/pkg/astnormalization/uploads/upload_finder.go index b69a8bef29..0fd2d44c14 100644 --- a/v2/pkg/astnormalization/uploads/upload_finder.go +++ b/v2/pkg/astnormalization/uploads/upload_finder.go @@ -74,7 +74,7 @@ func (v *UploadFinder) FindUploads(operation, definition *ast.Document, variable variables = []byte("{}") } - v.variables, err = astjson.ParseBytesWithoutCache(variables) + v.variables, err = astjson.ParseBytes(variables) if err != nil { return nil, err } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 4d9babc602..78cdce9f79 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -101,7 +101,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // make gRPC calls for index, invocation := range invocations { errGrp.Go(func() error { - a := astjson.Arena{} // Invoke the gRPC method - this will populate invocation.Output methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) @@ -113,7 +112,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) mu.Lock() defer mu.Unlock() - response, err := builder.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) + response, err := builder.marshalResponseJSON(&invocation.Call.Response, invocation.Output) if err != nil { return err } @@ -135,8 +134,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return nil } - a := astjson.Arena{} - root := a.NewObject() + root := astjson.ObjectValue(builder.jsonArena) for _, response := range responses { root, err = builder.mergeValues(root, response) if err != nil { diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 3ae711d512..f7340cec80 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -19,8 +19,6 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" - "github.com/wundergraph/astjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" @@ -499,9 +497,8 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - arena := astjson.Arena{} jsonBuilder := newJSONBuilder(nil, gjson.Result{}) - responseJSON, err := jsonBuilder.marshalResponseJSON(&arena, &response, responseMessage) + responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 7c1fc81d77..8fe71a3210 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -11,6 +11,7 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) // Standard GraphQL response paths @@ -104,6 +105,7 @@ type jsonBuilder struct { mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation variables gjson.Result // GraphQL variables containing entity representations indexMap indexMap // Entity index mapping for federation ordering + jsonArena arena.Arena } // newJSONBuilder creates a new JSON builder instance with the provided mapping @@ -114,6 +116,7 @@ func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { mapping: mapping, variables: variables, indexMap: createRepresentationIndexMap(variables), + jsonArena: arena.NewMonotonicArena(), } } @@ -160,7 +163,7 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a if len(j.indexMap) == 0 { // No federation index map available - use simple merge // This path is taken for non-federated queries - root, _, err := astjson.MergeValues(left, right) + root, _, err := astjson.MergeValues(j.jsonArena, left, right) if err != nil { return nil, err } @@ -186,11 +189,10 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a // This function ensures that entities are placed in the correct positions in the final response // array based on their original representation order, which is critical for GraphQL federation. func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - root := astjson.Arena{} // Create the response structure with _entities array - entities := root.NewObject() - entities.Set(entityPath, root.NewArray()) + entities := astjson.ObjectValue(j.jsonArena) + entities.Set(j.jsonArena, entityPath, astjson.ArrayValue(j.jsonArena)) arr := entities.Get(entityPath) // Extract entity arrays from both responses @@ -206,12 +208,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // Merge left entities using index mapping to preserve order for index, lr := range leftRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(lr, index), lr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(lr, index), lr) } // Merge right entities using index mapping to preserve order for index, rr := range rightRepresentations { - arr.SetArrayItem(j.indexMap.getResultIndex(rr, index), rr) + arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(rr, index), rr) } return entities, nil @@ -220,12 +222,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // marshalResponseJSON converts a protobuf message into a GraphQL-compatible JSON response. // This is the core marshaling function that handles all the complex type conversions, // including oneOf types, nested messages, lists, and scalar values. -func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { +func (j *jsonBuilder) marshalResponseJSON(message *RPCMessage, data protoref.Message) (*astjson.Value, error) { if message == nil { - return arena.NewNull(), nil + return astjson.NullValue, nil } - root := arena.NewObject() + root := astjson.ObjectValue(j.jsonArena) // Handle protobuf oneOf types - these represent GraphQL union/interface types if message.IsOneOf() { @@ -259,14 +261,14 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess if field.StaticValue != "" { if len(message.MemberTypes) == 0 { // Simple static value - use as-is - root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, field.StaticValue)) continue } // Type-specific static value - match against member types for _, memberTypes := range message.MemberTypes { if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, memberTypes)) break } } @@ -284,8 +286,8 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // Handle list fields (repeated in protobuf) if fd.IsList() { list := data.Get(fd).List() - arr := arena.NewArray() - root.Set(field.AliasOrPath(), arr) + arr := astjson.ArrayValue(j.jsonArena) + root.Set(j.jsonArena, field.AliasOrPath(), arr) if !list.IsValid() { // Invalid list - leave as empty array @@ -298,15 +300,15 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess case protoref.MessageKind: // List of messages - recursively marshal each message message := list.Get(i).Message() - value, err := j.marshalResponseJSON(arena, field.Message, message) + value, err := j.marshalResponseJSON(field.Message, message) if err != nil { return nil, err } - arr.SetArrayItem(i, value) + arr.SetArrayItem(j.jsonArena, i, value) default: // List of scalar values - convert directly - j.setArrayItem(i, arena, arr, list.Get(i), fd) + j.setArrayItem(i, arr, list.Get(i), fd) } } @@ -318,24 +320,24 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess msg := data.Get(fd).Message() if !msg.IsValid() { // Invalid message - set to null - root.Set(field.AliasOrPath(), arena.NewNull()) + root.Set(j.jsonArena, field.AliasOrPath(), astjson.NullValue) continue } // Handle special list wrapper types for complex nested lists if field.IsListType { - arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) + arr, err := j.flattenListStructure(field.ListMetadata, msg, field.Message) if err != nil { return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) } - root.Set(field.AliasOrPath(), arr) + root.Set(j.jsonArena, field.AliasOrPath(), arr) continue } // Handle optional scalar wrapper types (e.g., google.protobuf.StringValue) if field.IsOptionalScalar() { - err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) + err := j.resolveOptionalField(root, field.AliasOrPath(), msg) if err != nil { return nil, err } @@ -344,27 +346,27 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess } // Regular nested message - recursively marshal - value, err := j.marshalResponseJSON(arena, field.Message, msg) + value, err := j.marshalResponseJSON(field.Message, msg) if err != nil { return nil, err } if field.JSONPath == "" { // Field should be merged into parent object (flattened) - root, _, err = astjson.MergeValues(root, value) + root, _, err = astjson.MergeValues(j.jsonArena, root, value) if err != nil { return nil, err } } else { // Field should be nested under its own key - root.Set(field.AliasOrPath(), value) + root.Set(j.jsonArena, field.AliasOrPath(), value) } continue } // Handle scalar fields (string, int, bool, etc.) - j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) + j.setJSONValue(root, field.AliasOrPath(), data, fd) } return root, nil @@ -374,34 +376,34 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess // messages to support nullable and multi-dimensional lists. This is necessary because // protobuf doesn't directly support nullable list items or complex nesting scenarios // that GraphQL allows. -func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) flattenListStructure(md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if md == nil { - return arena.NewNull(), errors.New("list metadata not found") + return astjson.NullValue, errors.New("list metadata not found") } // Validate metadata consistency if len(md.LevelInfo) < md.NestingLevel { - return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") + return astjson.NullValue, errors.New("nesting level data does not match the number of levels in the list metadata") } // Handle null data with proper nullability checking if !data.IsValid() { if md.LevelInfo[0].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") + return astjson.NullValue, errors.New("cannot add null item to response for non nullable list") } // Start recursive traversal of the nested list structure - root := arena.NewArray() - return j.traverseList(0, arena, root, md, data, message) + root := astjson.ArrayValue(j.jsonArena) + return j.traverseList(0, root, md, data, message) } // traverseList recursively traverses nested list wrapper structures to extract the actual // list data. This handles multi-dimensional lists like [[String]] or [[[User]]] by // unwrapping the protobuf message wrappers at each level. -func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { +func (j *jsonBuilder) traverseList(level int, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if level > md.NestingLevel { return current, nil } @@ -409,11 +411,11 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // List wrappers always use field number 1 in the generated protobuf fd := data.Descriptor().Fields().ByNumber(1) if fd == nil { - return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) + return astjson.NullValue, fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) } if fd.Kind() != protoref.MessageKind { - return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a message", fd.Name()) } // Get the wrapper message containing the list @@ -421,16 +423,16 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !msg.IsValid() { // Handle null wrapper based on nullability rules if md.LevelInfo[level].Optional { - return arena.NewNull(), nil + return astjson.NullValue, nil } - return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") + return astjson.ArrayValue(j.jsonArena), fmt.Errorf("cannot add null item to response for non nullable list") } // The actual list is always at field number 1 in the wrapper fd = msg.Descriptor().Fields().ByNumber(1) if !fd.IsList() { - return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) + return astjson.NullValue, fmt.Errorf("field %q is not a list", fd.Name()) } // Handle intermediate nesting levels (not the final level) @@ -438,13 +440,13 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast list := msg.Get(fd).List() for i := 0; i < list.Len(); i++ { // Create nested array for next level - next := arena.NewArray() - val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) + next := astjson.ArrayValue(j.jsonArena) + val, err := j.traverseList(level+1, next, md, list.Get(i).Message(), message) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } return current, nil @@ -455,22 +457,22 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast if !list.IsValid() { // Invalid list at final level - return empty array // Nullability is checked at the wrapper level, not the list level - return arena.NewArray(), nil + return astjson.ArrayValue(j.jsonArena), nil } // Process each item in the final list for i := 0; i < list.Len(); i++ { if message != nil { // List of complex objects - recursively marshal each item - val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) + val, err := j.marshalResponseJSON(message, list.Get(i).Message()) if err != nil { return nil, err } - current.SetArrayItem(i, val) + current.SetArrayItem(j.jsonArena, i, val) } else { // List of scalar values - convert directly - j.setArrayItem(i, arena, current, list.Get(i), fd) + j.setArrayItem(i, current, list.Get(i), fd) } } @@ -480,7 +482,7 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast // resolveOptionalField extracts the value from optional scalar wrapper types like // google.protobuf.StringValue, google.protobuf.Int32Value, etc. These wrappers // are used to represent nullable scalar values in protobuf. -func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { +func (j *jsonBuilder) resolveOptionalField(root *astjson.Value, name string, data protoref.Message) error { // Optional scalar wrappers always have a "value" field fd := data.Descriptor().Fields().ByName(protoref.Name("value")) if fd == nil { @@ -488,16 +490,16 @@ func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.V } // Extract and set the wrapped value - j.setJSONValue(arena, root, name, data, fd) + j.setJSONValue(root, name, data, fd) return nil } // setJSONValue converts a protobuf field value to the appropriate JSON representation // and sets it on the provided JSON object. This handles all protobuf scalar types // and enum values with proper GraphQL mapping. -func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setJSONValue(root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { if !data.IsValid() { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -505,27 +507,27 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na case protoref.BoolKind: boolValue := data.Get(fd).Bool() if boolValue { - root.Set(name, arena.NewTrue()) + root.Set(j.jsonArena, name, astjson.TrueValue(j.jsonArena)) } else { - root.Set(name, arena.NewFalse()) + root.Set(j.jsonArena, name, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - root.Set(name, arena.NewString(data.Get(fd).String())) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, data.Get(fd).String())) case protoref.Int32Kind: - root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + root.Set(j.jsonArena, name, astjson.IntValue(j.jsonArena, int(data.Get(fd).Int()))) case protoref.Int64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatInt(data.Get(fd).Int(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Get(fd).Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) + root.Set(j.jsonArena, name, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Get(fd).Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) + root.Set(j.jsonArena, name, astjson.FloatValue(j.jsonArena, data.Get(fd).Float())) case protoref.BytesKind: - root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) + root.Set(j.jsonArena, name, astjson.StringValueBytes(j.jsonArena, data.Get(fd).Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) if enumValueDesc == nil { - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } @@ -533,20 +535,20 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - set to null - root.Set(name, arena.NewNull()) + root.Set(j.jsonArena, name, astjson.NullValue) return } - root.Set(name, arena.NewString(graphqlValue)) + root.Set(j.jsonArena, name, astjson.StringValue(j.jsonArena, graphqlValue)) } } // setArrayItem converts a protobuf list item value to JSON and sets it at the specified // array index. This is similar to setJSONValue but operates on array elements rather // than object properties, and works with protobuf Value types rather than Message types. -func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { +func (j *jsonBuilder) setArrayItem(index int, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { if !data.IsValid() { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -554,27 +556,27 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs case protoref.BoolKind: boolValue := data.Bool() if boolValue { - array.SetArrayItem(index, arena.NewTrue()) + array.SetArrayItem(j.jsonArena, index, astjson.TrueValue(j.jsonArena)) } else { - array.SetArrayItem(index, arena.NewFalse()) + array.SetArrayItem(j.jsonArena, index, astjson.FalseValue(j.jsonArena)) } case protoref.StringKind: - array.SetArrayItem(index, arena.NewString(data.String())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, data.String())) case protoref.Int32Kind: - array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + array.SetArrayItem(j.jsonArena, index, astjson.IntValue(j.jsonArena, int(data.Int()))) case protoref.Int64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatInt(data.Int(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatInt(data.Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) + array.SetArrayItem(j.jsonArena, index, astjson.NumberValue(j.jsonArena, strconv.FormatUint(data.Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: - array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) + array.SetArrayItem(j.jsonArena, index, astjson.FloatValue(j.jsonArena, data.Float())) case protoref.BytesKind: - array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) + array.SetArrayItem(j.jsonArena, index, astjson.StringValueBytes(j.jsonArena, data.Bytes())) case protoref.EnumKind: enumDesc := fd.Enum() enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) if enumValueDesc == nil { - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } @@ -582,20 +584,19 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { // No mapping found - use null - array.SetArrayItem(index, arena.NewNull()) + array.SetArrayItem(j.jsonArena, index, astjson.NullValue) return } - array.SetArrayItem(index, arena.NewString(graphqlValue)) + array.SetArrayItem(j.jsonArena, index, astjson.StringValue(j.jsonArena, graphqlValue)) } } // toDataObject wraps a response value in the standard GraphQL data envelope. // This creates the top-level structure { "data": ... } that GraphQL clients expect. func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { - a := astjson.Arena{} - data := a.NewObject() - data.Set(dataPath, root) + data := astjson.ObjectValue(j.jsonArena) + data.Set(j.jsonArena, dataPath, root) return data } @@ -603,30 +604,27 @@ func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { // This includes the error message and gRPC status code information in the extensions // field, following GraphQL error specification standards. func (j *jsonBuilder) writeErrorBytes(err error) []byte { - a := astjson.Arena{} - defer a.Reset() - // Create standard GraphQL error structure - errorRoot := a.NewObject() - errorArray := a.NewArray() - errorRoot.Set(errorsPath, errorArray) + errorRoot := astjson.ObjectValue(j.jsonArena) + errorArray := astjson.ArrayValue(j.jsonArena) + errorRoot.Set(j.jsonArena, errorsPath, errorArray) // Create individual error object - errorItem := a.NewObject() - errorItem.Set("message", a.NewString(err.Error())) + errorItem := astjson.ObjectValue(j.jsonArena) + errorItem.Set(j.jsonArena, "message", astjson.StringValue(j.jsonArena, err.Error())) // Add gRPC status code information to extensions - extensions := a.NewObject() + extensions := astjson.ObjectValue(j.jsonArena) if st, ok := status.FromError(err); ok { // gRPC error - include the specific status code - extensions.Set("code", a.NewString(st.Code().String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, st.Code().String())) } else { // Generic error - default to INTERNAL status - extensions.Set("code", a.NewString(codes.Internal.String())) + extensions.Set(j.jsonArena, "code", astjson.StringValue(j.jsonArena, codes.Internal.String())) } - errorItem.Set("extensions", extensions) - errorArray.SetArrayItem(0, errorItem) + errorItem.Set(j.jsonArena, "extensions", extensions) + errorArray.SetArrayItem(j.jsonArena, 0, errorItem) return errorRoot.MarshalTo(nil) } diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 65d2d6b900..e9958d24ef 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -146,7 +146,7 @@ func (c *Context) appendSubgraphErrors(errs ...error) { } type Request struct { - ID string + ID uint64 Header http.Header } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 7a14d61dce..ad4e78e472 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" @@ -180,6 +181,8 @@ type Loader struct { validateRequiredExternalFields bool taintedObjs taintedObjects + + jsonArena arena.Arena } func (l *Loader) Free() { @@ -431,7 +434,7 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso return selected } -func itemsData(items []*astjson.Value) *astjson.Value { +func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { if len(items) == 0 { return astjson.NullValue } @@ -442,7 +445,7 @@ func itemsData(items []*astjson.Value) *astjson.Value { // however, itemsData can be called concurrently, so this might result in a race arr := astjson.MustParseBytes([]byte(`[]`)) for i, item := range items { - arr.SetArrayItem(i, item) + arr.SetArrayItem(a, i, item) } return arr } @@ -552,7 +555,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - response, err := astjson.ParseBytesWithoutCache(res.out.Bytes()) + response, err := astjson.ParseBytesWithArena(l.jsonArena, res.out.Bytes()) if err != nil { // Fall back to status code if parsing fails and non-2XX if (res.statusCode > 0 && res.statusCode < 200) || res.statusCode >= 300 { @@ -633,7 +636,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } if len(items) == 1 && res.batchStats == nil { - items[0], _, err = astjson.MergeValuesWithPath(items[0], responseData, res.postProcessing.MergePath...) + items[0], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[0], responseData, res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -662,7 +665,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if idx == -1 { continue } - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[idx], res.postProcessing.MergePath...) + items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[idx], res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -683,7 +686,7 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } for i := range items { - items[i], _, err = astjson.MergeValuesWithPath(items[i], batch[i], res.postProcessing.MergePath...) + items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[i], res.postProcessing.MergePath...) if err != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, @@ -749,7 +752,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V values := value.GetArray() l.optionallyOmitErrorLocations(values) if l.rewriteSubgraphErrorPaths { - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(l.jsonArena, fetchItem, values) } l.optionallyEnsureExtensionErrorCode(values) @@ -792,7 +795,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // Wrap mode (default) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) if err != nil { return err } @@ -861,17 +864,17 @@ func (l *Loader) optionallyEnsureExtensionErrorCode(values []*astjson.Value) { switch extensions.Type() { case astjson.TypeObject: if !extensions.Exists("code") { - extensions.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) + extensions.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) } case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("code", l.resolvable.astjsonArena.NewString(l.defaultErrorExtensionCode)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -888,16 +891,16 @@ func (l *Loader) optionallyAttachServiceNameToErrorExtension(values []*astjson.V extensions := value.Get("extensions") switch extensions.Type() { case astjson.TypeObject: - extensions.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) + extensions.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) case astjson.TypeNull: - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } else { - extensionsObj := l.resolvable.astjsonArena.NewObject() - extensionsObj.Set("serviceName", l.resolvable.astjsonArena.NewString(serviceName)) - value.Set("extensions", extensionsObj) + extensionsObj := astjson.ObjectValue(l.jsonArena) + extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + value.Set(l.jsonArena, "extensions", extensionsObj) } } } @@ -951,7 +954,7 @@ func (l *Loader) optionallyOmitErrorLocations(values []*astjson.Value) { // - Drops the numeric index immediately following "_entities". // - Converts all subsequent numeric segments to strings (e.g., 1 -> "1"). // - Skips non-string/non-number segments. -func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { +func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Value) { pathPrefix := make([]string, len(fetchItem.ResponsePathElements)) copy(pathPrefix, fetchItem.ResponsePathElements) // remove the trailing @ in case we're in an array as it looks weird in the path @@ -993,11 +996,11 @@ func rewriteErrorPaths(fetchItem *FetchItem, values []*astjson.Value) { } } newPathJSON, _ := json.Marshal(newPath) - pathBytes, err := astjson.ParseBytesWithoutCache(newPathJSON) + pathBytes, err := astjson.ParseBytesWithArena(a, newPathJSON) if err != nil { continue } - value.Set("path", pathBytes) + value.Set(a, "path", pathBytes) break } } @@ -1018,17 +1021,17 @@ func (l *Loader) setSubgraphStatusCode(values []*astjson.Value, statusCode int) if extensions.Type() != astjson.TypeObject { continue } - v, err := astjson.ParseWithoutCache(strconv.Itoa(statusCode)) + v, err := astjson.ParseWithArena(l.jsonArena, strconv.Itoa(statusCode)) if err != nil { continue } - extensions.Set("statusCode", v) + extensions.Set(l.jsonArena, "statusCode", v) } else { - v, err := astjson.ParseWithoutCache(`{"statusCode":` + strconv.Itoa(statusCode) + `}`) + v, err := astjson.ParseWithArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`) if err != nil { continue } - value.Set("extensions", v) + value.Set(l.jsonArena, "extensions", v) } } } @@ -1065,7 +1068,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { } } }`, res.ds.Name, http.StatusText(res.statusCode), res.statusCode) - apolloRouterStatusError, err := astjson.ParseWithoutCache(apolloRouterStatusErrorJSON) + apolloRouterStatusError, err := astjson.ParseWithArena(l.jsonArena, apolloRouterStatusErrorJSON) if err != nil { return err } @@ -1078,7 +1081,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error { path := l.renderAtPathErrorPart(fetchItem.ResponsePath) msg := fmt.Sprintf(`{"message":"Failed to obtain field dependencies from Subgraph '%s'%s."}`, res.ds.Name, path) - errorObject, err := astjson.ParseWithoutCache(msg) + errorObject, err := astjson.ParseWithArena(l.jsonArena, msg) if err != nil { return err } @@ -1089,7 +1092,7 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, reason string) error { l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) if err != nil { return err } @@ -1106,7 +1109,7 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s l.ctx.appendSubgraphErrors(res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"%s"}`, reason)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason)) if err != nil { return err } @@ -1140,13 +1143,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1156,13 +1159,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } else { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) + errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1182,35 +1185,35 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result ) if res.ds.Name == "" { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } else { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithoutCache(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) + errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled { - extension, err := astjson.ParseWithoutCache(fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) + extension, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) if err != nil { return err } - errorObject, _, err = astjson.MergeValuesWithPath(errorObject, extension, "extensions") + errorObject, _, err = astjson.MergeValuesWithPath(l.jsonArena, errorObject, extension, "extensions") if err != nil { return err } @@ -1287,7 +1290,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI res.init(fetch.PostProcessing, fetch.Info) buf := &bytes.Buffer{} - inputData := itemsData(items) + inputData := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && inputData != nil { @@ -1353,7 +1356,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc res.init(fetch.PostProcessing, fetch.Info) buf := acquireEntityFetchBuffer() defer releaseEntityFetchBuffer(buf) - input := itemsData(items) + input := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && input != nil { @@ -1465,7 +1468,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { - data := itemsData(items) + data := itemsData(l.jsonArena, items) if data != nil { fetch.Trace.RawInputData, _ = l.compactJSON(data.MarshalTo(nil)) } @@ -1840,7 +1843,7 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - v, err := astjson.ParseBytesWithoutCache(out) + v, err := astjson.ParseBytesWithArena(l.jsonArena, out) if err != nil { return nil, err } diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 4ed83d4443..01c5ef5dca 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -287,7 +287,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -376,7 +376,7 @@ func TestLoader_MergeErrorDifferingTypes(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -467,7 +467,7 @@ func TestLoader_MergeErrorDifferingArrayLength(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -749,7 +749,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctx: context.Background(), Extensions: []byte(`{"foo":"bar"}`), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -1024,7 +1024,7 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) @@ -1125,7 +1125,7 @@ func TestLoader_RedactHeaders(t *testing.T) { Enable: true, }, } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) @@ -1421,7 +1421,7 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := NewResolvable(ResolvableOptions{}) + resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) @@ -1521,13 +1521,13 @@ func TestRewriteErrorPaths(t *testing.T) { for i, inputError := range tc.inputErrors { // Create a copy by marshaling and parsing again data := inputError.MarshalTo(nil) - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(nil, data) assert.NoError(t, err, "Failed to copy input error") values[i] = value } // Call the function under test - rewriteErrorPaths(fetchItem, values) + rewriteErrorPaths(nil, fetchItem, values) // Compare the results assert.Equal(t, len(tc.expectedErrors), len(values), diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5219c910d1..5aceb2110c 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -11,6 +11,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/astjson" @@ -31,7 +32,7 @@ type Resolvable struct { valueCompletion *astjson.Value skipAddingNullErrors bool - astjsonArena *astjson.Arena + astjsonArena arena.Arena parsers []*astjson.Parser print bool @@ -67,13 +68,13 @@ type ResolvableOptions struct { ApolloCompatibilityReplaceInvalidVarError bool } -func NewResolvable(options ResolvableOptions) *Resolvable { +func NewResolvable(a arena.Arena, options ResolvableOptions) *Resolvable { return &Resolvable{ options: options, xxh: xxhash.New(), authorizationAllow: make(map[uint64]struct{}), authorizationDeny: make(map[uint64]string), - astjsonArena: &astjson.Arena{}, + astjsonArena: a, } } @@ -95,7 +96,7 @@ func (r *Resolvable) Reset() { r.operationType = ast.OperationTypeUnknown r.renameTypeNames = r.renameTypeNames[:0] r.authorizationError = nil - r.astjsonArena.Reset() + r.astjsonArena = nil r.xxh.Reset() for k := range r.authorizationAllow { delete(r.authorizationAllow, k) @@ -109,14 +110,14 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.ctx = ctx r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames - r.data = r.astjsonArena.NewObject() - r.errors = r.astjsonArena.NewArray() + r.data = astjson.ObjectValue(r.astjsonArena) + r.errors = astjson.ArrayValue(r.astjsonArena) if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } - r.data, _, err = astjson.MergeValues(r.data, initialValue) + r.data, _, err = astjson.MergeValues(r.astjsonArena, r.data, initialValue) if err != nil { return err } @@ -129,19 +130,19 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames if initialData != nil { - initialValue, err := astjson.ParseBytesWithoutCache(initialData) + initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { return err } if postProcessing.SelectResponseDataPath == nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, initialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, initialValue, postProcessing.MergePath...) if err != nil { return err } } else { selectedInitialValue := initialValue.Get(postProcessing.SelectResponseDataPath...) if selectedInitialValue != nil { - r.data, _, err = astjson.MergeValuesWithPath(r.data, selectedInitialValue, postProcessing.MergePath...) + r.data, _, err = astjson.MergeValuesWithPath(r.astjsonArena, r.data, selectedInitialValue, postProcessing.MergePath...) if err != nil { return err } @@ -155,10 +156,10 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc } } if r.data == nil { - r.data = r.astjsonArena.NewObject() + r.data = astjson.ObjectValue(r.astjsonArena) } if r.errors == nil { - r.errors = r.astjsonArena.NewArray() + r.errors = astjson.ArrayValue(r.astjsonArena) } return } @@ -168,7 +169,7 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil - r.errors = r.astjsonArena.NewArray() + r.errors = astjson.ArrayValue(r.astjsonArena) hasErrors := r.walkNode(node, data) if hasErrors { @@ -464,7 +465,7 @@ func (r *Resolvable) renderScalarFieldValue(value *astjson.Value, nullable bool) // renderScalarFieldString - is used when value require some pre-processing, e.g. unescaping or custom rendering func (r *Resolvable) renderScalarFieldBytes(data []byte, nullable bool) { - value, err := astjson.ParseBytesWithoutCache(data) + value, err := astjson.ParseBytesWithArena(r.astjsonArena, data) if err != nil { r.printErr = err return @@ -853,7 +854,7 @@ func (r *Resolvable) walkArray(arr *Array, value *astjson.Value) bool { r.popArrayPathElement() if err { if arr.Item.NodeKind() == NodeKindObject && arr.Item.NodeNullable() { - value.SetArrayItem(i, astjson.NullValue) + value.SetArrayItem(r.astjsonArena, i, astjson.NullValue) continue } if arr.Nullable { @@ -1287,14 +1288,14 @@ func (r *Resolvable) addErrorWithCodeAndPath(message, code string, fieldPath []s func (r *Resolvable) addValueCompletion(message, code string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) } func (r *Resolvable) addValueCompletionWithPath(message, code string, fieldPath []string) { if r.valueCompletion == nil { - r.valueCompletion = r.astjsonArena.NewArray() + r.valueCompletion = astjson.ArrayValue(r.astjsonArena) } r.pushNodePathElement(fieldPath) fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.valueCompletion, message, code, r.path) diff --git a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go index 843c6e6969..0dbb0394b3 100644 --- a/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go +++ b/v2/pkg/engine/resolve/resolvable_custom_field_renderer_test.go @@ -440,7 +440,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() // Setup - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} var input []byte @@ -543,7 +543,7 @@ func TestResolvable_CustomFieldRenderer(t *testing.T) { t.Parallel() input := []byte(tc.input) - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{} err := res.Init(ctx, input, ast.OperationTypeQuery) assert.NoError(t, err) diff --git a/v2/pkg/engine/resolve/resolvable_test.go b/v2/pkg/engine/resolve/resolvable_test.go index 4b92f85914..aea4e78eff 100644 --- a/v2/pkg/engine/resolve/resolvable_test.go +++ b/v2/pkg/engine/resolve/resolvable_test.go @@ -12,7 +12,7 @@ import ( func TestResolvable_Resolve(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -84,7 +84,7 @@ func TestResolvable_Resolve(t *testing.T) { func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":true}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -157,7 +157,7 @@ func TestResolvable_ResolveWithTypeMismatch(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -231,7 +231,7 @@ func TestResolvable_ResolveWithErrorBubbleUp(t *testing.T) { func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { t.Run("Non-nullable root field", func(t *testing.T) { topProducts := `{"topProducts":null}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -258,7 +258,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable root field and nested field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -333,7 +333,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable root field and non-Nullable nested field", func(t *testing.T) { topProducts := `{"topProduct":{"name":null}}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -370,7 +370,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable sibling field", func(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","reviews":[{"author":{"__typename":"User","name":"Bob"},"body":null}]}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -439,7 +439,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-nullable array and array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -469,7 +469,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Nullable array and non-nullable array item", func(t *testing.T) { topProducts := `{"topProducts":[null]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -500,7 +500,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { }) t.Run("Non-Nullable array, array item, and array item field", func(t *testing.T) { topProducts := `{"topProducts":[{"author":{"name":"Name"}},{"author":null}]}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -549,7 +549,7 @@ func TestResolvable_ApolloCompatibilityMode_NonNullability(t *testing.T) { func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -622,7 +622,7 @@ func TestResolvable_ResolveWithErrorBubbleUpUntilData(t *testing.T) { func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -653,7 +653,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -686,7 +686,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Invalid enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -719,7 +719,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { t.Run("Inaccessible enum value with value completion Apollo compatibility flag", func(t *testing.T) { enum := `{"enum":"B"}` - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := &Context{ @@ -755,7 +755,7 @@ func TestResolvable_InvalidEnumValues(t *testing.T) { func BenchmarkResolvable_Resolve(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -838,7 +838,7 @@ func BenchmarkResolvable_Resolve(b *testing.B) { func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := &Context{ Variables: nil, } @@ -923,7 +923,7 @@ func BenchmarkResolvable_ResolveWithErrorBubbleUp(b *testing.B) { } func TestResolvable_WithTracingNotStarted(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) // Do not start a trace with SetTraceStart(), but request it to be output ctx := NewContext(context.Background()) ctx.TracingOptions.Enable = true @@ -950,7 +950,7 @@ func TestResolvable_WithTracingNotStarted(t *testing.T) { func TestResolveFloat(t *testing.T) { t.Run("default behaviour", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":1.0}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -972,7 +972,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1.0}}`, out.String()) }) t.Run("invalid float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) ctx := NewContext(context.Background()) err := res.Init(ctx, []byte(`{"f":false}`), ast.OperationTypeQuery) assert.NoError(t, err) @@ -994,7 +994,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"errors":[{"message":"Float cannot represent non-float value: \"false\"","path":["f"]}],"data":null}`, out.String()) }) t.Run("truncate float", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1018,7 +1018,7 @@ func TestResolveFloat(t *testing.T) { assert.Equal(t, `{"data":{"f":1}}`, out.String()) }) t.Run("truncate float with decimal place", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityTruncateFloatValues: true, }) ctx := NewContext(context.Background()) @@ -1045,7 +1045,7 @@ func TestResolveFloat(t *testing.T) { func TestResolvable_ValueCompletion(t *testing.T) { t.Run("nested object", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1143,7 +1143,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }`) t.Run("nullable", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1241,7 +1241,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { }) t.Run("mixed nullability", func(t *testing.T) { - res := NewResolvable(ResolvableOptions{ + res := NewResolvable(nil, ResolvableOptions{ ApolloCompatibilityValueCompletionInExtensions: true, }) ctx := NewContext(context.Background()) @@ -1342,7 +1342,7 @@ func TestResolvable_ValueCompletion(t *testing.T) { func TestResolvable_WithTracing(t *testing.T) { topProducts := `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}` - res := NewResolvable(ResolvableOptions{}) + res := NewResolvable(nil, ResolvableOptions{}) background := SetTraceStart(context.Background(), true) ctx := NewContext(background) ctx.TracingOptions.Enable = true diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 14d8ad4b52..92501bd2eb 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -235,7 +235,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}) *tools { return &tools{ - resolvable: NewResolvable(options.ResolvableOptions), + resolvable: NewResolvable(nil, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, @@ -291,6 +291,38 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } +func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { + resp := &GraphQLResolveInfo{} + + start := time.Now() + <-r.maxConcurrency + resp.ResolveAcquireWaitTime = time.Since(start) + defer func() { + r.maxConcurrency <- struct{}{} + }() + + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + + err := t.resolvable.Init(ctx, nil, response.Info.OperationType) + if err != nil { + return nil, err + } + + if !ctx.ExecutionOptions.SkipLoader { + err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) + if err != nil { + return nil, err + } + } + + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, writer) + if err != nil { + return nil, err + } + + return resp, err +} + type trigger struct { id uint64 cancel context.CancelFunc diff --git a/v2/pkg/engine/resolve/tainted_objects_test.go b/v2/pkg/engine/resolve/tainted_objects_test.go index 0eeb344407..b8205dc724 100644 --- a/v2/pkg/engine/resolve/tainted_objects_test.go +++ b/v2/pkg/engine/resolve/tainted_objects_test.go @@ -70,7 +70,7 @@ func TestSelectObjectAndIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") // Convert path elements to astjson.Value slice @@ -94,7 +94,7 @@ func TestSelectObjectAndIndex(t *testing.T) { assert.Nil(t, entity, "Expected nil entity") } else { assert.NotNil(t, entity, "Expected non-nil entity") - expectedEntity, err := astjson.ParseBytesWithoutCache([]byte(tt.expectedEntity)) + expectedEntity, err := astjson.ParseBytes([]byte(tt.expectedEntity)) assert.NoError(t, err, "Failed to parse expected entity JSON") // Compare JSON representations @@ -320,10 +320,10 @@ func TestGetTaintedIndices(t *testing.T) { } mockFetch := &mockFetchWithInfo{info: fetchInfo} - response, err := astjson.ParseBytesWithoutCache([]byte(tt.responseJSON)) + response, err := astjson.ParseBytes([]byte(tt.responseJSON)) assert.NoError(t, err, "Failed to parse response JSON") - errors, err := astjson.ParseBytesWithoutCache([]byte(tt.errorsJSON)) + errors, err := astjson.ParseBytes([]byte(tt.errorsJSON)) assert.NoError(t, err, "Failed to parse errors JSON") indices := getTaintedIndices(mockFetch, response, errors) diff --git a/v2/pkg/engine/resolve/variables_renderer.go b/v2/pkg/engine/resolve/variables_renderer.go index 4cbb471f8f..0fa1d3ee14 100644 --- a/v2/pkg/engine/resolve/variables_renderer.go +++ b/v2/pkg/engine/resolve/variables_renderer.go @@ -350,7 +350,7 @@ var ( func (g *GraphQLVariableResolveRenderer) getResolvable() *Resolvable { v := _graphQLVariableResolveRendererPool.Get() if v == nil { - return NewResolvable(ResolvableOptions{}) + return NewResolvable(nil, ResolvableOptions{}) } return v.(*Resolvable) } diff --git a/v2/pkg/fastjsonext/fastjsonext.go b/v2/pkg/fastjsonext/fastjsonext.go index 0480fcbd49..4929e8a96a 100644 --- a/v2/pkg/fastjsonext/fastjsonext.go +++ b/v2/pkg/fastjsonext/fastjsonext.go @@ -2,27 +2,28 @@ package fastjsonext import ( "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) -func AppendErrorToArray(arena *astjson.Arena, v *astjson.Value, msg string, path []PathElement) { +func AppendErrorToArray(a arena.Arena, v *astjson.Value, msg string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) + errorObject := CreateErrorObjectWithPath(a, msg, path) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } -func AppendErrorWithExtensionsCodeToArray(arena *astjson.Arena, v *astjson.Value, msg, code string, path []PathElement) { +func AppendErrorWithExtensionsCodeToArray(a arena.Arena, v *astjson.Value, msg, code string, path []PathElement) { if v.Type() != astjson.TypeArray { return } - errorObject := CreateErrorObjectWithPath(arena, msg, path) - extensions := arena.NewObject() - extensions.Set("code", arena.NewString(code)) - errorObject.Set("extensions", extensions) + errorObject := CreateErrorObjectWithPath(a, msg, path) + extensions := astjson.ObjectValue(a) + extensions.Set(a, "code", astjson.StringValue(a, code)) + errorObject.Set(a, "extensions", extensions) items, _ := v.Array() - v.SetArrayItem(len(items), errorObject) + v.SetArrayItem(a, len(items), errorObject) } type PathElement struct { @@ -30,29 +31,29 @@ type PathElement struct { Idx int } -func CreateErrorObjectWithPath(arena *astjson.Arena, message string, path []PathElement) *astjson.Value { - errorObject := arena.NewObject() - errorObject.Set("message", arena.NewString(message)) +func CreateErrorObjectWithPath(a arena.Arena, message string, path []PathElement) *astjson.Value { + errorObject := astjson.ObjectValue(a) + errorObject.Set(a, "message", astjson.StringValue(a, message)) if len(path) == 0 { return errorObject } - errorPath := arena.NewArray() + errorPath := astjson.ArrayValue(a) for i := range path { if path[i].Name != "" { - errorPath.SetArrayItem(i, arena.NewString(path[i].Name)) + errorPath.SetArrayItem(a, i, astjson.StringValue(a, path[i].Name)) } else { - errorPath.SetArrayItem(i, arena.NewNumberInt(path[i].Idx)) + errorPath.SetArrayItem(a, i, astjson.IntValue(a, path[i].Idx)) } } - errorObject.Set("path", errorPath) + errorObject.Set(a, "path", errorPath) return errorObject } func PrintGraphQLResponse(data, errors *astjson.Value) string { out := astjson.MustParse(`{}`) if astjson.ValueIsNonNull(errors) { - out.Set("errors", errors) + out.Set(nil, "errors", errors) } - out.Set("data", data) + out.Set(nil, "data", data) return string(out.MarshalTo(nil)) } diff --git a/v2/pkg/fastjsonext/fastjsonext_test.go b/v2/pkg/fastjsonext/fastjsonext_test.go index af42716308..e48a2ad1c5 100644 --- a/v2/pkg/fastjsonext/fastjsonext_test.go +++ b/v2/pkg/fastjsonext/fastjsonext_test.go @@ -21,28 +21,28 @@ func TestGetArray(t *testing.T) { func TestAppendErrorWithMessage(t *testing.T) { a := astjson.MustParse(`[]`) - AppendErrorToArray(&astjson.Arena{}, a, "error", nil) + AppendErrorToArray(nil, a, "error", nil) out := a.MarshalTo(nil) require.Equal(t, `[{"message":"error"}]`, string(out)) - AppendErrorToArray(&astjson.Arena{}, a, "error2", []PathElement{{Name: "a"}}) + AppendErrorToArray(nil, a, "error2", []PathElement{{Name: "a"}}) out = a.MarshalTo(nil) require.Equal(t, `[{"message":"error"},{"message":"error2","path":["a"]}]`, string(out)) } func TestCreateErrorObjectWithPath(t *testing.T) { - v := CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v := CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, }) out := v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Idx: 1}, {Name: "b"}, }) out = v.MarshalTo(nil) require.Equal(t, `{"message":"my error message","path":["a",1,"b"]}`, string(out)) - v = CreateErrorObjectWithPath(&astjson.Arena{}, "my error message", []PathElement{ + v = CreateErrorObjectWithPath(nil, "my error message", []PathElement{ {Name: "a"}, {Name: "b"}, }) diff --git a/v2/pkg/variablesvalidation/variablesvalidation.go b/v2/pkg/variablesvalidation/variablesvalidation.go index 70bb6033ba..6953a5970d 100644 --- a/v2/pkg/variablesvalidation/variablesvalidation.go +++ b/v2/pkg/variablesvalidation/variablesvalidation.go @@ -98,7 +98,7 @@ func (v *VariablesValidator) ValidateWithRemap(operation, definition *ast.Docume func (v *VariablesValidator) Validate(operation, definition *ast.Document, variables []byte) error { v.visitor.definition = definition v.visitor.operation = operation - v.visitor.variables, v.visitor.err = astjson.ParseBytesWithoutCache(variables) + v.visitor.variables, v.visitor.err = astjson.ParseBytes(variables) if v.visitor.err != nil { return v.visitor.err } From 20bf416b618279626dd4c1bb4f60c9c808c473e6 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 16:03:32 +0200 Subject: [PATCH 02/57] chore: refactor & simplify DataSource interface --- .../graphql_datasource/graphql_datasource.go | 8 +- .../graphql_datasource_test.go | 24 +- .../grpc_datasource/grpc_datasource.go | 26 +- .../grpc_datasource/grpc_datasource_test.go | 72 +- .../datasource/httpclient/httpclient_test.go | 11 +- .../datasource/httpclient/nethttpclient.go | 27 +- .../fixtures/schema_introspection.golden | 2 +- ...on_with_custom_root_operation_types.golden | 2 +- .../fixtures/type_introspection.golden | 2 +- .../introspection_datasource/source.go | 22 +- .../introspection_datasource/source_test.go | 13 +- .../pubsub_datasource/pubsub_kafka.go | 16 +- .../pubsub_datasource/pubsub_nats.go | 30 +- .../staticdatasource/static_datasource.go | 8 +- v2/pkg/engine/plan/planner_test.go | 9 +- v2/pkg/engine/resolve/authorization_test.go | 49 +- v2/pkg/engine/resolve/datasource.go | 5 +- v2/pkg/engine/resolve/loader.go | 59 +- v2/pkg/engine/resolve/loader_hooks_test.go | 114 +- v2/pkg/engine/resolve/loader_test.go | 26 +- v2/pkg/engine/resolve/resolve.go | 6 + .../engine/resolve/resolve_federation_test.go | 225 ++- v2/pkg/engine/resolve/resolve_mock_test.go | 27 +- v2/pkg/engine/resolve/resolve_test.go | 1417 ++++++++++++----- 24 files changed, 1412 insertions(+), 788 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 3acdd07603..6f301d52d9 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1907,14 +1907,14 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, out) + return httpclient.DoMultipartForm(s.httpClient, ctx, input, files) } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input, out) + return httpclient.Do(s.httpClient, ctx, input) } type GraphQLSubscriptionClient interface { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index f7031fc3a3..75a23f5ed7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -8693,10 +8693,9 @@ func TestSource_Load(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) - - require.NoError(t, src.Load(context.Background(), input, buf)) - assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, buf.String()) + data, err := src.Load(context.Background(), input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, string(data)) }) }) t.Run("remove undefined variables", func(t *testing.T) { @@ -8709,7 +8708,6 @@ func TestSource_Load(t *testing.T) { var input []byte input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) undefinedVariables := []string{"a", "c"} ctx := context.Background() @@ -8717,8 +8715,9 @@ func TestSource_Load(t *testing.T) { input, err = httpclient.SetUndefinedVariables(input, undefinedVariables) assert.NoError(t, err) - require.NoError(t, src.Load(ctx, input, buf)) - assert.Equal(t, `{"variables":{"b":null}}`, buf.String()) + data, err := src.Load(ctx, input) + require.NoError(t, err) + assert.Equal(t, `{"variables":{"b":null}}`, string(data)) }) }) } @@ -8800,10 +8799,10 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}, buf)) + _, err = src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + require.NoError(t, err) }) t.Run("multiple files", func(t *testing.T) { @@ -8844,7 +8843,6 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputBodyWithPath(input, query, "query") input = httpclient.SetInputURL(input, []byte(serverUrl)) - buf := bytes.NewBuffer(nil) dir := t.TempDir() f1, err := os.CreateTemp(dir, file1Name) @@ -8858,11 +8856,11 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - require.NoError(t, src.LoadWithFiles(ctx, input, + _, err = src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), - httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}, - buf)) + httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) + require.NoError(t, err) }) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 78cdce9f79..58729e33c2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -7,7 +7,6 @@ package grpcdatasource import ( - "bytes" "context" "fmt" "sync" @@ -73,25 +72,24 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D } // Load implements resolve.DataSource interface. -// It processes the input JSON data to make gRPC calls and writes -// the response to the output buffer. +// It processes the input JSON data to make gRPC calls and returns +// the response data. // // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. -func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (d *DataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(string(input)).Get("body.variables") builder := newJSONBuilder(d.mapping, variables) if d.disabled { - out.Write(builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used"))) - return nil + return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } // get invocations from plan invocations, err := d.rc.Compile(d.plan, variables) if err != nil { - return err + return nil, err } responses := make([]*astjson.Value, len(invocations)) @@ -130,23 +128,19 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) } if err := errGrp.Wait(); err != nil { - out.Write(builder.writeErrorBytes(err)) - return nil + return builder.writeErrorBytes(err), nil } root := astjson.ObjectValue(builder.jsonArena) for _, response := range responses { root, err = builder.mergeValues(root, response) if err != nil { - out.Write(builder.writeErrorBytes(err)) - return err + return builder.writeErrorBytes(err), err } } - data := builder.toDataObject(root) - out.Write(data.MarshalTo(nil)) - - return nil + dataObj := builder.toDataObject(root) + return dataObj.MarshalTo(nil), nil } // LoadWithFiles implements resolve.DataSource interface. @@ -156,6 +150,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // might not be applicable for most gRPC use cases. // // Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("unimplemented") } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index f7340cec80..2a18e2f176 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -1,7 +1,6 @@ package grpcdatasource import ( - "bytes" "context" "encoding/json" "fmt" @@ -147,12 +146,10 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output := new(bytes.Buffer) - - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`), output) + output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) - fmt.Println(output.String()) + fmt.Println(string(output)) } // Test_DataSource_Load_WithMockService tests the datasource.Load method with an actual gRPC server @@ -220,12 +217,11 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) // Print the response for debugging - // fmt.Println(output.String()) + // fmt.Println(string(output)) type response struct { Data struct { @@ -238,7 +234,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { var resp response - bytes := output.Bytes() + bytes := output fmt.Println(string(bytes)) err = json.Unmarshal(bytes, &resp) @@ -310,12 +306,10 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output := new(bytes.Buffer) - // Format the input with query and variables inputJSON := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(inputJSON), output) + output, err := ds.Load(context.Background(), []byte(inputJSON)) require.NoError(t, err) // Set up the correct response structure based on your GraphQL schema @@ -332,7 +326,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { } var resp response - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") // Check if there are any errors in the response @@ -407,11 +401,10 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { require.NoError(t, err) // 4. Execute the query - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err, "Load should not return an error even when the gRPC call fails") - responseJson := output.String() + responseJson := string(output) // 5. Verify the response format according to GraphQL specification // The response should have an "errors" array with the error message @@ -425,7 +418,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { } `json:"errors"` } - err = json.Unmarshal(output.Bytes(), &response) + err = json.Unmarshal(output, &response) require.NoError(t, err, "Failed to parse response JSON") // Verify there's at least one error @@ -733,9 +726,8 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -746,7 +738,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1004,9 +996,8 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1017,7 +1008,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1141,9 +1132,8 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1154,7 +1144,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -1222,9 +1212,8 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1246,7 +1235,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1313,9 +1302,8 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":{}}`, query) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1332,7 +1320,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") @@ -1783,9 +1771,8 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -1796,7 +1783,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -2162,9 +2149,8 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -2175,7 +2161,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3464,9 +3450,8 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response @@ -3477,7 +3462,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { } `json:"errors,omitempty"` } - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") require.Empty(t, resp.Errors, "Response should not contain errors") require.NotEmpty(t, resp.Data, "Response should contain data") @@ -3617,15 +3602,14 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) diff --git a/v2/pkg/engine/datasource/httpclient/httpclient_test.go b/v2/pkg/engine/datasource/httpclient/httpclient_test.go index 223e5d8332..cbef2d1f7d 100644 --- a/v2/pkg/engine/datasource/httpclient/httpclient_test.go +++ b/v2/pkg/engine/datasource/httpclient/httpclient_test.go @@ -1,7 +1,6 @@ package httpclient import ( - "bytes" "compress/gzip" "context" "io" @@ -80,10 +79,9 @@ func TestHttpClientDo(t *testing.T) { runTest := func(ctx context.Context, input []byte, expectedOutput string) func(t *testing.T) { return func(t *testing.T) { - out := &bytes.Buffer{} - err := Do(http.DefaultClient, ctx, input, out) + output, err := Do(http.DefaultClient, ctx, input) assert.NoError(t, err) - assert.Equal(t, expectedOutput, out.String()) + assert.Equal(t, expectedOutput, string(output)) } } @@ -211,9 +209,8 @@ func TestHttpClientDo(t *testing.T) { input = SetInputURL(input, []byte(server.URL)) input, err := sjson.SetBytes(input, TRACE, true) assert.NoError(t, err) - out := &bytes.Buffer{} - err = Do(http.DefaultClient, context.Background(), input, out) + output, err := Do(http.DefaultClient, context.Background(), input) assert.NoError(t, err) - assert.Contains(t, out.String(), `"Authorization":["****"]`) + assert.Contains(t, string(output), `"Authorization":["****"]`) }) } diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4e8ca9b31e..0eb4360fa1 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -254,21 +254,27 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err } -func Do(client *http.Client, ctx context.Context, requestInput []byte, out *bytes.Buffer) (err error) { +func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) h := pool.Hash64.Get() _, _ = h.Write(body) bodyHash := h.Sum64() pool.Hash64.Put(h) ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) + + var buf bytes.Buffer + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, &buf, ContentTypeJSON) + if err != nil { + return nil, err + } + return buf.Bytes(), nil } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, out *bytes.Buffer, -) (err error) { + client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, +) (data []byte, err error) { if len(files) == 0 { - return errors.New("no files provided") + return nil, errors.New("no files provided") } url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) @@ -300,7 +306,7 @@ func DoMultipartForm( temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { - return err + return nil, err } formValues[key] = bufio.NewReader(temporaryFile) } @@ -309,7 +315,7 @@ func DoMultipartForm( multipartBody, contentType, err := multipartBytes(formValues, files) if err != nil { - return err + return nil, err } defer func() { @@ -327,7 +333,12 @@ func DoMultipartForm( bodyHash := h.Sum64() ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, out, contentType) + var buf bytes.Buffer + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, &buf, contentType) + if err != nil { + return nil, err + } + return buf.Bytes(), nil } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden index 0064f2d6bf..43d477605d 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection.golden @@ -353,4 +353,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden index 0e8d299c2c..240e7f0c3d 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/schema_introspection_with_custom_root_operation_types.golden @@ -501,4 +501,4 @@ } ], "__typename": "__Schema" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden index 41827c0f69..16017d1314 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden +++ b/v2/pkg/engine/datasource/introspection_datasource/fixtures/type_introspection.golden @@ -56,4 +56,4 @@ "interfaces": [], "possibleTypes": [], "__typename": "__Type" -} +} \ No newline at end of file diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index b9a06489d5..a55549ace9 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -1,7 +1,6 @@ package introspection_datasource import ( - "bytes" "context" "encoding/json" "errors" @@ -19,21 +18,21 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { - return err + return nil, err } if req.RequestType == TypeRequestType { - return s.singleType(out, req.TypeName) + return s.singleTypeBytes(req.TypeName) } - return json.NewEncoder(out).Encode(s.introspectionData.Schema) + return json.Marshal(s.introspectionData.Schema) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return errors.New("introspection data source does not support file uploads") +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, errors.New("introspection data source does not support file uploads") } func (s *Source) typeInfo(typeName *string) *introspection.FullType { @@ -57,3 +56,12 @@ func (s *Source) singleType(w io.Writer, typeName *string) error { return json.NewEncoder(w).Encode(typeInfo) } + +func (s *Source) singleTypeBytes(typeName *string) ([]byte, error) { + typeInfo := s.typeInfo(typeName) + if typeInfo == nil { + return null, nil + } + + return json.Marshal(typeInfo) +} diff --git a/v2/pkg/engine/datasource/introspection_datasource/source_test.go b/v2/pkg/engine/datasource/introspection_datasource/source_test.go index bb4a911433..7c331b7d14 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source_test.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source_test.go @@ -27,13 +27,18 @@ func TestSource_Load(t *testing.T) { gen.Generate(&def, &report, &data) require.False(t, report.HasErrors()) - buf := &bytes.Buffer{} source := &Source{introspectionData: &data} - require.NoError(t, source.Load(context.Background(), []byte(input), buf)) + responseData, err := source.Load(context.Background(), []byte(input)) + require.NoError(t, err) actualResponse := &bytes.Buffer{} - require.NoError(t, json.Indent(actualResponse, buf.Bytes(), "", " ")) - goldie.Assert(t, fixtureName, actualResponse.Bytes()) + require.NoError(t, json.Indent(actualResponse, responseData, "", " ")) + // Trim the trailing newline that json.Indent adds + responseBytes := actualResponse.Bytes() + if len(responseBytes) > 0 && responseBytes[len(responseBytes)-1] == '\n' { + responseBytes = responseBytes[:len(responseBytes)-1] + } + goldie.Assert(t, fixtureName, responseBytes) } } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index cc562b803e..7f1a6226b2 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -1,10 +1,8 @@ package pubsub_datasource import ( - "bytes" "context" "encoding/json" - "io" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" @@ -68,21 +66,19 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { var publishConfiguration KafkaPublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index 31cb6d4154..e5d3bec0f0 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -77,23 +77,21 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { var publishConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) + err = json.Unmarshal(input, &publishConfiguration) if err != nil { - return err + return nil, err } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + return []byte(`{"success": false}`), err } - _, err = io.WriteString(out, `{"success": true}`) - return err + return []byte(`{"success": true}`), nil } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } @@ -101,16 +99,22 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) + err = json.Unmarshal(input, &subscriptionConfiguration) if err != nil { - return err + return nil, err + } + + var buf bytes.Buffer + err = s.pubSub.Request(ctx, subscriptionConfiguration, &buf) + if err != nil { + return nil, err } - return s.pubSub.Request(ctx, subscriptionConfiguration, out) + return buf.Bytes(), nil } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index e9074635cc..626a1d9f94 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -1,7 +1,6 @@ package staticdatasource import ( - "bytes" "context" "github.com/jensneuse/abstractlogger" @@ -71,11 +70,10 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - _, err = out.Write(input) - return +func (Source) Load(ctx context.Context, input []byte) (data []byte, err error) { + return input, nil } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 270140381f..658ff3fc72 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -1,7 +1,6 @@ package plan import ( - "bytes" "context" "encoding/json" "fmt" @@ -1075,10 +1074,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { + return nil, nil } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { - return +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + return nil, nil } diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index 263724a77c..ea83c77259 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "context" "encoding/json" "errors" @@ -510,38 +509,32 @@ func TestAuthorization(t *testing.T) { func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ @@ -821,38 +814,32 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }).AnyTimes() reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }).AnyTimes() productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }).AnyTimes() return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index c679d7693a..8063541f6d 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "context" "github.com/cespare/xxhash/v2" @@ -10,8 +9,8 @@ import ( ) type DataSource interface { - Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) - LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) + Load(ctx context.Context, input []byte) (data []byte, err error) + LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) } type SubscriptionDataSource interface { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index ad4e78e472..1bab9779b9 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -57,11 +57,7 @@ type ResponseInfo struct { // ResponseHeaders contains a clone of the headers of the response from the subgraph. ResponseHeaders http.Header // This should be private as we do not want user's to access the raw responseBody directly - responseBody *bytes.Buffer -} - -func (ri *ResponseInfo) GetResponseBody() string { - return ri.responseBody.String() + responseBody []byte } func newResponseInfo(res *result, subgraphError error) *ResponseInfo { @@ -119,7 +115,6 @@ func (b *batchStats) getUniqueIndexes() int { type result struct { postProcessing PostProcessingConfiguration - out *bytes.Buffer batchStats batchStats fetchSkipped bool nestedMergeItems []*result @@ -139,6 +134,7 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext + out []byte } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -283,9 +279,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { switch f := item.Fetch.(type) { case *SingleFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadSingleFetch(l.ctx.ctx, f, item, items, res) if err != nil { return err @@ -297,9 +291,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err case *BatchEntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -310,9 +302,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } return err case *EntityFetch: - res := &result{ - out: &bytes.Buffer{}, - } + res := &result{} err := l.loadEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -330,9 +320,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range items { i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } + results[i] = &result{} if l.ctx.TracingOptions.Enable { f.Traces[i] = new(SingleFetch) *f.Traces[i] = *f.Fetch @@ -453,7 +441,6 @@ func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: - res.out = &bytes.Buffer{} return l.loadSingleFetch(ctx, f, fetchItem, items, res) case *ParallelListItemFetch: results := make([]*result, len(items)) @@ -463,9 +450,7 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range items { i := i - results[i] = &result{ - out: &bytes.Buffer{}, - } + results[i] = &result{} if l.ctx.TracingOptions.Enable { f.Traces[i] = new(SingleFetch) *f.Traces[i] = *f.Fetch @@ -485,10 +470,8 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte res.nestedMergeItems = results return nil case *EntityFetch: - res.out = &bytes.Buffer{} return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: - res.out = &bytes.Buffer{} return l.loadBatchEntityFetch(ctx, fetchItem, f, items, res) } return nil @@ -551,11 +534,12 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if res.fetchSkipped { return nil } - if res.out.Len() == 0 { + if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } - - response, err := astjson.ParseBytesWithArena(l.jsonArena, res.out.Bytes()) + slice := arena.AllocateSlice[byte](l.jsonArena, len(res.out), len(res.out)) + copy(slice, res.out) + response, err := astjson.ParseBytesWithArena(l.jsonArena, slice) if err != nil { // Fall back to status code if parsing fails and non-2XX if (res.statusCode > 0 && res.statusCode < 200) || res.statusCode >= 300 { @@ -706,7 +690,8 @@ var ( errorsInvalidInputFooter = []byte(`]}]}`) ) -func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffer) error { +func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { + out := &bytes.Buffer{} elements := fetchItem.ResponsePathElements if len(elements) > 0 && elements[len(elements)-1] == "@" { elements = elements[:len(elements)-1] @@ -724,7 +709,7 @@ func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem, out *bytes.Buffe _, _ = out.Write(quote) } _, _ = out.Write(errorsInvalidInputFooter) - return nil + return out.Bytes() } func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *astjson.Value, values []*astjson.Value) error { @@ -1312,7 +1297,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI err := fetch.InputTemplate.Render(l.ctx, inputData, buf) if err != nil { - return l.renderErrorsInvalidInput(fetchItem, res.out) + res.out = l.renderErrorsInvalidInput(fetchItem) + return nil } fetchInput := buf.Bytes() allowed, err := l.validatePreFetch(fetchInput, fetch.Info, res) @@ -1648,9 +1634,14 @@ func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *Data func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { if l.ctx.Files != nil { - return source.LoadWithFiles(ctx, input, l.ctx.Files, res.out) + res.out, res.err = source.LoadWithFiles(ctx, input, l.ctx.Files) + } else { + res.out, res.err = source.Load(ctx, input) } - return source.Load(ctx, input, res.out) + if res.err != nil { + return errors.WithStack(res.err) + } + return nil } func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, source DataSource, input []byte, res *result, trace *DataSourceLoadTrace) { @@ -1813,8 +1804,8 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so trace.SingleFlightUsed = stats.SingleFlightUsed trace.SingleFlightSharedResponse = stats.SingleFlightSharedResponse } - if !l.ctx.TracingOptions.ExcludeOutput && res.out.Len() > 0 { - trace.Output, _ = l.compactJSON(res.out.Bytes()) + if !l.ctx.TracingOptions.ExcludeOutput && len(res.out) > 0 { + trace.Output, _ = l.compactJSON(res.out) if l.ctx.TracingOptions.EnablePredictableDebugTimings { trace.Output, _ = sjson.DeleteBytes(trace.Output, "extensions.trace.response.headers.Date") } diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index 4b7b3ea6c5..d82857598d 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -3,7 +3,6 @@ package resolve import ( "bytes" "context" - "io" "sync" "sync/atomic" "testing" @@ -50,11 +49,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("simple fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -124,11 +121,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -192,11 +187,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ ctx: context.Background(), @@ -257,11 +250,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -322,12 +313,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) resolveCtx := Context{ ctx: context.Background(), @@ -388,12 +376,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate only extension code field from subgraph errors", testFnSubgraphErrorsWithExtensionFieldCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -426,12 +411,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate all extension fields from subgraph errors when allow all option is enabled", testFnSubgraphErrorsWithAllowAllExtensionFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\",\"foo\":\"bar\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -464,12 +446,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName extension field", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("{\"code\":\"BAD_USER_INPUT\"}")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -502,12 +481,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is null", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -540,12 +516,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is an empty object", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, []byte("null")) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -578,12 +551,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when no code field was set", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{\"code\":\"GRAPHQL_VALIDATION_FAILED\"}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -616,12 +586,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is null", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("null")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -654,12 +621,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, []byte("{}")) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 01c5ef5dca..0fe38ddc79 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -19,19 +19,19 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -480,19 +480,19 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}","extensions":{"foo":"bar"}}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]},"extensions":{"foo":"bar"}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ @@ -1054,7 +1054,7 @@ func TestLoader_RedactHeaders(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","header":{"Authorization":"value"},"body":{"query":"query{topProducts{name __typename upc}}"},"__trace__":true}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) response := &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1153,19 +1153,19 @@ func TestLoader_InvalidBatchItemCount(t *testing.T) { ctrl := gomock.NewController(t) productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2}]}`) // 3 items expected, 2 returned + `{"data":{"_entities":[{"stock":8},{"stock":2}]}}`) // 3 items expected, 2 returned usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}`) // 2 items expected, 3 returned + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"},{"name":"user-3"}]}}`) // 2 items expected, 3 returned response := &GraphQLResponse{ Fetches: Sequence( Single(&SingleFetch{ diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 92501bd2eb..4a0075f6b4 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "go.uber.org/atomic" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -303,6 +304,11 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + jsonArena := arena.NewMonotonicArena() + defer jsonArena.Release() + t.loader.jsonArena = jsonArena + t.resolvable.astjsonArena = jsonArena + err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { return nil, err diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index 2547c6d104..64d969c6c6 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -1,9 +1,7 @@ package resolve import ( - "bytes" "context" - "io" "testing" "github.com/golang/mock/gomock" @@ -21,18 +19,11 @@ func mockedDS(t TestingTB, ctrl *gomock.Controller, expectedInput, responseData t.Helper() service := NewMockDataSource(ctrl) service.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := expectedInput - - require.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(responseData) - - return writeGraphqlResponse(pair, w, false) - }).AnyTimes() + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + require.Equal(t, expectedInput, string(input)) + return []byte(responseData), nil + }).Times(1) return service } @@ -48,7 +39,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, - `{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}`, + `{"data":{"user":{"account":{"__typename":"Account","id":"1234","info":{"a":"foo","b":"bar"}}}}}`, ), Input: `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {__typename id info {a b}}}}"}}`, PostProcessing: PostProcessingConfiguration{ @@ -70,7 +61,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { DataSource: mockedDS( t, ctrl, expectedAccountsQuery, - `{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}`, + `{"data":{"_entities":[{"__typename":"Account","name":"John Doe","shippingInfo":{"zip":"12345"}}]}}`, ), Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Account {name shippingInfo {zip}}}}","variables":{"representations":$$0$$}}}`, PostProcessing: PostProcessingConfiguration{ @@ -182,38 +173,38 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("federation with shareable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { firstService := NewMockDataSource(ctrl) firstService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://first.service","body":{"query":"{me {details {forename middlename} __typename id}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"me": {"__typename": "User", "id": "1234", "details": {"forename": "John", "middlename": "A"}}}}`) + return pair.Data.Bytes(), nil }) secondService := NewMockDataSource(ctrl) secondService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://second.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {surname}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"surname": "Smith"}}]}}`) + return pair.Data.Bytes(), nil }) thirdService := NewMockDataSource(ctrl) thirdService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://third.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {age}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename": "User", "details": {"age": 21}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities": [{"__typename": "User", "details": {"age": 21}}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -377,26 +368,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Info"},{"id": 55,"__typename":"Address"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age } ... on Address { line1 }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":55,"__typename":"Address"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"line1":"Munich","__typename":"Address"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -530,19 +521,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","infoOrAddress":[{"id":11,"__typename":"Whatever"},{"id": 55,"__typename":"Whatever"}]}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -675,26 +666,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching on a field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -819,26 +810,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with duplicates", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":{"id":11,"__typename":"Info"}},{"name":"Jane","info":{"id":11,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":77,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":77,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -960,26 +951,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with null entry", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":11,"__typename":"Info"}},{"name":"John","info":null},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":23,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1105,19 +1096,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with all null entries", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":null},{"name":"John","info":null},{"name":"Jane","info":null}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1243,27 +1234,27 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with render error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() // render error - first item id is boolean - pair.Data.WriteString(`{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"users":[{"name":"Bill","info":{"id":true,"__typename":"Info"}},{"name":"John","info":{"id":12,"__typename":"Info"}},{"name":"Jane","info":{"id":13,"__typename":"Info"}}]}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"},{"age":22,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1390,26 +1381,26 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("all data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":11,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"age":21,"__typename":"Info"}]}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"_entities":[{"age":21,"__typename":"Info"}]}}`) + return pair.Data.Bytes(), nil }) return &GraphQLResponse{ @@ -1524,19 +1515,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("null info data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":null}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":null}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1652,19 +1643,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("wrong type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":false,"__typename":"Info"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1780,19 +1771,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("not matching type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) pair := NewBufPair() - pair.Data.WriteString(`{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}`) - return writeGraphqlResponse(pair, w, false) + pair.Data.WriteString(`{"data":{"user":{"name":"Bill","info":{"id":1,"__typename":"Whatever"}}}}`) + return pair.Data.Bytes(), nil }) infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). + Load(gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1912,19 +1903,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { user := mockedDS(t, ctrl, `{"method":"POST","url":"http://user.service","body":{"query":"{user {account {address {__typename id line1 line2}}}}"}}`, - `{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}`) + `{"data":{"user":{"account":{"address":{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2"}}}}}`) addressEnricher := mockedDS(t, ctrl, `{"method":"POST","url":"http://address-enricher.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {country city}}}","variables":{"representations":[{"__typename":"Address","id":"address-1"}]}}}`, - `{"__typename":"Address","country":"country-1","city":"city-1"}`) + `{"data":{"__typename":"Address","country":"country-1","city":"city-1"}}`) address := mockedDS(t, ctrl, `{"method":"POST","url":"http://address.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {line3(test: "BOOM") zip}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","country":"country-1","city":"city-1"}]}}}`, - `{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}`) + `{"data":{"__typename": "Address", "line3": "line3-1", "zip": "zip-1"}}`) account := mockedDS(t, ctrl, `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Address {fullAddress}}}","variables":{"representations":[{"__typename":"Address","id":"address-1","line1":"line1","line2":"line2","line3":"line3-1","zip":"zip-1"}]}}}`, - `{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}`) + `{"data":{"__typename":"Address","fullAddress":"line1 line2 line3-1 city-1 country-1 zip-1"}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2152,19 +2143,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`, - `{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + `{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`, - `{"_entities":[{"name":"user-1"},{"name":"user-2"}]}`) + `{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2424,19 +2415,19 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { productsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`, - `{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}`) + `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"}]}}`) reviewsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}`) + `{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}}]}]}}`) stockService := mockedDS(t, ctrl, `{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"}]}}}`, - `{"_entities":[{"stock":8}]}`) + `{"data":{"_entities":[{"stock":8}]}}`) usersService := mockedDS(t, ctrl, `{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"}]}}}`, - `{"_entities":[{"name":"user-1"}]}`) + `{"data":{"_entities":[{"name":"user-1"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2696,11 +2687,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts{__typename ... on User {__typename id} ... on Moderator {__typename moderatorID} ... on Admin {__typename adminID}}}"}}`, - `{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}`) + `{"data":{"accounts":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name} ... on Moderator {subject} ... on Admin {type}}}","variables":{"representations":[{"__typename":"User","id":"3"},{"__typename":"Admin","adminID":"2"},{"__typename":"Moderator","moderatorID":"1"}]}}}`, - `{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}`) + `{"data":{"_entities":[{"__typename":"User","name":"User"},{"__typename":"Admin","type":"super"},{"__typename":"Moderator","subject":"posts"}]}}`) return &GraphQLResponse{ Fetches: Sequence( @@ -2836,11 +2827,11 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { accountsService := mockedDS(t, ctrl, `{"method":"POST","url":"http://accounts","body":{"query":"{accounts {__typename ... on User {some {__typename id}} ... on Admin {some {__typename id}}}}"}}`, - `{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}`) + `{"data":{"accounts":[{"__typename":"User","some":{"__typename":"User","id":"1"}},{"__typename":"Admin","some":{"__typename":"User","id":"2"}},{"__typename":"User","some":{"__typename":"User","id":"3"}}]}}`) namesService := mockedDS(t, ctrl, `{"method":"POST","url":"http://names","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {__typename title}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"3"}]}}}`, - `{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}`) + `{"data":{"_entities":[{"__typename":"User","title":"User1"},{"__typename":"User","title":"User3"}]}}`) return &GraphQLResponse{ Fetches: Sequence( diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index 3f72cc3d89..d493ff4bdf 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -5,7 +5,6 @@ package resolve import ( - bytes "bytes" context "context" reflect "reflect" @@ -37,29 +36,31 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 *bytes.Buffer) error { +func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Load", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Load indicates an expected call of Load. -func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) Load(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1) } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload, arg3 *bytes.Buffer) error { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // LoadWithFiles indicates an expected call of LoadWithFiles. -func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2) } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 8e15ff98a9..d19156f365 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -32,7 +32,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -41,11 +41,10 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buf require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -54,8 +53,7 @@ func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = out.Write(f.data) - return + return f.data, nil } func FakeDataSource(data string) *_fakeDataSource { @@ -351,12 +349,11 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("fetch with context variable resolver", testFn(true, func(t *testing.T, ctrl *gomock.Controller) (response *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { - _, err = w.Write([]byte(`{"name":"Jens"}`)) - return + Load(gomock.Any(), []byte(`{"id":1}`)). + Do(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil }). - Return(nil) + Return([]byte(`{"name":"Jens"}`), nil) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -1802,11 +1799,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1834,11 +1829,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID no subgraph error forwarding", testFnNoSubgraphErrorForwarding(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1866,11 +1859,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -1902,11 +1893,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error in pass through Subgraph Error Mode", testFnSubgraphErrorsPassthrough(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ Fetches: Single(&SingleFetch{ @@ -1938,10 +1927,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with pass through mode and omit custom fields", testFnSubgraphErrorsPassthroughAndOmitCustomFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) error { - _, err := w.Write([]byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`)) - return err + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`), nil }) return &GraphQLResponse{ Info: &GraphQLResponseInfo{ @@ -1976,9 +1964,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (with DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2010,9 +1998,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (no DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2040,9 +2028,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err and non-nullable root field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - return &net.AddrError{} + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, &net.AddrError{} }) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ @@ -2218,14 +2206,10 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with two Errors", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage1"), nil, nil, nil) - pair.WriteErr([]byte("errorMessage2"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }). - Return(nil) + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}]}`), nil + }).Times(1) return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ FetchConfiguration: FetchConfiguration{ @@ -2578,39 +2562,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("complex GraphQL Server plan", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) serviceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"query($firstArg: String, $thirdArg: Int){serviceOne(serviceOneArg: $firstArg){fieldOne} anotherServiceOne(anotherServiceOneArg: $thirdArg){fieldOne} reusingServiceOne(reusingServiceOneArg: $firstArg){fieldOne}}","variables":{"thirdArg":123,"firstArg":"firstArgValue"}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}`), nil }) serviceTwo := NewMockDataSource(ctrl) serviceTwo.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.two","body":{"query":"query($secondArg: Boolean, $fourthArg: Float){serviceTwo(serviceTwoArg: $secondArg){fieldTwo} secondServiceTwo(secondServiceTwoArg: $fourthArg){fieldTwo}}","variables":{"fourthArg":12.34,"secondArg":true}}}` assert.Equal(t, expected, actual) - - pair := NewBufPair() - pair.Data.WriteString(`{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceTwo":{"fieldTwo":"fieldTwoValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"}}}`), nil }) nestedServiceOne := NewMockDataSource(ctrl) nestedServiceOne.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"{serviceOne {fieldOne}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"serviceOne":{"fieldOne":"fieldOneValue"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"serviceOne":{"fieldOne":"fieldOneValue"}}}`), nil }) return &GraphQLResponse{ @@ -2821,52 +2798,42 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename":"User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) - // {"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":["id":"1234","__typename":"User"]}}} expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) var productServiceCallCount atomic.Int64 productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) productServiceCallCount.Add(1) switch actual { case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Furby"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name": "Furby"}]}}`), nil case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-2","__typename":"Product"}]}}}`: - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name": "Trilby"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name": "Trilby"}]}}`), nil default: t.Fatalf("unexpected request: %s", actual) } - return - }). - Return(nil).Times(2) + return nil, nil + }).Times(2) return &GraphQLResponse{ Fetches: Sequence( @@ -3038,38 +3005,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3241,38 +3202,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with merge paths", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"__typename":"User","reviews": [{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w *bytes.Buffer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities": [{"name": "Trilby"},{"name": "Fedora"}]}}`), nil }) return &GraphQLResponse{ @@ -3445,45 +3400,39 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with null response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","username":"Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews": [ + return []byte(`{"data":{"_entities":[{"reviews": [ {"body": "foo","product": {"upc": "top-1","__typename": "Product"}}, {"body": "bar","product": {"upc": "top-2","__typename": "Product"}}, {"body": "baz","product": null}, {"body": "bat","product": {"upc": "top-4","__typename": "Product"}}, {"body": "bal","product": {"upc": "top-5","__typename": "Product"}}, {"body": "ban","product": {"upc": "top-6","__typename": "Product"}} -]}]}`) - return writeGraphqlResponse(pair, w, false) +]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"},{"upc":"top-4","__typename":"Product"},{"upc":"top-5","__typename":"Product"},{"upc":"top-6","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"name":"Trilby"},{"name":"Fedora"},{"name":"Boater"},{"name":"Top Hat"},{"name":"Bowler"}]}}`), nil }) return &GraphQLResponse{ @@ -3678,38 +3627,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -3871,38 +3814,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil }) reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil }) productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -4061,38 +3998,32 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with optional variable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8080/query","body":{"query":"{me {id}}"}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me":{"id":"1234","__typename":"User"}}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"me":{"id":"1234","__typename":"User"}}}`), nil }) employeeService := NewMockDataSource(ctrl) employeeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8081/query","body":{"query":"query($representations: [_Any!]!, $companyId: ID!){_entities(representations: $representations){... on User {employment(companyId: $companyId){id}}}}","variables":{"companyId":"abc123","representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"employment":{"id":"xyz987"}}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"employment":{"id":"xyz987"}}]}}`), nil }) timeService := NewMockDataSource(ctrl) timeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8082/query","body":{"query":"query($representations: [_Any!]!, $date: LocalTime){_entities(representations: $representations){... on Employee {times(date: $date){id employee {id} start end}}}}","variables":{"date":null,"representations":[{"id":"xyz987","__typename":"Employee"}]}}}` assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}`) - return writeGraphqlResponse(pair, w, false) + return []byte(`{"data":{"_entities":[{"times":[{"id": "t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}]}}`), nil }) return &GraphQLResponse{ @@ -4263,148 +4194,597 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }) } -func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { - options := apolloCompatibilityOptions{ - valueCompletion: true, - suppressFetchErrors: true, +// testFnArena is a helper function for testing ArenaResolveGraphQLResponse +func testFnArena(fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string)) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + + ctrl := gomock.NewController(t) + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResolver(rCtx) + node, ctx, expectedOutput := fn(t, ctrl) + + if node.Info == nil { + node.Info = &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + } + } + + if t.Skipped() { + return + } + + buf := &bytes.Buffer{} + _, err := r.ArenaResolveGraphQLResponse(&ctx, node, buf) + assert.NoError(t, err) + assert.Equal(t, expectedOutput, buf.String()) + ctrl.Finish() } - t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte("{}")) - return - }) +} + +func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { + + t.Run("empty graphql response", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), - SegmentType: StaticSegmentType, + Data: &Object{ + Nullable: true, + }, + }, Context{ctx: context.Background()}, `{"data":{}}` + })) + + t.Run("simple data source", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","registered":true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("registered"), + Value: &Boolean{ + Path: []string{"registered"}, + Nullable: false, + }, + }, + }, }, }, }, - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - }, "query"), + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","registered":true}}}` + })) + + t.Run("array of strings", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"strings": ["Alex", "true", "123"]}`)}, + }), Data: &Object{ Fields: []*Field{ { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, + Name: []byte("strings"), + Value: &Array{ + Path: []string{"strings"}, + Item: &String{ + Nullable: false, + }, }, }, }, }, - }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` - }, &options)) + }, Context{ctx: context.Background()}, `{"data":{"strings":["Alex","true","123"]}}` + })) - t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - _, _ = w.Write([]byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`)) - return - }) + t.Run("array of objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), - SegmentType: StaticSegmentType, + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("friends"), + Value: &Array{ + Path: []string{"friends"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, }, }, }, - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - }, "query"), + }, + }, Context{ctx: context.Background()}, `{"data":{"friends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}]}}` + })) + + t.Run("nested objects", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}`)}, + }), Data: &Object{ Fields: []*Field{ { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("pet"), + Value: &Object{ + Path: []string{"pet"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("kind"), + Value: &String{ + Path: []string{"kind"}, + Nullable: false, + }, + }, + }, + }, + }, + }, }, }, }, }, - }, Context{ctx: context.Background()}, `{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}],"data":null}` - }, &options)) - - t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - userService := NewMockDataSource(ctrl) - userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"me": {"id": "1234","username": "Me","__typename": "User"}}`) - return writeGraphqlResponse(pair, w, false) - }) - - reviewsService := NewMockDataSource(ctrl) - reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.Data.WriteString(`{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}`) - return writeGraphqlResponse(pair, w, false) - }) - - productService := NewMockDataSource(ctrl) - productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - DoAndReturn(func(ctx context.Context, input []byte, w io.Writer) (err error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` - assert.Equal(t, expected, actual) - pair := NewBufPair() - pair.WriteErr([]byte("errorMessage"), nil, nil, nil) - return writeGraphqlResponse(pair, w, false) - }) + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","pet":{"name":"Barky","kind":"Dog"}}}}` + })) + t.Run("scalar types", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ - Fetches: Sequence( - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), - SegmentType: StaticSegmentType, - }, + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"int": 12345, "float": 3.5, "str":"value", "bool": true}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("int"), + Value: &Integer{ + Path: []string{"int"}, + Nullable: false, }, }, - FetchConfiguration: FetchConfiguration{ - DataSource: userService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, + { + Name: []byte("float"), + Value: &Float{ + Path: []string{"float"}, + Nullable: false, }, }, - }, "query"), - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { + { + Name: []byte("str"), + Value: &String{ + Path: []string{"str"}, + Nullable: false, + }, + }, + { + Name: []byte("bool"), + Value: &Boolean{ + Path: []string{"bool"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"int":12345,"float":3.5,"str":"value","bool":true}}` + })) + + t.Run("null field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("foo"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"foo":null}}` + })) + + t.Run("__typename field", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"id":1,"name":"Jannik","__typename":"User"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + Nullable: false, + IsTypeName: true, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":1,"name":"Jannik","__typename":"User"}}}` + })) + + t.Run("multiple fetches", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user1"), + Value: &Object{ + Path: []string{"user1"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + { + Name: []byte("user2"), + Value: &Object{ + Path: []string{"user2"}, + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + Nullable: false, + }, + }, + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"user1":{"id":1,"name":"User1"},"user2":{"id":2,"name":"User2"}}}` + })) + + t.Run("with variables", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), []byte(`{"id":1}`)). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"name":"Jens"}`), nil + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"id":`), + SegmentType: StaticSegmentType, + }, + { + Data: []byte(`{{.arguments.id}}`), + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"id"}, + Renderer: NewPlainVariableRenderer(), + }, + { + Data: []byte(`}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"id":1}`))}, `{"data":{"name":"Jens"}}` + })) + + t.Run("error handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return nil, errors.New("data source error") + }) + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: mockDataSource}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph."}],"data":null}` + })) + + t.Run("bigint handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource(`{"n": 12345, "ns_small": "12346", "ns_big": "1152921504606846976"}`)}, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("n"), + Value: &BigInt{ + Path: []string{"n"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_small"), + Value: &BigInt{ + Path: []string{"ns_small"}, + Nullable: false, + }, + }, + { + Name: []byte("ns_big"), + Value: &BigInt{ + Path: []string{"ns_big"}, + Nullable: false, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":{"n":12345,"ns_small":"12346","ns_big":"1152921504606846976"}}` + })) + + t.Run("skip loader", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + return &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("static"), + Value: &Null{}, + }, + }, + }, + }, Context{ctx: context.Background(), ExecutionOptions: ExecutionOptions{SkipLoader: true}}, `{"data":null}` + })) +} + +func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { + options := apolloCompatibilityOptions{ + valueCompletion: true, + suppressFetchErrors: true, + } + t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte("{}"), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + }, "query"), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"data":null,"extensions":{"valueCompletion":[{"message":"Cannot return null for non-nullable field Query.name.","path":["name"],"extensions":{"code":"INVALID_GRAPHQL"}}]}}` + }, &options)) + + t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`), nil + }) + return &GraphQLResponse{ + Fetches: SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{query{name}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + }, "query"), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}],"data":null}` + }, &options)) + + t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { + userService := NewMockDataSource(ctrl) + userService.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + actual := string(input) + expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` + assert.Equal(t, expected, actual) + return []byte(`{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}`), nil + }) + + reviewsService := NewMockDataSource(ctrl) + reviewsService.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + actual := string(input) + expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` + assert.Equal(t, expected, actual) + return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product":{"upc": "top-1","__typename":"Product"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product"}}]}]}}`), nil + }) + + productService := NewMockDataSource(ctrl) + productService.EXPECT(). + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + actual := string(input) + expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` + assert.Equal(t, expected, actual) + return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil + }) + + return &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: userService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, "query"), + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"`), SegmentType: StaticSegmentType, }, @@ -4566,14 +4946,12 @@ func TestResolver_WithHeader(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, "foo", actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -4639,14 +5017,12 @@ func TestResolver_WithVariableRemapping(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.AssignableToTypeOf(&bytes.Buffer{})). - Do(func(ctx context.Context, input []byte, w io.Writer) (err error) { + Load(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, tc.expectedOutput, actual) - _, err = w.Write([]byte(`{"bar":"baz"}`)) - return - }). - Return(nil) + return []byte(`{"bar":"baz"}`), nil + }) out := &bytes.Buffer{} res := &GraphQLResponse{ @@ -5909,50 +6285,353 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { Data: []byte(`{"method":"POST","url":"http://localhost:4000"}`), }, }, - }, - }, - Filter: &SubscriptionFilter{ - In: &SubscriptionFieldFilter{ - FieldPath: []string{"id"}, - Values: []InputTemplate{ - { + }, + }, + Filter: &SubscriptionFilter{ + In: &SubscriptionFieldFilter{ + FieldPath: []string{"id"}, + Values: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`x.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"a"}, + Renderer: NewPlainVariableRenderer(), + }, + { + SegmentType: StaticSegmentType, + Data: []byte(`.`), + }, + { + SegmentType: VariableSegmentType, + VariableKind: ContextVariableKind, + VariableSourcePath: []string{"b"}, + Renderer: NewPlainVariableRenderer(), + }, + }, + }, + }, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("oneUserByID"), + Value: &Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }, + }, + }, + }, + }, + } + + out := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + out.complete.Store(false) + + id := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 1, + } + + resolver := newResolver(c) + + ctx := &Context{ + ctx: context.Background(), + Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), + } + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) + assert.NoError(t, err) + out.AwaitComplete(t, defaultTimeout) + assert.Equal(t, 4, len(out.Messages())) + assert.ElementsMatch(t, []string{ + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, + }, out.Messages()) + }) +} + +func Benchmark_NestedBatching(b *testing.B) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := newResolver(rCtx) + + productsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + []byte(`{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}}`)) + stockService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`)) + reviewsService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[{"__typename":"Product","upc":"1"},{"__typename":"Product","upc":"2"},{"__typename":"Product","upc":"3"}]}}}`), + []byte(`{"data":{"_entities":[{"__typename":"Product","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2"}}]},{"__typename":"Product","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1"}}]},{"__typename":"Product","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2"}}]}]}}`)) + usersService := fakeDataSourceWithInputCheck(b, + []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[{"__typename":"User","id":"1"},{"__typename":"User","id":"2"}]}}}`), + []byte(`{"data":{"_entities":[{"name":"user-1"},{"name":"user-2"}]}}`)) + + plan := &GraphQLResponse{ + Fetches: Sequence( + SingleWithPath(&SingleFetch{ + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://products","body":{"query":"query{topProducts{name __typename upc}}"}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + FetchConfiguration: FetchConfiguration{ + DataSource: productsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + }, ""), + Parallel( + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://reviews","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {reviews {body author {__typename id}}}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + }, + DataSource: reviewsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ Segments: []TemplateSegment{ { + Data: []byte(`{"method":"POST","url":"http://stock","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Product {stock}}}","variables":{"representations":[`), SegmentType: StaticSegmentType, - Data: []byte(`x.`), }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("upc"), + Value: &String{ + Path: []string{"upc"}, + }, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"a"}, - Renderer: NewPlainVariableRenderer(), + Data: []byte(`,`), + SegmentType: StaticSegmentType, }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ { + Data: []byte(`]}}}`), SegmentType: StaticSegmentType, - Data: []byte(`.`), }, + }, + }, + }, + DataSource: stockService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts", ArrayPath("topProducts")), + ), + SingleWithPath(&BatchEntityFetch{ + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://users","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {name}}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ { - SegmentType: VariableSegmentType, - VariableKind: ContextVariableKind, - VariableSourcePath: []string{"b"}, - Renderer: NewPlainVariableRenderer(), + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + }, + { + Name: []byte("id"), + Value: &String{ + Path: []string{"id"}, + }, + }, + }, + }), }, }, }, }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, }, - }, - Response: &GraphQLResponse{ - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("oneUserByID"), - Value: &Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, + DataSource: usersService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + }, "topProducts.@.reviews.@.author", ArrayPath("topProducts"), ArrayPath("reviews"), ObjectPath("author")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("topProducts"), + Value: &Array{ + Path: []string{"topProducts"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("stock"), + Value: &Integer{ + Path: []string{"stock"}, + }, + }, + { + Name: []byte("reviews"), + Value: &Array{ + Path: []string{"reviews"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("body"), + Value: &String{ + Path: []string{"body"}, + }, + }, + { + Name: []byte("author"), + Value: &Object{ + Path: []string{"author"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + }, + }, + }, + }, }, }, }, @@ -5961,41 +6640,53 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { }, }, }, - } + }, + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + } - out := &SubscriptionRecorder{ - buf: &bytes.Buffer{}, - messages: []string{}, - complete: atomic.Bool{}, - } - out.complete.Store(false) + expected := []byte(`{"data":{"topProducts":[{"name":"Table","stock":8,"reviews":[{"body":"Love Table!","author":{"name":"user-1"}},{"body":"Prefer other Table.","author":{"name":"user-2"}}]},{"name":"Couch","stock":2,"reviews":[{"body":"Couch Too expensive.","author":{"name":"user-1"}}]},{"name":"Chair","stock":5,"reviews":[{"body":"Chair Could be better.","author":{"name":"user-2"}}]}]}}`) - id := SubscriptionIdentifier{ - ConnectionID: 1, - SubscriptionID: 1, - } + pool := sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, 1024)) + }, + } - resolver := newResolver(c) + ctxPool := sync.Pool{ + New: func() interface{} { + return NewContext(context.Background()) + }, + } - ctx := &Context{ - ctx: context.Background(), - Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), - } + b.ReportAllocs() + b.SetBytes(int64(len(expected))) + b.ResetTimer() - err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, out, id) - assert.NoError(t, err) - out.AwaitComplete(t, defaultTimeout) - assert.Equal(t, 4, len(out.Messages())) - assert.ElementsMatch(t, []string{ - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - `{"errors":[{"message":"invalid subscription filter template"}],"data":null}`, - }, out.Messages()) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ctx := ctxPool.Get().(*Context) + buf := pool.Get().(*bytes.Buffer) + ctx.ctx = context.Background() + _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(expected, buf.Bytes()) { + require.Equal(b, string(expected), buf.String()) + } + + buf.Reset() + pool.Put(buf) + + ctx.Free() + ctxPool.Put(ctx) + } }) } -func Benchmark_NestedBatching(b *testing.B) { +func Benchmark_NestedBatchingArena(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -6293,7 +6984,7 @@ func Benchmark_NestedBatching(b *testing.B) { ctx := ctxPool.Get().(*Context) buf := pool.Get().(*bytes.Buffer) ctx.ctx = context.Background() - _, err := resolver.ResolveGraphQLResponse(ctx, plan, nil, buf) + _, err := resolver.ArenaResolveGraphQLResponse(ctx, plan, buf) if err != nil { b.Fatal(err) } @@ -6310,7 +7001,7 @@ func Benchmark_NestedBatching(b *testing.B) { }) } -func Benchmark_NestedBatchingWithoutChecks(b *testing.B) { +func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() From 3142c9011da0c0e587c6464fe8df2f7c13b620bf Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 16:28:42 +0200 Subject: [PATCH 03/57] chore: implement weak arena pool --- v2/pkg/engine/resolve/loader.go | 4 +++ v2/pkg/engine/resolve/resolve.go | 60 +++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 1bab9779b9..70626cbe4b 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -60,6 +60,10 @@ type ResponseInfo struct { responseBody []byte } +func (r *ResponseInfo) GetResponseBody() string { + return string(r.responseBody) +} + func newResponseInfo(res *result, subgraphError error) *ResponseInfo { responseInfo := &ResponseInfo{ StatusCode: res.statusCode, diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 4a0075f6b4..01417606f9 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,7 +7,9 @@ import ( "context" "fmt" "io" + "sync" "time" + "weak" "github.com/buger/jsonparser" "github.com/pkg/errors" @@ -70,6 +72,14 @@ type Resolver struct { heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration + + arenaPool []weak.Pointer[arenaPoolItem] + arenaSize map[uint64]int + arenaPoolMu sync.Mutex +} + +type arenaPoolItem struct { + jsonArena arena.Arena } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -229,6 +239,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { resolver.maxConcurrency <- struct{}{} } + resolver.arenaSize = make(map[uint64]int) + go resolver.processEvents() return resolver @@ -292,6 +304,46 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } +func (r *Resolver) acquireArena(id uint64) *arenaPoolItem { + r.arenaPoolMu.Lock() + defer r.arenaPoolMu.Unlock() + + for i := 0; i < len(r.arenaPool); i++ { + v := r.arenaPool[i].Value() + r.arenaPool = append(r.arenaPool[:i], r.arenaPool[i+1:]...) + if v == nil { + continue + } + return v + } + + size := arena.WithMinBufferSize(r.getArenaSize(id)) + + return &arenaPoolItem{ + jsonArena: arena.NewMonotonicArena(size), + } +} + +func (r *Resolver) getArenaSize(id uint64) int { + if size, ok := r.arenaSize[id]; ok { + return size + } + return 1024 * 1024 +} + +func (r *Resolver) releaseArena(id uint64, item *arenaPoolItem) { + peak := item.jsonArena.Peak() + item.jsonArena.Reset() + + r.arenaPoolMu.Lock() + defer r.arenaPoolMu.Unlock() + + r.arenaSize[id] = peak + + w := weak.Make(item) + r.arenaPool = append(r.arenaPool, w) +} + func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { resp := &GraphQLResolveInfo{} @@ -304,10 +356,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) - jsonArena := arena.NewMonotonicArena() - defer jsonArena.Release() - t.loader.jsonArena = jsonArena - t.resolvable.astjsonArena = jsonArena + poolItem := r.acquireArena(ctx.Request.ID) + defer r.releaseArena(ctx.Request.ID, poolItem) + t.loader.jsonArena = poolItem.jsonArena + t.resolvable.astjsonArena = poolItem.jsonArena err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { From 1c9b87758cc4e212a8f66a3233cf24c218119911 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 15 Oct 2025 17:49:34 +0200 Subject: [PATCH 04/57] chore: default buffer size --- v2/pkg/engine/datasource/httpclient/nethttpclient.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 0eb4360fa1..30b01f0120 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -262,8 +262,9 @@ func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []b pool.Hash64.Put(h) ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - var buf bytes.Buffer - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, &buf, ContentTypeJSON) + buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) + + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, buf, ContentTypeJSON) if err != nil { return nil, err } @@ -333,8 +334,9 @@ func DoMultipartForm( bodyHash := h.Sum64() ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - var buf bytes.Buffer - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, &buf, contentType) + buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) + + err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, buf, contentType) if err != nil { return nil, err } From 112171e9515440da04ab8d06c7ff267c70aaa5ad Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 16 Oct 2025 23:37:08 +0200 Subject: [PATCH 05/57] chore: move single flight into loader --- .../datasource/httpclient/nethttpclient.go | 81 +++--------- v2/pkg/engine/resolve/loader.go | 119 ++++++++++++------ v2/pkg/engine/resolve/resolvable.go | 19 ++- v2/pkg/engine/resolve/resolve.go | 17 ++- v2/pkg/engine/resolve/singleflight.go | 86 +++++++++++++ 5 files changed, 214 insertions(+), 108 deletions(-) create mode 100644 v2/pkg/engine/resolve/singleflight.go diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 30b01f0120..3fa74b9497 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -20,7 +20,6 @@ import ( "github.com/buger/jsonparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -130,21 +129,11 @@ func respBodyReader(res *http.Response) (io.Reader, error) { } } -type bodyHashContextKey struct{} - -func BodyHashFromContext(ctx context.Context) (uint64, bool) { - value := ctx.Value(bodyHashContextKey{}) - if value == nil { - return 0, false - } - return value.(uint64), true -} - -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out *bytes.Buffer, contentType string) (err error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { - return err + return nil, err } if headers != nil { @@ -161,7 +150,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err }) if err != nil { - return err + return nil, err } } @@ -190,7 +179,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } }) if err != nil { - return err + return nil, err } request.URL.RawQuery = query.Encode() } @@ -204,7 +193,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head response, err := client.Do(request) if err != nil { - return err + return nil, err } defer response.Body.Close() @@ -212,23 +201,20 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head respReader, err := respBodyReader(response) if err != nil { - return err + return nil, err } - if !enableTrace { - if response.ContentLength > 0 { - out.Grow(int(response.ContentLength)) - } else { - out.Grow(1024 * 4) - } - _, err = out.ReadFrom(respReader) - return + out := bytes.NewBuffer(make([]byte, 0, 1024*4)) + _, err = out.ReadFrom(respReader) + if err != nil { + return nil, err } - data, err := io.ReadAll(respReader) - if err != nil { - return err + if !enableTrace { + return out.Bytes(), nil } + + data := out.Bytes() responseTrace := TraceHTTP{ Request: TraceHTTPRequest{ Method: request.Method, @@ -244,31 +230,18 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } trace, err := json.Marshal(responseTrace) if err != nil { - return err + return nil, err } responseWithTraceExtension, err := jsonparser.Set(data, trace, "extensions", "trace") if err != nil { - return err + return nil, err } - _, err = out.Write(responseWithTraceExtension) - return err + return responseWithTraceExtension, nil } func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - _, _ = h.Write(body) - bodyHash := h.Sum64() - pool.Hash64.Put(h) - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - - buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) - - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, buf, ContentTypeJSON) - if err != nil { - return nil, err - } - return buf.Bytes(), nil + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON) } func DoMultipartForm( @@ -280,10 +253,6 @@ func DoMultipartForm( url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - h := pool.Hash64.Get() - defer pool.Hash64.Put(h) - _, _ = h.Write(body) - formValues := map[string]io.Reader{ "operations": bytes.NewReader(body), } @@ -300,10 +269,9 @@ func DoMultipartForm( } hasWrittenFileName = true - fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) + _, _ = fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) key := fmt.Sprintf("%d", i) - _, _ = h.WriteString(file.Path()) temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { @@ -331,16 +299,7 @@ func DoMultipartForm( } }() - bodyHash := h.Sum64() - ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) - - buf := bytes.NewBuffer(make([]byte, 0, 1024*4)) - - err = makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, buf, contentType) - if err != nil { - return nil, err - } - return buf.Bytes(), nil + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, contentType) } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 70626cbe4b..7031190538 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -139,6 +139,7 @@ type result struct { httpResponseContext *httpclient.ResponseContext out []byte + singleFlightStats *singleFlightStats } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -183,6 +184,7 @@ type Loader struct { taintedObjs taintedObjects jsonArena arena.Arena + sf *SingleFlight } func (l *Loader) Free() { @@ -772,6 +774,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // If the error propagation mode is pass-through, we append the errors to the root array + l.resolvable.ensureErrorsInitialized() l.resolvable.errors.AppendArrayItems(value) return nil } @@ -808,6 +811,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V return err } + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1062,6 +1066,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { return err } + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, apolloRouterStatusError) return nil @@ -1075,6 +1080,7 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1086,6 +1092,7 @@ func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, re return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1104,7 +1111,7 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) - + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1129,6 +1136,7 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) extensionErrorCode := fmt.Sprintf(`"extensions":{"code":"%s"}`, errorcodes.UnauthorizedFieldOrType) + l.resolvable.ensureErrorsInitialized() if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { @@ -1207,6 +1215,7 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result return err } } + l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -1598,29 +1607,8 @@ func redactHeaders(rawJSON json.RawMessage) (json.RawMessage, error) { return redactedJSON, nil } -type disallowSingleFlightContextKey struct{} - -func SingleFlightDisallowed(ctx context.Context) bool { - return ctx.Value(disallowSingleFlightContextKey{}) != nil -} - -type singleFlightStatsKey struct{} - -type SingleFlightStats struct { - SingleFlightUsed bool - SingleFlightSharedResponse bool -} - -func GetSingleFlightStats(ctx context.Context) *SingleFlightStats { - maybeStats := ctx.Value(singleFlightStatsKey{}) - if maybeStats == nil { - return nil - } - return maybeStats.(*SingleFlightStats) -} - -func setSingleFlightStats(ctx context.Context, stats *SingleFlightStats) context.Context { - return context.WithValue(ctx, singleFlightStatsKey{}, stats) +type singleFlightStats struct { + used, shared bool } func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *DataSourceLoadTrace) { @@ -1636,7 +1624,70 @@ func (l *Loader) setTracingInput(fetchItem *FetchItem, input []byte, trace *Data } } -func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { +type loaderContextKey string + +const ( + operationTypeContextKey loaderContextKey = "operationType" +) + +func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { + if ctx == nil { + return ast.OperationTypeQuery + } + if v := ctx.Value(operationTypeContextKey); v != nil { + if opType, ok := v.(ast.OperationType); ok { + return opType + } + } + return ast.OperationTypeQuery +} + +func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { + + if l.info != nil { + ctx = context.WithValue(ctx, operationTypeContextKey, l.info.OperationType) + } + + if l.info == nil || l.info.OperationType == ast.OperationTypeMutation { + // Disable single flight for mutations + return l.loadByContextDirect(ctx, source, input, res) + } + + key, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) + if res.singleFlightStats != nil { + res.singleFlightStats.used = shared + res.singleFlightStats.shared = shared + } + + if shared { + select { + case <-item.loaded: + case <-ctx.Done(): + return ctx.Err() + } + + if item.err != nil { + return item.err + } + + res.out = item.response + return nil + } + + defer l.sf.Finish(key, item) + + // Perform the actual load + err := l.loadByContextDirect(ctx, source, input, res) + if err != nil { + item.err = err + return err + } + + item.response = res.out + return nil +} + +func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, input []byte, res *result) error { if l.ctx.Files != nil { res.out, res.err = source.LoadWithFiles(ctx, input, l.ctx.Files) } else { @@ -1674,7 +1725,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so } } if l.ctx.TracingOptions.Enable { - ctx = setSingleFlightStats(ctx, &SingleFlightStats{}) + res.singleFlightStats = &singleFlightStats{} trace.Path = fetchItem.ResponsePath if !l.ctx.TracingOptions.ExcludeInput { trace.Input = make([]byte, len(input)) @@ -1778,9 +1829,6 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so ctx = httptrace.WithClientTrace(ctx, clientTrace) } } - if l.info != nil && l.info.OperationType == ast.OperationTypeMutation { - ctx = context.WithValue(ctx, disallowSingleFlightContextKey{}, true) - } var responseContext *httpclient.ResponseContext ctx, responseContext = httpclient.InjectResponseContext(ctx) @@ -1789,24 +1837,23 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so // Prevent that the context is destroyed when the loader hook return an empty context if res.loaderHookContext != nil { - res.err = l.loadByContext(res.loaderHookContext, source, input, res) + res.err = l.loadByContext(res.loaderHookContext, source, fetchItem, input, res) } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) res.loaderHookContext = ctx // Set the context to the original context to ensure that OnFinished hook gets valid context } } else { - res.err = l.loadByContext(ctx, source, input, res) + res.err = l.loadByContext(ctx, source, fetchItem, input, res) } res.statusCode = responseContext.StatusCode res.httpResponseContext = responseContext if l.ctx.TracingOptions.Enable { - stats := GetSingleFlightStats(ctx) - if stats != nil { - trace.SingleFlightUsed = stats.SingleFlightUsed - trace.SingleFlightSharedResponse = stats.SingleFlightSharedResponse + if res.singleFlightStats != nil { + trace.SingleFlightUsed = res.singleFlightStats.used + trace.SingleFlightSharedResponse = res.singleFlightStats.shared } if !l.ctx.TracingOptions.ExcludeOutput && len(res.out) > 0 { trace.Output, _ = l.compactJSON(res.out) diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5aceb2110c..21470f475d 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -111,7 +111,7 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames r.data = astjson.ObjectValue(r.astjsonArena) - r.errors = astjson.ArrayValue(r.astjsonArena) + r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { @@ -129,6 +129,7 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.ctx = ctx r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames + r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) if err != nil { @@ -158,9 +159,6 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc if r.data == nil { r.data = astjson.ObjectValue(r.astjsonArena) } - if r.errors == nil { - r.errors = astjson.ArrayValue(r.astjsonArena) - } return } @@ -169,7 +167,7 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil - r.errors = astjson.ArrayValue(r.astjsonArena) + r.errors = nil hasErrors := r.walkNode(node, data) if hasErrors { @@ -235,6 +233,12 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *F return r.printErr } +func (r *Resolvable) ensureErrorsInitialized() { + if r.errors == nil { + r.errors = astjson.ArrayValue(r.astjsonArena) + } +} + func (r *Resolvable) enclosingTypeName() string { if len(r.enclosingTypeNames) > 0 { return r.enclosingTypeNames[len(r.enclosingTypeNames)-1] @@ -761,6 +765,7 @@ func (r *Resolvable) addRejectFieldError(reason string, ds DataSourceInfo, field } r.ctx.appendSubgraphErrors(errors.New(errorMessage), NewSubgraphError(ds, fieldPath, reason, 0)) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, errorMessage, errorcodes.UnauthorizedFieldOrType, r.path) r.popNodePathElement(nodePath) } @@ -1202,6 +1207,7 @@ func (r *Resolvable) addNonNullableFieldError(fieldPath []string, parent *astjso r.addValueCompletion(r.renderApolloCompatibleNonNullableErrorMessage(), errorcodes.InvalidGraphql) } else { errorMessage := fmt.Sprintf("Cannot return null for non-nullable field '%s'.", r.renderFieldPath()) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, errorMessage, r.path) } r.popNodePathElement(fieldPath) @@ -1272,16 +1278,19 @@ func (r *Resolvable) renderFieldCoordinates() string { func (r *Resolvable) addError(message string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorToArray(r.astjsonArena, r.errors, message, r.path) r.popNodePathElement(fieldPath) } func (r *Resolvable) addErrorWithCode(message, code string) { + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) } func (r *Resolvable) addErrorWithCodeAndPath(message, code string, fieldPath []string) { r.pushNodePathElement(fieldPath) + r.ensureErrorsInitialized() fastjsonext.AppendErrorWithExtensionsCodeToArray(r.astjsonArena, r.errors, message, code, r.path) r.popNodePathElement(fieldPath) } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 01417606f9..eef77b5b81 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -76,6 +76,9 @@ type Resolver struct { arenaPool []weak.Pointer[arenaPoolItem] arenaSize map[uint64]int arenaPoolMu sync.Mutex + + // Single flight cache for deduplicating requests across all loaders + sf *SingleFlight } type arenaPoolItem struct { @@ -233,6 +236,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, + sf: NewSingleFlight(), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -246,7 +250,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight) *tools { return &tools{ resolvable: NewResolvable(nil, options.ResolvableOptions), loader: &Loader{ @@ -264,6 +268,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ apolloRouterCompatibilitySubrequestHTTPError: options.ApolloRouterCompatibilitySubrequestHTTPError, propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, + sf: sf, }, } } @@ -282,7 +287,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -354,7 +359,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) poolItem := r.acquireArena(ctx.Request.ID) defer r.releaseArena(ctx.Request.ID, poolItem) @@ -511,7 +516,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar input := make([]byte, len(sharedInput)) copy(input, sharedInput) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) @@ -1104,7 +1109,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1213,7 +1218,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go new file mode 100644 index 0000000000..7843bafece --- /dev/null +++ b/v2/pkg/engine/resolve/singleflight.go @@ -0,0 +1,86 @@ +package resolve + +import ( + "context" + "sync" + + "github.com/cespare/xxhash/v2" +) + +type SingleFlightItem struct { + loaded chan struct{} + response []byte + err error +} + +type SingleFlight struct { + mu *sync.RWMutex + items map[uint64]*SingleFlightItem + xxPool *sync.Pool + cleanup chan func() +} + +func NewSingleFlight() *SingleFlight { + return &SingleFlight{ + items: make(map[uint64]*SingleFlightItem), + mu: new(sync.RWMutex), + xxPool: &sync.Pool{ + New: func() any { + return xxhash.New() + }, + }, + cleanup: make(chan func()), + } +} + +func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (key uint64, item *SingleFlightItem, shared bool) { + key = s.key(fetchItem, input) + + // First, try to get the item with a read lock + s.mu.RLock() + item, exists := s.items[key] + s.mu.RUnlock() + if exists { + return key, item, true + } + + // If not exists, acquire a write lock to create the item + s.mu.Lock() + // Double-check if the item was created while acquiring the write lock + item, exists = s.items[key] + if exists { + s.mu.Unlock() + return key, item, true + } + + // Create a new item + item = &SingleFlightItem{ + loaded: make(chan struct{}), + } + s.items[key] = item + s.mu.Unlock() + return key, item, false +} + +func (s *SingleFlight) key(fetchItem *FetchItem, input []byte) uint64 { + h := s.xxPool.Get().(*xxhash.Digest) + if fetchItem != nil && fetchItem.Fetch != nil { + info := fetchItem.Fetch.FetchInfo() + if info != nil { + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.WriteString(":") + } + } + _, _ = h.Write(input) + key := h.Sum64() + h.Reset() + s.xxPool.Put(h) + return key +} + +func (s *SingleFlight) Finish(key uint64, item *SingleFlightItem) { + close(item.loaded) + s.mu.Lock() + delete(s.items, key) + s.mu.Unlock() +} From 7a777ea9f163206b40c9a85623931e0526a05b58 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 17 Oct 2025 00:09:00 +0200 Subject: [PATCH 06/57] chore: add http client buffer size hint --- .../datasource/httpclient/nethttpclient.go | 19 ++++- v2/pkg/engine/resolve/loader.go | 6 +- v2/pkg/engine/resolve/singleflight.go | 81 +++++++++++++++---- 3 files changed, 88 insertions(+), 18 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 3fa74b9497..27b0434c11 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -129,6 +129,23 @@ func respBodyReader(res *http.Response) (io.Reader, error) { } } +type httpClientContext string + +const ( + sizeHintKey httpClientContext = "size-hint" +) + +func WithHTTPClientSizeHint(ctx context.Context, size int) context.Context { + return context.WithValue(ctx, sizeHintKey, size) +} + +func buffer(ctx context.Context) *bytes.Buffer { + if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { + return bytes.NewBuffer(make([]byte, 0, sizeHint)) + } + return bytes.NewBuffer(make([]byte, 0, 1024*4)) // default to 4KB +} + func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) @@ -204,7 +221,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return nil, err } - out := bytes.NewBuffer(make([]byte, 0, 1024*4)) + out := buffer(ctx) _, err = out.ReadFrom(respReader) if err != nil { return nil, err diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 7031190538..2c923b2c9f 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1653,7 +1653,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return l.loadByContextDirect(ctx, source, input, res) } - key, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) + sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) if res.singleFlightStats != nil { res.singleFlightStats.used = shared res.singleFlightStats.shared = shared @@ -1674,7 +1674,9 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return nil } - defer l.sf.Finish(key, item) + ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) + + defer l.sf.Finish(sfKey, fetchKey, item) // Perform the actual load err := l.loadByContextDirect(ctx, source, input, res) diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go index 7843bafece..e298531967 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/singleflight.go @@ -11,18 +11,26 @@ type SingleFlightItem struct { loaded chan struct{} response []byte err error + sizeHint int } type SingleFlight struct { mu *sync.RWMutex items map[uint64]*SingleFlightItem + sizes map[uint64]*fetchSize xxPool *sync.Pool cleanup chan func() } +type fetchSize struct { + count int + totalBytes int +} + func NewSingleFlight() *SingleFlight { return &SingleFlight{ items: make(map[uint64]*SingleFlightItem), + sizes: make(map[uint64]*fetchSize), mu: new(sync.RWMutex), xxPool: &sync.Pool{ New: func() any { @@ -33,37 +41,49 @@ func NewSingleFlight() *SingleFlight { } } -func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (key uint64, item *SingleFlightItem, shared bool) { - key = s.key(fetchItem, input) +func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { + sfKey, fetchKey = s.keys(fetchItem, input) // First, try to get the item with a read lock s.mu.RLock() - item, exists := s.items[key] + item, exists := s.items[sfKey] s.mu.RUnlock() if exists { - return key, item, true + return sfKey, fetchKey, item, true } // If not exists, acquire a write lock to create the item s.mu.Lock() // Double-check if the item was created while acquiring the write lock - item, exists = s.items[key] + item, exists = s.items[sfKey] if exists { s.mu.Unlock() - return key, item, true + return sfKey, fetchKey, item, true } // Create a new item item = &SingleFlightItem{ loaded: make(chan struct{}), } - s.items[key] = item + if size, ok := s.sizes[fetchKey]; ok { + item.sizeHint = size.totalBytes / size.count + } + s.items[sfKey] = item s.mu.Unlock() - return key, item, false + return sfKey, fetchKey, item, false } -func (s *SingleFlight) key(fetchItem *FetchItem, input []byte) uint64 { +func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) + sfKey = s.sfKey(h, fetchItem, input) + h.Reset() + fetchKey = s.fetchKey(h, fetchItem) + h.Reset() + s.xxPool.Put(h) + return sfKey, fetchKey +} + +func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -72,15 +92,46 @@ func (s *SingleFlight) key(fetchItem *FetchItem, input []byte) uint64 { } } _, _ = h.Write(input) - key := h.Sum64() - h.Reset() - s.xxPool.Put(h) - return key + return h.Sum64() } -func (s *SingleFlight) Finish(key uint64, item *SingleFlightItem) { +func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { + if fetchItem == nil || fetchItem.Fetch == nil { + return 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return 0 + } + _, _ = h.WriteString(info.DataSourceID) + _, _ = h.WriteString("|") + for i := range info.RootFields { + if i != 0 { + _, _ = h.WriteString(",") + } + _, _ = h.WriteString(info.RootFields[i].TypeName) + _, _ = h.WriteString(".") + _, _ = h.WriteString(info.RootFields[i].FieldName) + } + return h.Sum64() +} + +func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { close(item.loaded) s.mu.Lock() - delete(s.items, key) + delete(s.items, sfKey) + if size, ok := s.sizes[fetchKey]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += len(item.response) + } else { + s.sizes[fetchKey] = &fetchSize{ + count: 1, + totalBytes: len(item.response), + } + } s.mu.Unlock() } From c41b4b6300dc500f6fea2795d3b6056b5dfbe6a1 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 17 Oct 2025 00:20:53 +0200 Subject: [PATCH 07/57] chore: selectItems on arena --- v2/pkg/engine/resolve/loader.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 2c923b2c9f..23da2cbe02 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -367,7 +367,7 @@ func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Valu if len(items) == 0 { break } - items = selectItems(items, path[i]) + items = selectItems(l.jsonArena, items, path[i]) } return l.taintedObjs.filterOutTainted(items) } @@ -388,7 +388,7 @@ func isItemAllowedByTypename(obj *astjson.Value, typeNames []string) bool { return slices.Contains(typeNames, __typeNameStr) } -func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { +func selectItems(a arena.Arena, items []*astjson.Value, element FetchItemPathElement) []*astjson.Value { if len(items) == 0 { return nil } @@ -410,7 +410,7 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso } return []*astjson.Value{field} } - selected := make([]*astjson.Value, 0, len(items)) + selected := arena.AllocateSlice[*astjson.Value](a, 0, len(items)) for _, item := range items { if !isItemAllowedByTypename(item, element.TypeNames) { continue @@ -420,10 +420,10 @@ func selectItems(items []*astjson.Value, element FetchItemPathElement) []*astjso continue } if field.Type() == astjson.TypeArray { - selected = append(selected, field.GetArray()...) + selected = arena.SliceAppend(a, selected, field.GetArray()...) continue } - selected = append(selected, field) + selected = arena.SliceAppend(a, selected, field) } return selected } From 3e1454f355faf8a1b9f4060cf41f3bd5cafa4336 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 17 Oct 2025 12:40:07 +0200 Subject: [PATCH 08/57] chore: refactor arena pool into separate file --- v2/pkg/engine/resolve/arena.go | 78 +++++++++++++++++ v2/pkg/engine/resolve/inputtemplate.go | 25 ++++-- v2/pkg/engine/resolve/loader.go | 114 +++++++------------------ v2/pkg/engine/resolve/resolve.go | 62 ++------------ 4 files changed, 131 insertions(+), 148 deletions(-) create mode 100644 v2/pkg/engine/resolve/arena.go diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go new file mode 100644 index 0000000000..1bd5ee4958 --- /dev/null +++ b/v2/pkg/engine/resolve/arena.go @@ -0,0 +1,78 @@ +package resolve + +import ( + "sync" + "weak" + + "github.com/wundergraph/go-arena" +) + +// ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. +// It uses weak pointers to allow garbage collection of unused arenas while maintaining +// a pool of reusable arenas for high-frequency allocation patterns. +type ArenaPool struct { + pool []weak.Pointer[ArenaPoolItem] + sizes map[uint64]int + mu sync.Mutex +} + +// ArenaPoolItem wraps an arena.Arena for use in the pool +type ArenaPoolItem struct { + Arena arena.Arena +} + +// NewArenaPool creates a new ArenaPool instance +func NewArenaPool() *ArenaPool { + return &ArenaPool{ + sizes: make(map[uint64]int), + } +} + +// Acquire gets an arena from the pool or creates a new one if none are available. +// The id parameter is used to track arena sizes per use case for optimization. +func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { + p.mu.Lock() + defer p.mu.Unlock() + + // Try to find an available arena in the pool + for i := 0; i < len(p.pool); i++ { + v := p.pool[i].Value() + p.pool = append(p.pool[:i], p.pool[i+1:]...) + if v == nil { + continue + } + return v + } + + // No arena available, create a new one + size := arena.WithMinBufferSize(p.getArenaSize(id)) + return &ArenaPoolItem{ + Arena: arena.NewMonotonicArena(size), + } +} + +// Release returns an arena to the pool for reuse. +// The peak memory usage is recorded to optimize future arena sizes for this use case. +func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { + peak := item.Arena.Peak() + item.Arena.Reset() + + p.mu.Lock() + defer p.mu.Unlock() + + // Record the peak usage for this use case + p.sizes[id] = peak + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) +} + +// getArenaSize returns the optimal arena size for a given use case ID. +// If no size is recorded, it defaults to 1MB. +func (p *ArenaPool) getArenaSize(id uint64) int { + if size, ok := p.sizes[id]; ok { + return size + } + return 1024 * 1024 // Default 1MB +} diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 82825cac73..80db3cdd82 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -1,10 +1,10 @@ package resolve import ( - "bytes" "context" "errors" "fmt" + "io" "github.com/wundergraph/astjson" @@ -36,7 +36,7 @@ type InputTemplate struct { SetTemplateOutputToNullOnVariableNull bool } -func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables []string) error { +func SetInputUndefinedVariables(preparedInput InputTemplateWriter, undefinedVariables []string) error { if len(undefinedVariables) > 0 { output, err := httpclient.SetUndefinedVariables(preparedInput.Bytes(), undefinedVariables) if err != nil { @@ -55,7 +55,14 @@ func SetInputUndefinedVariables(preparedInput *bytes.Buffer, undefinedVariables // to callers; renderSegments intercepts it and writes literal.NULL instead. var errSetTemplateOutputNull = errors.New("set to null") -func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer) error { +type InputTemplateWriter interface { + io.Writer + io.StringWriter + Reset() + Bytes() []byte +} + +func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter) error { var undefinedVariables []string if err := i.renderSegments(ctx, data, i.Segments, preparedInput, &undefinedVariables); err != nil { @@ -65,12 +72,12 @@ func (i *InputTemplate) Render(ctx *Context, data *astjson.Value, preparedInput return SetInputUndefinedVariables(preparedInput, undefinedVariables) } -func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) RenderAndCollectUndefinedVariables(ctx *Context, data *astjson.Value, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { err = i.renderSegments(ctx, data, i.Segments, preparedInput, undefinedVariables) return } -func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput *bytes.Buffer, undefinedVariables *[]string) (err error) { +func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segments []TemplateSegment, preparedInput InputTemplateWriter, undefinedVariables *[]string) (err error) { for _, segment := range segments { switch segment.SegmentType { case StaticSegmentType: @@ -107,7 +114,7 @@ func (i *InputTemplate) renderSegments(ctx *Context, data *astjson.Value, segmen return err } -func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { value := variables.Get(segment.VariableSourcePath...) if value == nil || value.Type() == astjson.TypeNull { if i.SetTemplateOutputToNullOnVariableNull { @@ -119,11 +126,11 @@ func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables *ast return segment.Renderer.RenderVariable(ctx, value, preparedInput) } -func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderResolvableObjectVariable(ctx context.Context, objectData *astjson.Value, segment TemplateSegment, preparedInput InputTemplateWriter) error { return segment.Renderer.RenderVariable(ctx, objectData, preparedInput) } -func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput *bytes.Buffer) (variableWasUndefined bool, err error) { +func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput InputTemplateWriter) (variableWasUndefined bool, err error) { variableSourcePath := segment.VariableSourcePath if len(variableSourcePath) == 1 && ctx.RemapVariables != nil { nameToUse, hasMapping := ctx.RemapVariables[variableSourcePath[0]] @@ -142,7 +149,7 @@ func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegm return false, segment.Renderer.RenderVariable(ctx.Context(), value, preparedInput) } -func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput *bytes.Buffer) error { +func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput InputTemplateWriter) error { if len(path) != 1 { return errHeaderPathInvalid } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 23da2cbe02..71a3c5304e 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -11,7 +11,6 @@ import ( "slices" "strconv" "strings" - "sync" "time" "github.com/buger/jsonparser" @@ -359,7 +358,9 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Value { - items := []*astjson.Value{l.resolvable.data} + // Use arena allocation for the initial items slice + items := arena.AllocateSlice[*astjson.Value](l.jsonArena, 1, 1) + items[0] = l.resolvable.data if len(path) == 0 { return l.taintedObjs.filterOutTainted(items) } @@ -1286,7 +1287,7 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchItem *FetchItem, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := &bytes.Buffer{} + buf := bytes.NewBuffer(nil) inputData := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { @@ -1325,36 +1326,8 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI return nil } -var ( - entityFetchPool = sync.Pool{ - New: func() any { - return &entityFetchBuffer{ - item: &bytes.Buffer{}, - preparedInput: &bytes.Buffer{}, - } - }, - } -) - -type entityFetchBuffer struct { - item *bytes.Buffer - preparedInput *bytes.Buffer -} - -func acquireEntityFetchBuffer() *entityFetchBuffer { - return entityFetchPool.Get().(*entityFetchBuffer) -} - -func releaseEntityFetchBuffer(buf *entityFetchBuffer) { - buf.item.Reset() - buf.preparedInput.Reset() - entityFetchPool.Put(buf) -} - func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := acquireEntityFetchBuffer() - defer releaseEntityFetchBuffer(buf) input := itemsData(l.jsonArena, items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} @@ -1363,14 +1336,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } } + preparedInput := bytes.NewBuffer(nil) + item := bytes.NewBuffer(nil) + var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = fetch.Input.Item.Render(l.ctx, input, buf.item) + err = fetch.Input.Item.Render(l.ctx, input, item) if err != nil { if fetch.Input.SkipErrItem { // skip fetch on render item error @@ -1382,7 +1358,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc } return errors.WithStack(err) } - renderedItem := buf.item.Bytes() + renderedItem := item.Bytes() if bytes.Equal(renderedItem, null) { // skip fetch if item is null res.fetchSkipped = true @@ -1401,17 +1377,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } } - _, _ = buf.item.WriteTo(buf.preparedInput) - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + _, _ = item.WriteTo(preparedInput) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + fetchInput := preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1429,41 +1405,9 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } -var ( - batchEntityFetchPool = sync.Pool{} -) - -type batchEntityFetchBuffer struct { - preparedInput *bytes.Buffer - itemInput *bytes.Buffer - keyGen *xxhash.Digest -} - -func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { - buf := batchEntityFetchPool.Get() - if buf == nil { - return &batchEntityFetchBuffer{ - preparedInput: &bytes.Buffer{}, - itemInput: &bytes.Buffer{}, - keyGen: xxhash.New(), - } - } - return buf.(*batchEntityFetchBuffer) -} - -func releaseBatchEntityFetchBuffer(buf *batchEntityFetchBuffer) { - buf.preparedInput.Reset() - buf.itemInput.Reset() - buf.keyGen.Reset() - batchEntityFetchPool.Put(buf) -} - func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - buf := acquireBatchEntityFetchBuffer() - defer releaseBatchEntityFetchBuffer(buf) - if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { @@ -1474,9 +1418,13 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, } } + preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) + itemInput := bytes.NewBuffer(make([]byte, 0, 32)) + keyGen := xxhash.New() + var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } @@ -1488,8 +1436,8 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, WithNextItem: for i, item := range items { for j := range fetch.Input.Items { - buf.itemInput.Reset() - err = fetch.Input.Items[j].Render(l.ctx, item, buf.itemInput) + itemInput.Reset() + err = fetch.Input.Items[j].Render(l.ctx, item, itemInput) if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign @@ -1501,18 +1449,18 @@ WithNextItem: } return errors.WithStack(err) } - if fetch.Input.SkipNullItems && buf.itemInput.Len() == 4 && bytes.Equal(buf.itemInput.Bytes(), null) { + if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { res.batchStats[i] = append(res.batchStats[i], -1) continue } - if fetch.Input.SkipEmptyObjectItems && buf.itemInput.Len() == 2 && bytes.Equal(buf.itemInput.Bytes(), emptyObject) { + if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { res.batchStats[i] = append(res.batchStats[i], -1) continue } - buf.keyGen.Reset() - _, _ = buf.keyGen.Write(buf.itemInput.Bytes()) - itemHash := buf.keyGen.Sum64() + keyGen.Reset() + _, _ = keyGen.Write(itemInput.Bytes()) + itemHash := keyGen.Sum64() for k := range itemHashes { if itemHashes[k] == itemHash { res.batchStats[i] = append(res.batchStats[i], k) @@ -1521,12 +1469,12 @@ WithNextItem: } itemHashes = append(itemHashes, itemHash) if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, buf.preparedInput) + err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) if err != nil { return errors.WithStack(err) } } - _, _ = buf.itemInput.WriteTo(buf.preparedInput) + _, _ = itemInput.WriteTo(preparedInput) res.batchStats[i] = append(res.batchStats[i], batchItemIndex) batchItemIndex++ addSeparator = true @@ -1543,16 +1491,16 @@ WithNextItem: } } - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := buf.preparedInput.Bytes() + fetchInput := preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index eef77b5b81..ce09fe0863 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,15 +7,12 @@ import ( "context" "fmt" "io" - "sync" "time" - "weak" "github.com/buger/jsonparser" "github.com/pkg/errors" "go.uber.org/atomic" - "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -73,18 +70,12 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - arenaPool []weak.Pointer[arenaPoolItem] - arenaSize map[uint64]int - arenaPoolMu sync.Mutex + arenaPool *ArenaPool // Single flight cache for deduplicating requests across all loaders sf *SingleFlight } -type arenaPoolItem struct { - jsonArena arena.Arena -} - func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { r.asyncErrorWriter = w } @@ -236,6 +227,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, + arenaPool: NewArenaPool(), sf: NewSingleFlight(), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) @@ -243,8 +235,6 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { resolver.maxConcurrency <- struct{}{} } - resolver.arenaSize = make(map[uint64]int) - go resolver.processEvents() return resolver @@ -309,46 +299,6 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons return resp, err } -func (r *Resolver) acquireArena(id uint64) *arenaPoolItem { - r.arenaPoolMu.Lock() - defer r.arenaPoolMu.Unlock() - - for i := 0; i < len(r.arenaPool); i++ { - v := r.arenaPool[i].Value() - r.arenaPool = append(r.arenaPool[:i], r.arenaPool[i+1:]...) - if v == nil { - continue - } - return v - } - - size := arena.WithMinBufferSize(r.getArenaSize(id)) - - return &arenaPoolItem{ - jsonArena: arena.NewMonotonicArena(size), - } -} - -func (r *Resolver) getArenaSize(id uint64) int { - if size, ok := r.arenaSize[id]; ok { - return size - } - return 1024 * 1024 -} - -func (r *Resolver) releaseArena(id uint64, item *arenaPoolItem) { - peak := item.jsonArena.Peak() - item.jsonArena.Reset() - - r.arenaPoolMu.Lock() - defer r.arenaPoolMu.Unlock() - - r.arenaSize[id] = peak - - w := weak.Make(item) - r.arenaPool = append(r.arenaPool, w) -} - func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { resp := &GraphQLResolveInfo{} @@ -361,10 +311,10 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - poolItem := r.acquireArena(ctx.Request.ID) - defer r.releaseArena(ctx.Request.ID, poolItem) - t.loader.jsonArena = poolItem.jsonArena - t.resolvable.astjsonArena = poolItem.jsonArena + poolItem := r.arenaPool.Acquire(ctx.Request.ID) + defer r.arenaPool.Release(ctx.Request.ID, poolItem) + t.loader.jsonArena = poolItem.Arena + t.resolvable.astjsonArena = poolItem.Arena err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { From a41ec06ba8ca13b0121889ce35fa48d9381a8951 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 19 Oct 2025 19:50:20 +0200 Subject: [PATCH 09/57] refactor: update buffer size in HTTP client and enhance arena pool size tracking --- .../datasource/httpclient/nethttpclient.go | 2 +- v2/pkg/engine/resolve/arena.go | 25 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 27b0434c11..d6276c8375 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -143,7 +143,7 @@ func buffer(ctx context.Context) *bytes.Buffer { if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { return bytes.NewBuffer(make([]byte, 0, sizeHint)) } - return bytes.NewBuffer(make([]byte, 0, 1024*4)) // default to 4KB + return bytes.NewBuffer(make([]byte, 0, 64)) } func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index 1bd5ee4958..0aae889742 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -12,10 +12,15 @@ import ( // a pool of reusable arenas for high-frequency allocation patterns. type ArenaPool struct { pool []weak.Pointer[ArenaPoolItem] - sizes map[uint64]int + sizes map[uint64]*arenaPoolItemSize mu sync.Mutex } +type arenaPoolItemSize struct { + count int + totalBytes int +} + // ArenaPoolItem wraps an arena.Arena for use in the pool type ArenaPoolItem struct { Arena arena.Arena @@ -24,7 +29,7 @@ type ArenaPoolItem struct { // NewArenaPool creates a new ArenaPool instance func NewArenaPool() *ArenaPool { return &ArenaPool{ - sizes: make(map[uint64]int), + sizes: make(map[uint64]*arenaPoolItemSize), } } @@ -61,7 +66,19 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { defer p.mu.Unlock() // Record the peak usage for this use case - p.sizes[id] = peak + if size, ok := p.sizes[id]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[id] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } // Add the arena back to the pool using a weak pointer w := weak.Make(item) @@ -72,7 +89,7 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { // If no size is recorded, it defaults to 1MB. func (p *ArenaPool) getArenaSize(id uint64) int { if size, ok := p.sizes[id]; ok { - return size + return size.totalBytes / size.count } return 1024 * 1024 // Default 1MB } From ced27f30f64b24b461e1c1e28b44878ea7c28723 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 20 Oct 2025 20:24:59 +0200 Subject: [PATCH 10/57] chore: add second arena for response buffer --- v2/pkg/engine/resolve/resolve.go | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index ce09fe0863..90b534174e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -11,6 +11,7 @@ import ( "github.com/buger/jsonparser" "github.com/pkg/errors" + "github.com/wundergraph/go-arena" "go.uber.org/atomic" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" @@ -70,7 +71,8 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - arenaPool *ArenaPool + resolveArenaPool *ArenaPool + responseBufferPool *ArenaPool // Single flight cache for deduplicating requests across all loaders sf *SingleFlight @@ -227,7 +229,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, - arenaPool: NewArenaPool(), + resolveArenaPool: NewArenaPool(), + responseBufferPool: NewArenaPool(), sf: NewSingleFlight(), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) @@ -311,28 +314,36 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - poolItem := r.arenaPool.Acquire(ctx.Request.ID) - defer r.arenaPool.Release(ctx.Request.ID, poolItem) - t.loader.jsonArena = poolItem.Arena - t.resolvable.astjsonArena = poolItem.Arena + resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + t.loader.jsonArena = resolveArena.Arena + t.resolvable.astjsonArena = resolveArena.Arena err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } if !ctx.ExecutionOptions.SkipLoader { err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } } - err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, writer) + responseArena := r.responseBufferPool.Acquire(ctx.Request.ID) + buf := arena.NewArenaBuffer(responseArena.Arena) + err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) if err != nil { + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.responseBufferPool.Release(ctx.Request.ID, responseArena) return nil, err } + r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + _, err = writer.Write(buf.Bytes()) + r.responseBufferPool.Release(ctx.Request.ID, responseArena) return resp, err } From 67db907e1f1a9a94ff05b09af31d7ee6fb9fdcb2 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 24 Oct 2025 12:28:20 +0200 Subject: [PATCH 11/57] chore: add headers to DataSource args, add HeadersForSubgraphRequest to resolve Context --- .../graphql_datasource/graphql_datasource.go | 29 +-- .../graphql_datasource_test.go | 35 ++- .../graphql_subscription_client.go | 108 ---------- .../graphql_subscription_client_test.go | 202 ------------------ .../grpc_datasource/grpc_datasource.go | 5 +- .../grpc_datasource/grpc_datasource_test.go | 26 +-- .../datasource/httpclient/httpclient_test.go | 4 +- .../datasource/httpclient/nethttpclient.go | 18 +- .../introspection_datasource/source.go | 5 +- .../introspection_datasource/source_test.go | 2 +- .../pubsub_datasource_test.go | 8 + .../pubsub_datasource/pubsub_kafka.go | 31 +-- .../pubsub_datasource/pubsub_nats.go | 35 +-- .../staticdatasource/static_datasource.go | 5 +- v2/pkg/engine/plan/planner_test.go | 5 +- v2/pkg/engine/plan/visitor.go | 2 + v2/pkg/engine/resolve/authorization_test.go | 25 +-- v2/pkg/engine/resolve/context.go | 15 ++ v2/pkg/engine/resolve/datasource.go | 13 +- v2/pkg/engine/resolve/event_loop_test.go | 9 +- v2/pkg/engine/resolve/loader.go | 34 ++- v2/pkg/engine/resolve/loader_hooks_test.go | 53 ++--- v2/pkg/engine/resolve/loader_test.go | 4 +- v2/pkg/engine/resolve/resolve.go | 82 ++++--- .../engine/resolve/resolve_federation_test.go | 95 ++++---- v2/pkg/engine/resolve/resolve_mock_test.go | 17 +- v2/pkg/engine/resolve/resolve_test.go | 188 ++++++++-------- v2/pkg/engine/resolve/response.go | 13 +- v2/pkg/engine/resolve/singleflight.go | 13 +- 29 files changed, 382 insertions(+), 699 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 6f301d52d9..4574681849 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -14,7 +14,6 @@ import ( "unicode" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/jensneuse/abstractlogger" "github.com/pkg/errors" "github.com/tidwall/sjson" @@ -1907,20 +1906,19 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files) + return httpclient.DoMultipartForm(s.httpClient, ctx, headers, input, files) } -func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input) + return httpclient.Do(s.httpClient, ctx, headers, input) } type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error Unsubscribe(id uint64) } @@ -1956,12 +1954,13 @@ type SubscriptionSource struct { client GraphQLSubscriptionClient } -func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1975,12 +1974,13 @@ func (s *SubscriptionSource) AsyncStop(id uint64) { } // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. -func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions err := json.Unmarshal(input, &options) if err != nil { return err } + options.Header = headers if options.Body.Query == "" { return resolve.ErrUnableToResolve } @@ -1990,16 +1990,3 @@ func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater r var ( dataSouceName = []byte("graphql") ) - -func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(dataSouceName) - if err != nil { - return err - } - var options GraphQLSubscriptionOptions - err = json.Unmarshal(input, &options) - if err != nil { - return err - } - return s.client.UniqueRequestID(ctx, options, xxh) -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 75a23f5ed7..e064b607e6 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -16,7 +16,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -4021,6 +4020,8 @@ func TestGraphQLDataSource(t *testing.T) { NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Fetches: resolve.Sequence(), @@ -4062,6 +4063,8 @@ func TestGraphQLDataSource(t *testing.T) { client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), }, PostProcessing: DefaultPostProcessingConfiguration, + SourceName: "ds-id", + SourceID: "ds-id", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -8258,10 +8261,6 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap return errSubscriptionClientFail } -func (f *FailingSubscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - return errSubscriptionClientFail -} - type testSubscriptionUpdater struct { updates []string done bool @@ -8375,13 +8374,13 @@ func TestSubscriptionSource_Start(t *testing.T) { t.Run("should return error when input is invalid", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": "", "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": "", "header": null}`), nil) assert.Error(t, err) }) t.Run("should return error when subscription client returns an error", func(t *testing.T) { source := SubscriptionSource{client: &FailingSubscriptionClient{}} - err := source.Start(resolve.NewContext(context.Background()), []byte(`{"url": "", "body": {}, "header": null}`), nil) + err := source.Start(resolve.NewContext(context.Background()), nil, []byte(`{"url": "", "body": {}, "header": null}`), nil) assert.Error(t, err) assert.Equal(t, resolve.ErrUnableToResolve, err) }) @@ -8394,7 +8393,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: "#test") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.ErrorIs(t, err, resolve.ErrUnableToResolve) }) @@ -8406,7 +8405,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) assert.Len(t, updater.updates, 1) @@ -8424,7 +8423,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8447,7 +8446,7 @@ func TestSubscriptionSource_Start(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8511,7 +8510,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) updater.AwaitUpdates(t, time.Second, 1) @@ -8531,7 +8530,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(resolverLifecycle) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(resolve.NewContext(subscriptionLifecycle), chatSubscriptionOptions, updater) + err := source.Start(resolve.NewContext(subscriptionLifecycle), nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8555,7 +8554,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, updater) + err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) username := "myuser" @@ -8693,7 +8692,7 @@ func TestSource_Load(t *testing.T) { input = httpclient.SetInputBodyWithPath(input, variables, "variables") input = httpclient.SetInputURL(input, []byte(serverUrl)) - data, err := src.Load(context.Background(), input) + data, err := src.Load(context.Background(), nil, input) require.NoError(t, err) assert.Equal(t, `{"variables":{"a":null,"b":"b","c":{}}}`, string(data)) }) @@ -8715,7 +8714,7 @@ func TestSource_Load(t *testing.T) { input, err = httpclient.SetUndefinedVariables(input, undefinedVariables) assert.NoError(t, err) - data, err := src.Load(ctx, input) + data, err := src.Load(ctx, nil, input) require.NoError(t, err) assert.Equal(t, `{"variables":{"b":null}}`, string(data)) }) @@ -8801,7 +8800,7 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputURL(input, []byte(serverUrl)) ctx := context.Background() - _, err = src.LoadWithFiles(ctx, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) require.NoError(t, err) }) @@ -8856,7 +8855,7 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - _, err = src.LoadWithFiles(ctx, input, + _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c5a52a476c..c8a08df03f 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -9,13 +9,9 @@ import ( "errors" "fmt" "io" - "maps" "net" "net/http" "net/http/httptrace" - "net/textproto" - "slices" - "strconv" "strings" "sync" "syscall" @@ -295,27 +291,6 @@ func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubs return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) } -var ( - withSSE = []byte(`sse:true`) - withSSEMethodPost = []byte(`sse_method_post:true`) -) - -func (c *subscriptionClient) UniqueRequestID(ctx *resolve.Context, options GraphQLSubscriptionOptions, hash *xxhash.Digest) (err error) { - if options.UseSSE { - _, err = hash.Write(withSSE) - if err != nil { - return err - } - } - if options.SSEMethodPost { - _, err = hash.Write(withSSEMethodPost) - if err != nil { - return err - } - } - return c.requestHash(ctx, options, hash) -} - func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { options.readTimeout = c.readTimeout if c.streamingClient == nil { @@ -409,89 +384,6 @@ func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext cont return nil } -// generateHandlerIDHash generates a Hash based on: URL and Headers to uniquely identify Upgrade Requests -func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSubscriptionOptions, xxh *xxhash.Digest) (err error) { - if _, err = xxh.WriteString(options.URL); err != nil { - return err - } - if err := options.Header.Write(xxh); err != nil { - return err - } - // Make sure any header that will be forwarded to the subgraph - // is hashed to create the handlerID, this way requests with - // different headers will use separate connections. - for _, headerName := range options.ForwardedClientHeaderNames { - if _, err = xxh.WriteString(headerName); err != nil { - return err - } - for _, val := range ctx.Request.Header[textproto.CanonicalMIMEHeaderKey(headerName)] { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - - // Sort header names for deterministic hashing since looping through maps - // results in a non-deterministic order of elements - headerKeys := slices.Sorted(maps.Keys(ctx.Request.Header)) - - for _, headerRegexp := range options.ForwardedClientHeaderRegularExpressions { - // Write header pattern - if _, err = xxh.WriteString(headerRegexp.Pattern.String()); err != nil { - return err - } - - // Write negate match - if _, err = xxh.WriteString(strconv.FormatBool(headerRegexp.NegateMatch)); err != nil { - return err - } - - for _, headerName := range headerKeys { - values := ctx.Request.Header[headerName] - result := headerRegexp.Pattern.MatchString(headerName) - if headerRegexp.NegateMatch { - result = !result - } - if result { - for _, val := range values { - if _, err = xxh.WriteString(val); err != nil { - return err - } - } - } - } - } - if len(ctx.InitialPayload) > 0 { - if _, err = xxh.Write(ctx.InitialPayload); err != nil { - return err - } - } - if options.Body.Extensions != nil { - if _, err = xxh.Write(options.Body.Extensions); err != nil { - return err - } - } - if options.Body.Query != "" { - _, err = xxh.WriteString(options.Body.Query) - if err != nil { - return err - } - } - if options.Body.Variables != nil { - _, err = xxh.Write(options.Body.Variables) - if err != nil { - return err - } - } - if options.Body.OperationName != "" { - _, err = xxh.WriteString(options.Body.OperationName) - if err != nil { - return err - } - } - return nil -} - type UpgradeRequestError struct { URL string StatusCode int diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 279c4bfe83..25eaa29f72 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "regexp" "runtime" "strings" "sync" @@ -15,7 +14,6 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/coder/websocket" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" @@ -2571,203 +2569,3 @@ func TestInvalidWebSocketAcceptKey(t *testing.T) { }) } } - -func TestRequestHash(t *testing.T) { - t.Parallel() - client := &subscriptionClient{} - - t.Run("basic request with URL and headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Header: http.Header{ - "Authorization": []string{"Bearer token"}, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xacbca06c541c2a79), hash.Sum64()) - }) - - t.Run("request with forwarded client headers", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-User-Id": []string{"123"}, - "X-Role": []string{"admin"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderNames: []string{"X-User-Id", "X-Role"}, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xf428bef25952044c), hash.Sum64()) - }) - - t.Run("request with forwarded client header regex patterns", func(t *testing.T) { - t.Parallel() - - t.Run("with normal", func(t *testing.T) { - header := http.Header{ - "X-Custom-1": []string{"value1"}, - "X-There-2": []string{"value2"}, - "X-Alright-3": []string{"value3"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xb1557904bfa9d86a), hash.Sum64()) - }) - - t.Run("with negative", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{ - "X-Custom-1": []string{"valueThere1"}, - "X-Custom-2": []string{"valueThere2"}, - }, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-2"), - NegateMatch: true, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x5888642db454ccab), hash.Sum64()) - }) - - t.Run("with multiple tries to ensure the hash is idempotent", func(t *testing.T) { - for range 100 { - header := http.Header{ - "X-Custom-1": []string{"a1"}, - "X-There-2": []string{"a2"}, - "X-Custom-6": []string{"a3"}, - "X-Alright-3": []string{"a4"}, - "X-Custom-5": []string{"a5"}, - } - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: header, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - ForwardedClientHeaderRegularExpressions: []RegularExpression{ - { - Pattern: regexp.MustCompile("^X-Custom-.*$"), - NegateMatch: false, - }, - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x6c9c1099adab987d), hash.Sum64()) - } - }) - }) - - t.Run("request with initial payload", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - InitialPayload: []byte(`{"auth": "token"}`), - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x3c5af329478bfcce), hash.Sum64()) - - }) - - t.Run("request with body components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - Body: GraphQLBody{ - Query: "query { hello }", - Variables: []byte(`{"var": "value"}`), - OperationName: "HelloQuery", - Extensions: []byte(`{"ext": "value"}`), - }, - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0xd8d5588c8a466cf2), hash.Sum64()) - }) - - t.Run("empty components", func(t *testing.T) { - t.Parallel() - - ctx := &resolve.Context{ - Request: resolve.Request{ - Header: http.Header{}, - }, - } - options := GraphQLSubscriptionOptions{ - URL: "http://example.com/graphql", - } - hash := xxhash.New() - - err := client.requestHash(ctx, options, hash) - assert.NoError(t, err) - assert.Equal(t, uint64(0x767db2231989769), hash.Sum64()) - }) - -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 58729e33c2..1305fda5f1 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -9,6 +9,7 @@ package grpcdatasource import ( "context" "fmt" + "net/http" "sync" "github.com/tidwall/gjson" @@ -77,7 +78,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. -func (d *DataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(string(input)).Get("body.variables") builder := newJSONBuilder(d.mapping, variables) @@ -150,6 +151,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte) (data []byte, err e // might not be applicable for most gRPC use cases. // // Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("unimplemented") } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 2a18e2f176..348a502d72 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -146,7 +146,7 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","variables":`+variables+`}`)) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) fmt.Println(string(output)) @@ -217,7 +217,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { require.NoError(t, err) // 3. Execute the query through our datasource - output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) // Print the response for debugging @@ -309,7 +309,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { // Format the input with query and variables inputJSON := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - output, err := ds.Load(context.Background(), []byte(inputJSON)) + output, err := ds.Load(context.Background(), nil, []byte(inputJSON)) require.NoError(t, err) // Set up the correct response structure based on your GraphQL schema @@ -401,7 +401,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { require.NoError(t, err) // 4. Execute the query - output, err := ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`)) + output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err, "Load should not return an error even when the gRPC call fails") responseJson := string(output) @@ -727,7 +727,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -997,7 +997,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1133,7 +1133,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1213,7 +1213,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, query, variables) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1303,7 +1303,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":{}}`, query) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -1772,7 +1772,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -2150,7 +2150,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -3451,7 +3451,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response @@ -3603,7 +3603,7 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { // Execute the query through our datasource input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - output, err := ds.Load(context.Background(), []byte(input)) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response diff --git a/v2/pkg/engine/datasource/httpclient/httpclient_test.go b/v2/pkg/engine/datasource/httpclient/httpclient_test.go index cbef2d1f7d..98685ceceb 100644 --- a/v2/pkg/engine/datasource/httpclient/httpclient_test.go +++ b/v2/pkg/engine/datasource/httpclient/httpclient_test.go @@ -79,7 +79,7 @@ func TestHttpClientDo(t *testing.T) { runTest := func(ctx context.Context, input []byte, expectedOutput string) func(t *testing.T) { return func(t *testing.T) { - output, err := Do(http.DefaultClient, ctx, input) + output, err := Do(http.DefaultClient, ctx, nil, input) assert.NoError(t, err) assert.Equal(t, expectedOutput, string(output)) } @@ -209,7 +209,7 @@ func TestHttpClientDo(t *testing.T) { input = SetInputURL(input, []byte(server.URL)) input, err := sjson.SetBytes(input, TRACE, true) assert.NoError(t, err) - output, err := Do(http.DefaultClient, context.Background(), input) + output, err := Do(http.DefaultClient, context.Background(), nil, input) assert.NoError(t, err) assert.Contains(t, string(output), `"Authorization":["****"]`) }) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index d6276c8375..4c4f2de3d4 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -27,6 +27,7 @@ const ( AcceptEncodingHeader = "Accept-Encoding" AcceptHeader = "Accept" ContentTypeHeader = "Content-Type" + ContentLengthHeader = "Content-Length" EncodingGzip = "gzip" EncodingDeflate = "deflate" @@ -146,13 +147,17 @@ func buffer(ctx context.Context) *bytes.Buffer { return bytes.NewBuffer(make([]byte, 0, 64)) } -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string) ([]byte, error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http.Header, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, contentType string, contentLength int) ([]byte, error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { return nil, err } + if baseHeaders != nil { + request.Header = baseHeaders + } + if headers != nil { err = jsonparser.ObjectEach(headers, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { _, err := jsonparser.ArrayEach(value, func(value []byte, dataType jsonparser.ValueType, offset int, err error) { @@ -205,6 +210,9 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head request.Header.Add(ContentTypeHeader, contentType) request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) + if contentLength > 0 { + request.Header.Set(ContentLengthHeader, fmt.Sprintf("%d", contentLength)) + } setRequest(ctx, request) @@ -256,13 +264,13 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return responseWithTraceExtension, nil } -func Do(client *http.Client, ctx context.Context, requestInput []byte) (data []byte, err error) { +func Do(client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte) (data []byte, err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, ContentTypeJSON, len(body)) } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []*FileUpload, + client *http.Client, ctx context.Context, baseHeaders http.Header, requestInput []byte, files []*FileUpload, ) (data []byte, err error) { if len(files) == 0 { return nil, errors.New("no files provided") @@ -316,7 +324,7 @@ func DoMultipartForm( } }() - return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, contentType) + return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, multipartBody, enableTrace, contentType, 0) } func multipartBytes(values map[string]io.Reader, files []*FileUpload) (*io.PipeReader, string, error) { diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index a55549ace9..67195e44a7 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "io" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" @@ -18,7 +19,7 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { return nil, err @@ -31,7 +32,7 @@ func (s *Source) Load(ctx context.Context, input []byte) (data []byte, err error return json.Marshal(s.introspectionData.Schema) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { return nil, errors.New("introspection data source does not support file uploads") } diff --git a/v2/pkg/engine/datasource/introspection_datasource/source_test.go b/v2/pkg/engine/datasource/introspection_datasource/source_test.go index 7c331b7d14..9737a4ee9f 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source_test.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source_test.go @@ -28,7 +28,7 @@ func TestSource_Load(t *testing.T) { require.False(t, report.HasErrors()) source := &Source{introspectionData: &data} - responseData, err := source.Load(context.Background(), []byte(input)) + responseData, err := source.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) actualResponse := &bytes.Buffer{} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go index 28a37df33b..2ea8114ad4 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_datasource_test.go @@ -424,6 +424,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"helloSubscription"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -487,6 +489,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithMultipleSubjects"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -532,6 +536,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithStaticValues"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ @@ -583,6 +589,8 @@ func TestPubSub(t *testing.T) { PostProcessing: resolve.PostProcessingConfiguration{ MergePath: []string{"subscriptionWithArgTemplateAndStaticValue"}, }, + SourceName: "test", + SourceID: "test", }, Response: &resolve.GraphQLResponse{ Data: &resolve.Object{ diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index 7f1a6226b2..3f688b6b14 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -3,9 +3,7 @@ package pubsub_datasource import ( "context" "encoding/json" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -31,28 +29,7 @@ type KafkaSubscriptionSource struct { pubSub KafkaPubSub } -func (s *KafkaSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *KafkaSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration KafkaSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -66,7 +43,7 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *KafkaPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration KafkaPublishEventConfiguration err = json.Unmarshal(input, &publishConfiguration) if err != nil { @@ -79,6 +56,6 @@ func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte) (data [ return []byte(`{"success": true}`), nil } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index e5d3bec0f0..776b5deac1 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -5,9 +5,7 @@ import ( "context" "encoding/json" "io" - - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -42,28 +40,7 @@ type NatsSubscriptionSource struct { pubSub NatsPubSub } -func (s *NatsSubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) - return err -} - -func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { +func (s *NatsSubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var subscriptionConfiguration NatsSubscriptionEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -77,7 +54,7 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *NatsPublishDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var publishConfiguration NatsPublishAndRequestEventConfiguration err = json.Unmarshal(input, &publishConfiguration) if err != nil { @@ -91,7 +68,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte) (data [] return []byte(`{"success": true}`), nil } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } @@ -99,7 +76,7 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (s *NatsRequestDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration err = json.Unmarshal(input, &subscriptionConfiguration) if err != nil { @@ -115,6 +92,6 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte) (data [] return buf.Bytes(), nil } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index 626a1d9f94..3fb75c8b36 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -2,6 +2,7 @@ package staticdatasource import ( "context" + "net/http" "github.com/jensneuse/abstractlogger" @@ -70,10 +71,10 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (Source) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { return input, nil } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (Source) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { panic("not implemented") } diff --git a/v2/pkg/engine/plan/planner_test.go b/v2/pkg/engine/plan/planner_test.go index 658ff3fc72..b952107f07 100644 --- a/v2/pkg/engine/plan/planner_test.go +++ b/v2/pkg/engine/plan/planner_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "reflect" "slices" "testing" @@ -1074,10 +1075,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (f *FakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { return nil, nil } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { return nil, nil } diff --git a/v2/pkg/engine/plan/visitor.go b/v2/pkg/engine/plan/visitor.go index ef8d094757..72dbe719c6 100644 --- a/v2/pkg/engine/plan/visitor.go +++ b/v2/pkg/engine/plan/visitor.go @@ -1290,6 +1290,8 @@ func (v *Visitor) configureSubscription(config *objectFetchConfiguration) { v.subscription.Trigger.QueryPlan = subscription.QueryPlan v.resolveInputTemplates(config, &subscription.Input, &v.subscription.Trigger.Variables) v.subscription.Trigger.Input = []byte(subscription.Input) + v.subscription.Trigger.SourceName = config.sourceName + v.subscription.Trigger.SourceID = config.sourceID v.subscription.Filter = config.filter } diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index ea83c77259..95051def7e 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "io" + "net/http" "sync/atomic" "testing" @@ -509,8 +510,8 @@ func TestAuthorization(t *testing.T) { func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -519,8 +520,8 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -529,8 +530,8 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) @@ -814,8 +815,8 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T, ctrl *gomock.Controller) *GraphQLResponse { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -824,8 +825,8 @@ func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -834,8 +835,8 @@ func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index e9958d24ef..b0b82f5787 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -32,12 +32,27 @@ type Context struct { fieldRenderer FieldValueRenderer subgraphErrors error + + SubgraphHeadersBuilder HeadersForSubgraphRequest +} + +type HeadersForSubgraphRequest interface { + HeadersForSubgraph(subgraphName string) (http.Header, uint64) +} + +func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, uint64) { + if c.SubgraphHeadersBuilder == nil { + return nil, 0 + } + return c.SubgraphHeadersBuilder.HeadersForSubgraph(subgraphName) } type ExecutionOptions struct { SkipLoader bool IncludeQueryPlanInResponse bool SendHeartbeat bool + // DisableRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. + DisableRequestDeduplication bool } type FieldValue struct { diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index 8063541f6d..7855fa6378 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -2,26 +2,23 @@ package resolve import ( "context" - - "github.com/cespare/xxhash/v2" + "net/http" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) type DataSource interface { - Load(ctx context.Context, input []byte) (data []byte, err error) - LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) + Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) + LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) } type SubscriptionDataSource interface { // Start is called when a new subscription is created. It establishes the connection to the data source. // The updater is used to send updates to the client. Deduplication of the request must be done before calling this method. - Start(ctx *Context, input []byte, updater SubscriptionUpdater) error - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) + Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error } type AsyncSubscriptionDataSource interface { - AsyncStart(ctx *Context, id uint64, input []byte, updater SubscriptionUpdater) error + AsyncStart(ctx *Context, id uint64, headers http.Header, input []byte, updater SubscriptionUpdater) error AsyncStop(id uint64) - UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) } diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/event_loop_test.go index 11389630a9..ba8b7c8e2f 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/event_loop_test.go @@ -3,12 +3,12 @@ package resolve import ( "context" "io" + "net/http" "sync" "sync/atomic" "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/require" ) @@ -71,12 +71,7 @@ type FakeSource struct { interval time.Duration } -func (f *FakeSource) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.Write(input) - return err -} - -func (f *FakeSource) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *FakeSource) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { go func() { for i, u := range f.updates { updater.Update([]byte(u)) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 71a3c5304e..a429087d06 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -14,10 +14,10 @@ import ( "time" "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" @@ -1420,7 +1420,8 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) - keyGen := xxhash.New() + keyGen := pool.Hash64.Get() + defer pool.Hash64.Put(keyGen) var undefinedVariables []string @@ -1590,18 +1591,33 @@ func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { return ast.OperationTypeQuery } +func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, uint64) { + if fetchItem == nil || fetchItem.Fetch == nil { + return nil, 0 + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return nil, 0 + } + return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) +} + func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { if l.info != nil { ctx = context.WithValue(ctx, operationTypeContextKey, l.info.OperationType) } - if l.info == nil || l.info.OperationType == ast.OperationTypeMutation { + headers, extraKey := l.headersForSubgraphRequest(fetchItem) + + if l.info == nil || + l.info.OperationType == ast.OperationTypeMutation || + l.ctx.ExecutionOptions.DisableRequestDeduplication { // Disable single flight for mutations - return l.loadByContextDirect(ctx, source, input, res) + return l.loadByContextDirect(ctx, source, headers, input, res) } - sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(ctx, fetchItem, input) + sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) if res.singleFlightStats != nil { res.singleFlightStats.used = shared res.singleFlightStats.shared = shared @@ -1627,7 +1643,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem defer l.sf.Finish(sfKey, fetchKey, item) // Perform the actual load - err := l.loadByContextDirect(ctx, source, input, res) + err := l.loadByContextDirect(ctx, source, headers, input, res) if err != nil { item.err = err return err @@ -1637,11 +1653,11 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return nil } -func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, input []byte, res *result) error { +func (l *Loader) loadByContextDirect(ctx context.Context, source DataSource, headers http.Header, input []byte, res *result) error { if l.ctx.Files != nil { - res.out, res.err = source.LoadWithFiles(ctx, input, l.ctx.Files) + res.out, res.err = source.LoadWithFiles(ctx, headers, input, l.ctx.Files) } else { - res.out, res.err = source.Load(ctx, input) + res.out, res.err = source.Load(ctx, headers, input) } if res.err != nil { return errors.WithStack(res.err) diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index d82857598d..ebe263dcd9 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -3,6 +3,7 @@ package resolve import ( "bytes" "context" + "net/http" "sync" "sync/atomic" "testing" @@ -49,8 +50,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("simple fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ @@ -121,8 +122,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ @@ -187,8 +188,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := &Context{ @@ -250,8 +251,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) resolveCtx := Context{ @@ -313,8 +314,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) resolveCtx := Context{ @@ -376,8 +377,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate only extension code field from subgraph errors", testFnSubgraphErrorsWithExtensionFieldCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ @@ -411,8 +412,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Propagate all extension fields from subgraph errors when allow all option is enabled", testFnSubgraphErrorsWithAllowAllExtensionFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","foo":"bar"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ @@ -446,8 +447,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName extension field", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT"}}]}`), nil }) return &GraphQLResponse{ @@ -481,8 +482,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is null", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ @@ -516,8 +517,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Include datasource name as serviceName when extensions is an empty object", testFnSubgraphErrorsWithExtensionFieldServiceName(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2","extensions":null}]}`), nil }) return &GraphQLResponse{ @@ -551,8 +552,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when no code field was set", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ @@ -586,8 +587,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is null", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":null},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ @@ -621,8 +622,8 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","extensions":{}},{"message":"errorMessage2"}]}`), nil }) return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index 0fe38ddc79..d6c002393b 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -296,7 +296,7 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } @@ -758,7 +758,7 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctrl.Finish() out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` assert.Equal(t, expected, out) } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 90b534174e..107f0cb794 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "net/http" "time" "github.com/buger/jsonparser" @@ -707,14 +708,16 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) asyncDataSource = async } + headers, _ := r.triggerHeaders(add.ctx, add.sourceName) + go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, headers, add.input, updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, headers, add.input, updater) } if err != nil { if r.options.Debug { @@ -1057,6 +1060,13 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { return nil } +func (r *Resolver) triggerHeaders(ctx *Context, sourceName string) (http.Header, uint64) { + if ctx.SubgraphHeadersBuilder != nil { + return ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + } + return nil, 0 +} + func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter) error { if subscription.Trigger.Source == nil { return errors.New("no data source found") @@ -1094,14 +1104,14 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ return nil } + _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) + xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } - uniqueID := xxh.Sum64() + _, _ = xxh.Write(input) + // the hash for subgraph headers is pre-computed + // we can just add it to the input hash to get a unique id + uniqueID := xxh.Sum64() + headersHash + pool.Hash64.Put(xxh) id := SubscriptionIdentifier{ ConnectionID: ConnectionIDs.Inc(), SubscriptionID: 0, @@ -1120,12 +1130,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ triggerID: uniqueID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: completed, + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: completed, + sourceName: subscription.Trigger.SourceName, }, }: } @@ -1203,13 +1214,14 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } + _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) + xxh := pool.Hash64.Get() - defer pool.Hash64.Put(xxh) - err = subscription.Trigger.Source.UniqueRequestID(ctx, input, xxh) - if err != nil { - msg := []byte(`{"errors":[{"message":"unable to resolve"}]}`) - return writeFlushComplete(writer, msg) - } + _, _ = xxh.Write(input) + // the hash for subgraph headers is pre-computed + // we can just add it to the input hash to get a unique id + uniqueID := xxh.Sum64() + headersHash + pool.Hash64.Put(xxh) select { case <-r.ctx.Done(): @@ -1219,15 +1231,16 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // Stop resolving if the client is gone return ctx.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: xxh.Sum64(), + triggerID: uniqueID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: make(chan struct{}), + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: make(chan struct{}), + sourceName: subscription.Trigger.SourceName, }, }: } @@ -1335,12 +1348,13 @@ type subscriptionEvent struct { } type addSubscription struct { - ctx *Context - input []byte - resolve *GraphQLSubscription - writer SubscriptionResponseWriter - id SubscriptionIdentifier - completed chan struct{} + ctx *Context + input []byte + resolve *GraphQLSubscription + writer SubscriptionResponseWriter + id SubscriptionIdentifier + completed chan struct{} + sourceName string } type subscriptionEventKind int diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index 64d969c6c6..1c32db689a 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -2,6 +2,7 @@ package resolve import ( "context" + "net/http" "testing" "github.com/golang/mock/gomock" @@ -19,8 +20,8 @@ func mockedDS(t TestingTB, ctrl *gomock.Controller, expectedInput, responseData t.Helper() service := NewMockDataSource(ctrl) service.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { require.Equal(t, expectedInput, string(input)) return []byte(responseData), nil }).Times(1) @@ -173,8 +174,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("federation with shareable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { firstService := NewMockDataSource(ctrl) firstService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://first.service","body":{"query":"{me {details {forename middlename} __typename id}}"}}` assert.Equal(t, expected, actual) @@ -185,8 +186,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { secondService := NewMockDataSource(ctrl) secondService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://second.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {surname}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -197,8 +198,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { thirdService := NewMockDataSource(ctrl) thirdService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://third.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on User {details {age}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -368,8 +369,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -380,8 +381,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age } ... on Address { line1 }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":55,"__typename":"Address"}]}}}` assert.Equal(t, expected, actual) @@ -521,8 +522,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name infoOrAddress { ... on Info {id __typename} ... on Address {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -533,7 +534,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -666,8 +667,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching on a field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -678,8 +679,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -810,8 +811,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with duplicates", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -822,8 +823,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -951,8 +952,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with null entry", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -963,8 +964,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -1096,8 +1097,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with all null entries", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1108,7 +1109,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1234,8 +1235,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("batching with render error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ users { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1247,8 +1248,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":12,"__typename":"Info"},{"id":13,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -1381,8 +1382,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("all data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1393,8 +1394,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}}}","variables":{"representations":[{"id":11,"__typename":"Info"}]}}}` assert.Equal(t, expected, actual) @@ -1515,8 +1516,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("null info data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1527,7 +1528,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1643,8 +1644,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("wrong type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1655,7 +1656,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ @@ -1771,8 +1772,8 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { t.Run("not matching type data", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{ user { name info {id __typename}}}}"}}` assert.Equal(t, expected, actual) @@ -1783,7 +1784,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { infoService := NewMockDataSource(ctrl) infoService.EXPECT(). - Load(gomock.Any(), gomock.Any()). + Load(gomock.Any(), gomock.Any(), gomock.Any()). Times(0) return &GraphQLResponse{ diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index d493ff4bdf..a64b7dd831 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -6,6 +6,7 @@ package resolve import ( context "context" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -36,31 +37,31 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte) ([]byte, error) { +func (m *MockDataSource) Load(arg0 context.Context, arg1 http.Header, arg2 []byte) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Load", arg0, arg1) + ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // Load indicates an expected call of Load. -func (mr *MockDataSourceMockRecorder) Load(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1, arg2) } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []*httpclient.FileUpload) ([]byte, error) { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 http.Header, arg2 []byte, arg3 []*httpclient.FileUpload) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } // LoadWithFiles indicates an expected call of LoadWithFiles. -func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2, arg3) } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index d19156f365..5c2ea4ed66 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/cespare/xxhash/v2" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -32,7 +31,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte) (data []byte, err error) { +func (f *_fakeDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -44,7 +43,7 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte) (data []byte, return f.data, nil } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -349,8 +348,8 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("fetch with context variable resolver", testFn(true, func(t *testing.T, ctrl *gomock.Controller) (response *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`)). - Do(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + Do(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"name":"Jens"}`), nil }). Return([]byte(`{"name":"Jens"}`), nil) @@ -1799,8 +1798,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1829,8 +1828,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error without datasource ID no subgraph error forwarding", testFnNoSubgraphErrorForwarding(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1859,8 +1858,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1893,8 +1892,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with simple error in pass through Subgraph Error Mode", testFnSubgraphErrorsPassthrough(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil }) return &GraphQLResponse{ @@ -1927,8 +1926,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with pass through mode and omit custom fields", testFnSubgraphErrorsPassthroughAndOmitCustomFields(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage","longMessage":"This is a long message","extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":{"name":null}}`), nil }) return &GraphQLResponse{ @@ -1964,8 +1963,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (with DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, &net.AddrError{} }) return &GraphQLResponse{ @@ -1998,8 +1997,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err (no DataSourceID)", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, &net.AddrError{} }) return &GraphQLResponse{ @@ -2028,8 +2027,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with returned err and non-nullable root field", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, &net.AddrError{} }) return &GraphQLResponse{ @@ -2206,8 +2205,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("fetch with two Errors", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}]}`), nil }).Times(1) return &GraphQLResponse{ @@ -2562,8 +2561,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("complex GraphQL Server plan", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) serviceOne.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"query($firstArg: String, $thirdArg: Int){serviceOne(serviceOneArg: $firstArg){fieldOne} anotherServiceOne(anotherServiceOneArg: $thirdArg){fieldOne} reusingServiceOne(reusingServiceOneArg: $firstArg){fieldOne}}","variables":{"thirdArg":123,"firstArg":"firstArgValue"}}}` assert.Equal(t, expected, actual) @@ -2572,8 +2571,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { serviceTwo := NewMockDataSource(ctrl) serviceTwo.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.two","body":{"query":"query($secondArg: Boolean, $fourthArg: Float){serviceTwo(serviceTwoArg: $secondArg){fieldTwo} secondServiceTwo(secondServiceTwoArg: $fourthArg){fieldTwo}}","variables":{"fourthArg":12.34,"secondArg":true}}}` assert.Equal(t, expected, actual) @@ -2582,8 +2581,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { nestedServiceOne := NewMockDataSource(ctrl) nestedServiceOne.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"url":"https://service.one","body":{"query":"{serviceOne {fieldOne}}"}}` assert.Equal(t, expected, actual) @@ -2798,8 +2797,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -2808,8 +2807,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -2820,8 +2819,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) productServiceCallCount.Add(1) switch actual { @@ -3005,8 +3004,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3015,8 +3014,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -3025,8 +3024,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) @@ -3202,8 +3201,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with merge paths", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3212,8 +3211,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"__typename":"User","id":"1234"}]}}}` assert.Equal(t, expected, actual) @@ -3222,8 +3221,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"__typename":"Product","upc":"top-1"},{"__typename":"Product","upc":"top-2"}]}}}` assert.Equal(t, expected, actual) @@ -3400,8 +3399,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with null response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3410,8 +3409,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -3427,8 +3426,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"},{"upc":"top-4","__typename":"Product"},{"upc":"top-5","__typename":"Product"},{"upc":"top-6","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -3627,8 +3626,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3637,8 +3636,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -3647,8 +3646,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -3814,8 +3813,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -3824,8 +3823,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -3834,8 +3833,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -3998,8 +3997,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("federation with optional variable", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8080/query","body":{"query":"{me {id}}"}}` assert.Equal(t, expected, actual) @@ -4008,8 +4007,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { employeeService := NewMockDataSource(ctrl) employeeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8081/query","body":{"query":"query($representations: [_Any!]!, $companyId: ID!){_entities(representations: $representations){... on User {employment(companyId: $companyId){id}}}}","variables":{"companyId":"abc123","representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -4018,8 +4017,8 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { timeService := NewMockDataSource(ctrl) timeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:8082/query","body":{"query":"query($representations: [_Any!]!, $date: LocalTime){_entities(representations: $representations){... on Employee {times(date: $date){id employee {id} start end}}}}","variables":{"date":null,"representations":[{"id":"xyz987","__typename":"Employee"}]}}}` assert.Equal(t, expected, actual) @@ -4538,8 +4537,8 @@ func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { t.Run("with variables", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), []byte(`{"id":1}`)). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), []byte(`{"id":1}`)). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"name":"Jens"}`), nil }) return &GraphQLResponse{ @@ -4582,8 +4581,8 @@ func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { t.Run("error handling", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return nil, errors.New("data source error") }) return &GraphQLResponse{ @@ -4659,8 +4658,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("simple fetch with fetch error suppression - empty response", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte("{}"), nil }) return &GraphQLResponse{ @@ -4697,8 +4696,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("simple fetch with fetch error suppression - response with error", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { return []byte(`{"errors":[{"message":"Cannot query field 'name' on type 'Query'"}]}`), nil }) return &GraphQLResponse{ @@ -4735,8 +4734,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { t.Run("complex fetch with fetch error suppression", testFnApolloCompatibility(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` assert.Equal(t, expected, actual) @@ -4745,8 +4744,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { reviewsService := NewMockDataSource(ctrl) reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` assert.Equal(t, expected, actual) @@ -4755,8 +4754,8 @@ func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { productService := NewMockDataSource(ctrl) productService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) expected := `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"},{"upc":"top-2","__typename":"Product"}]}}}` assert.Equal(t, expected, actual) @@ -4946,8 +4945,8 @@ func TestResolver_WithHeader(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, "foo", actual) return []byte(`{"bar":"baz"}`), nil @@ -5017,8 +5016,8 @@ func TestResolver_WithVariableRemapping(t *testing.T) { ctrl := gomock.NewController(t) fakeService := NewMockDataSource(ctrl) fakeService.EXPECT(). - Load(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, input []byte) ([]byte, error) { + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { actual := string(input) assert.Equal(t, tc.expectedOutput, actual) return []byte(`{"bar":"baz"}`), nil @@ -5203,16 +5202,7 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } } -func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = fmt.Fprint(xxh, fakeStreamRequestId.Add(1)) - if err != nil { - return - } - _, err = xxh.Write(input) - return -} - -func (f *_fakeStream) Start(ctx *Context, input []byte, updater SubscriptionUpdater) error { +func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error { if f.onStart != nil { f.onStart(input) } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index b98f4c00fa..c02d92f497 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -16,12 +16,13 @@ type GraphQLSubscription struct { } type GraphQLSubscriptionTrigger struct { - Input []byte - InputTemplate InputTemplate - Variables Variables - Source SubscriptionDataSource - PostProcessing PostProcessingConfiguration - QueryPlan *QueryPlan + Input []byte + InputTemplate InputTemplate + Variables Variables + Source SubscriptionDataSource + PostProcessing PostProcessingConfiguration + QueryPlan *QueryPlan + SourceName, SourceID string } // GraphQLResponse contains an ordered tree of fetches and the response shape. diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go index e298531967..a179602492 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/singleflight.go @@ -1,7 +1,6 @@ package resolve import ( - "context" "sync" "github.com/cespare/xxhash/v2" @@ -41,8 +40,8 @@ func NewSingleFlight() *SingleFlight { } } -func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { - sfKey, fetchKey = s.keys(fetchItem, input) +func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { + sfKey, fetchKey = s.keys(fetchItem, input, extraKey) // First, try to get the item with a read lock s.mu.RLock() @@ -73,9 +72,9 @@ func (s *SingleFlight) GetOrCreateItem(ctx context.Context, fetchItem *FetchItem return sfKey, fetchKey, item, false } -func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte) (sfKey, fetchKey uint64) { +func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) - sfKey = s.sfKey(h, fetchItem, input) + sfKey = s.sfKey(h, fetchItem, input, extraKey) h.Reset() fetchKey = s.fetchKey(h, fetchItem) h.Reset() @@ -83,7 +82,7 @@ func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte) (sfKey, fetchKey return sfKey, fetchKey } -func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte) uint64 { +func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -92,7 +91,7 @@ func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byt } } _, _ = h.Write(input) - return h.Sum64() + return h.Sum64() + extraKey } func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { From 26f22b33f89c94a6ec64682d870b45600bfae244 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 24 Oct 2025 18:56:18 +0200 Subject: [PATCH 12/57] chore: rename HeadersForSubgraphRequest to SubgraphHeadersBuilder --- v2/pkg/engine/resolve/context.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index b0b82f5787..dd4f32e8cb 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -33,10 +33,10 @@ type Context struct { subgraphErrors error - SubgraphHeadersBuilder HeadersForSubgraphRequest + SubgraphHeadersBuilder SubgraphHeadersBuilder } -type HeadersForSubgraphRequest interface { +type SubgraphHeadersBuilder interface { HeadersForSubgraph(subgraphName string) (http.Header, uint64) } From 4392770d09eb34630a0e10666d693fdfdd118780 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 24 Oct 2025 18:56:41 +0200 Subject: [PATCH 13/57] chore: fix bug --- v2/pkg/engine/resolve/loader.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index a429087d06..8269768094 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1851,7 +1851,7 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - v, err := astjson.ParseBytesWithArena(l.jsonArena, out) + v, err := astjson.ParseBytes(out) if err != nil { return nil, err } From 94f3d27c578ebd85761b6f97c482675c4b778b96 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 13:30:24 +0200 Subject: [PATCH 14/57] chore: use are to execute subscription updates --- v2/pkg/engine/resolve/resolve.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 107f0cb794..5acfc6aad4 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -480,7 +480,12 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) + t.loader.jsonArena = resolveArena.Arena + t.resolvable.astjsonArena = resolveArena.Arena + if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -492,6 +497,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -503,6 +509,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -513,6 +520,8 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } + r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + if err := sub.writer.Flush(); err != nil { // If flush fails (e.g. client disconnected), remove the subscription. _ = r.AsyncUnsubscribeSubscription(sub.id) From e7407d1fd2a3023989eb572fe30ef9bad4d694d5 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 15:36:20 +0200 Subject: [PATCH 15/57] chore: merge main --- v2/go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/v2/go.sum b/v2/go.sum index 690d15a884..5a7781e3a2 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -134,8 +134,6 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= From 60b5c3b390d0af3e1e54a71fd669138744e72689 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 20:59:31 +0200 Subject: [PATCH 16/57] chore: update deps --- v2/go.mod | 9 ++------- v2/go.sum | 4 ++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/v2/go.mod b/v2/go.mod index 308eea5345..43ada453b4 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -28,8 +28,8 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 - github.com/wundergraph/go-arena v0.0.1 + github.com/wundergraph/astjson v1.0.0 + github.com/wundergraph/go-arena v1.0.0 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 @@ -79,8 +79,3 @@ require ( ) tool github.com/99designs/gqlgen - -replace ( - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 => ../../wundergraph-projects/astjson - github.com/wundergraph/go-arena v0.0.1 => ../../wundergraph-projects/go-arena -) diff --git a/v2/go.sum b/v2/go.sum index 5a7781e3a2..6d0fb36360 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -134,6 +134,10 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= +github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= +github.com/wundergraph/go-arena v1.0.0 h1:RVYWpDkJ1/6851BRHYehBeEcTLKmZygYIZsvBorcOjw= +github.com/wundergraph/go-arena v1.0.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= From 3fb0272893d828d5a574d43396e8278b250a5ef8 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 21:45:26 +0200 Subject: [PATCH 17/57] chore: add comments --- .../graphql_subscription_client_test.go | 2 +- .../datasource/httpclient/nethttpclient.go | 14 ++++++ v2/pkg/engine/resolve/context.go | 7 +++ v2/pkg/engine/resolve/loader.go | 49 ++++++++++++++++--- v2/pkg/engine/resolve/resolvable.go | 7 ++- v2/pkg/engine/resolve/resolve.go | 9 +++- 6 files changed, 78 insertions(+), 10 deletions(-) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 25eaa29f72..86dd57c030 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -2437,7 +2437,7 @@ func TestWebSocketUpgradeFailures(t *testing.T) { w.Header().Set(key, value) } w.WriteHeader(tc.statusCode) - fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) + _, _ = fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) })) defer server.Close() diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4c4f2de3d4..c4ce9915ff 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -136,6 +136,9 @@ const ( sizeHintKey httpClientContext = "size-hint" ) +// WithHTTPClientSizeHint allows the engine to keep track of response sizes per subgraph fetch +// If a hint is supplied, we can create a buffer of size close to the required size +// This reduces allocations by reducing the buffer grow calls, which always copies the buffer func WithHTTPClientSizeHint(ctx context.Context, size int) context.Context { return context.WithValue(ctx, sizeHintKey, size) } @@ -144,6 +147,9 @@ func buffer(ctx context.Context) *bytes.Buffer { if sizeHint, ok := ctx.Value(sizeHintKey).(int); ok && sizeHint > 0 { return bytes.NewBuffer(make([]byte, 0, sizeHint)) } + // if we start with zero, doubling will take a while until we reach the required size + // if we start with a high number, e.g. 1024, we just increase the memory usage of the engine + // 64 seems to be a healthy middle ground return bytes.NewBuffer(make([]byte, 0, 64)) } @@ -211,6 +217,8 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) if contentLength > 0 { + // always set the Content-Length Header so that chunking can be avoided + // and other parties can more efficiently parse request.Header.Set(ContentLengthHeader, fmt.Sprintf("%d", contentLength)) } @@ -229,6 +237,12 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. return nil, err } + // we intentionally don't use a pool of sorts here + // we're buffering the response and then later, in the engine, + // parse it into an JSON AST with the use of an arena, which is quite efficient + // Through trial and error it turned out that it's best to leave this buffer to the GC + // It'll know best the lifecycle of the buffer + // Using an arena here just increased overall memory usage out := buffer(ctx) _, err = out.ReadFrom(respReader) if err != nil { diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index dd4f32e8cb..fdb2ebb581 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -36,10 +36,17 @@ type Context struct { SubgraphHeadersBuilder SubgraphHeadersBuilder } +// SubgraphHeadersBuilder allows the user of the engine to "define" the headers for a subgraph request +// Instead of going back and forth between engine & transport, +// you can simply define a function that returns headers for a Subgraph request +// In addition to just the header, the implementer can return a hash for the header which will be used by request deduplication type SubgraphHeadersBuilder interface { + // HeadersForSubgraph must return the headers and a hash for a Subgraph Request + // The hash will be used for request deduplication HeadersForSubgraph(subgraphName string) (http.Header, uint64) } +// HeadersForSubgraphRequest returns headers and a hash for a request that the engine will make to a subgraph func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, uint64) { if c.SubgraphHeadersBuilder == nil { return nil, 0 diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 8269768094..4b22dbbcf2 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -137,8 +137,9 @@ type result struct { loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext - out []byte - singleFlightStats *singleFlightStats + // out is the subgraph response body + out []byte + singleFlightStats *singleFlightStats } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -182,6 +183,14 @@ type Loader struct { taintedObjs taintedObjects + // jsonArena is the arena to allocation json, supplied by the Resolver + // Disclaimer: this arena is NOT thread safe! + // Only use from main goroutine + // Don't Reset or Release, the Resolver handles this + // Disclaimer: When parsing json into the arena, the underlying bytes must also be allocated on the arena! + // This is very important to "tie" their lifecycles together + // If you're not doing this, you will see segfaults + // Example of correct usage in func "mergeResult" jsonArena arena.Arena sf *SingleFlight } @@ -773,9 +782,11 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V return err } } - - // If the error propagation mode is pass-through, we append the errors to the root array + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() + // If the error propagation mode is pass-through, we append the errors to the root array l.resolvable.errors.AppendArrayItems(value) return nil } @@ -811,7 +822,9 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V if err := l.addApolloRouterCompatibilityError(res); err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) @@ -1066,7 +1079,9 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { if err != nil { return err } - + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, apolloRouterStatusError) @@ -1081,6 +1096,9 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1093,6 +1111,9 @@ func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, re return err } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1112,6 +1133,9 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s } l.setSubgraphStatusCode([]*astjson.Value{errorObject}, res.statusCode) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1137,6 +1161,9 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } pathPart := l.renderAtPathErrorPart(fetchItem.ResponsePath) extensionErrorCode := fmt.Sprintf(`"extensions":{"code":"%s"}`, errorcodes.UnauthorizedFieldOrType) + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { @@ -1216,6 +1243,9 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result return err } } + // for efficiency purposes, resolvable.errors is not initialized + // don't change this, it's measurable + // downside: we have to verify it's initialized before appending to it l.resolvable.ensureErrorsInitialized() astjson.AppendToArray(l.resolvable.errors, errorObject) return nil @@ -1417,7 +1447,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, } } } - + // I tried using arena here but it only worsened the situation preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) keyGen := pool.Hash64.Get() @@ -1579,6 +1609,7 @@ const ( operationTypeContextKey loaderContextKey = "operationType" ) +// GetOperationTypeFromContext can be used, e.g. by the transport, to check if the operation is a Mutation func GetOperationTypeFromContext(ctx context.Context) ast.OperationType { if ctx == nil { return ast.OperationTypeQuery @@ -1638,6 +1669,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return nil } + // helps the http client to create buffers at the right size ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) defer l.sf.Finish(sfKey, fetchKey, item) @@ -1851,6 +1883,9 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() + // don't use arena here or segfault + // it's also not a hot path and not important to optimize + // arena requires the parsed content to be on the arena as well v, err := astjson.ParseBytes(out) if err != nil { return nil, err diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 21470f475d..cbd1df5ea4 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -31,7 +31,8 @@ type Resolvable struct { errors *astjson.Value valueCompletion *astjson.Value skipAddingNullErrors bool - + // astjsonArena is the arena to handle json, supplied by Resolver + // not thread safe, but Resolvable is single threaded anyways astjsonArena arena.Arena parsers []*astjson.Parser @@ -111,6 +112,7 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames r.data = astjson.ObjectValue(r.astjsonArena) + // don't init errors! It will heavily increase memory usage r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) @@ -129,6 +131,7 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.ctx = ctx r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames + // don't init errors! It will heavily increase memory usage r.errors = nil if initialData != nil { initialValue, err := astjson.ParseBytesWithArena(r.astjsonArena, initialData) @@ -167,6 +170,7 @@ func (r *Resolvable) ResolveNode(node Node, data *astjson.Value, out io.Writer) r.print = false r.printErr = nil r.authorizationError = nil + // don't init errors! It will heavily increase memory usage r.errors = nil hasErrors := r.walkNode(node, data) @@ -233,6 +237,7 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *F return r.printErr } +// ensureErrorsInitialized is used to lazily init r.errors if needed func (r *Resolvable) ensureErrorsInitialized() { if r.errors == nil { r.errors = astjson.ArrayValue(r.astjsonArena) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 5acfc6aad4..39b1a3beaa 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -72,7 +72,13 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - resolveArenaPool *ArenaPool + // resolveArenaPool is the arena pool dedicated for Loader & Resolvable + // ArenaPool automatically adjusts arena buffer sizes per workload + // resolving & response buffering are very different tasks + // as such, it was best to have two arena pools in terms of memory usage + // A single pool for both was much less efficient + resolveArenaPool *ArenaPool + // responseBufferPool is the arena pool dedicated for response buffering before sending to the client responseBufferPool *ArenaPool // Single flight cache for deduplicating requests across all loaders @@ -246,6 +252,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight) *tools { return &tools{ + // we set the arena manually resolvable: NewResolvable(nil, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, From bb33b4b527ada0a4622d1aa1a63927448595d9e7 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 21:48:27 +0200 Subject: [PATCH 18/57] chore: set content length correctly --- v2/pkg/engine/datasource/httpclient/nethttpclient.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index c4ce9915ff..c5f53c7e02 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -219,7 +219,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. if contentLength > 0 { // always set the Content-Length Header so that chunking can be avoided // and other parties can more efficiently parse - request.Header.Set(ContentLengthHeader, fmt.Sprintf("%d", contentLength)) + request.ContentLength = int64(contentLength) } setRequest(ctx, request) From bb31735c6849ba3340a0ded2440450d6d1ea84a4 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 22:12:22 +0200 Subject: [PATCH 19/57] chore: fix bench --- v2/pkg/engine/resolve/loader_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index d6c002393b..f88d7227f6 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -1026,7 +1026,7 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { } resolvable := NewResolvable(nil, ResolvableOptions{}) loader := &Loader{} - expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` + expected := `{"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) b.ReportAllocs() b.ResetTimer() From ce83a7b763be51b37445310919d2d7241b96fa7e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 22:18:35 +0200 Subject: [PATCH 20/57] chore: fix lint --- v2/pkg/engine/resolve/inputtemplate.go | 12 +++++++++--- v2/pkg/engine/resolve/loader.go | 3 +-- v2/pkg/engine/resolve/resolvable.go | 3 +-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 80db3cdd82..0ad72ec949 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -158,14 +158,20 @@ func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, prepar return nil } if len(value) == 1 { - preparedInput.WriteString(value[0]) + if _, err := preparedInput.WriteString(value[0]); err != nil { + return err + } return nil } for j := range value { if j != 0 { - _, _ = preparedInput.Write(literal.COMMA) + if _, err := preparedInput.Write(literal.COMMA); err != nil { + return err + } + } + if _, err := preparedInput.WriteString(value[j]); err != nil { + return err } - preparedInput.WriteString(value[j]) } return nil } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 4b22dbbcf2..73cef311f1 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -17,16 +17,15 @@ import ( "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" "golang.org/x/sync/errgroup" "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index cbd1df5ea4..5879396e7e 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -11,10 +11,9 @@ import ( "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" - "github.com/wundergraph/go-arena" "github.com/wundergraph/astjson" - + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" From 5cfd72d8da0d3074ab4ffc2d139ec71d706da2bc Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sat, 25 Oct 2025 22:32:35 +0200 Subject: [PATCH 21/57] chore: fix lint --- v2/pkg/engine/datasource/httpclient/nethttpclient.go | 2 +- v2/pkg/engine/resolve/loader.go | 1 + v2/pkg/engine/resolve/resolvable.go | 1 + v2/pkg/engine/resolve/resolve.go | 3 ++- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index c5f53c7e02..46af845e4f 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -217,7 +217,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, baseHeaders http. request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) if contentLength > 0 { - // always set the Content-Length Header so that chunking can be avoided + // always set the ContentLength field so that chunking can be avoided // and other parties can more efficiently parse request.ContentLength = int64(contentLength) } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 73cef311f1..340c41894b 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -21,6 +21,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index 5879396e7e..226705a706 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -14,6 +14,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 39b1a3beaa..2f7bec6602 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -12,9 +12,10 @@ import ( "github.com/buger/jsonparser" "github.com/pkg/errors" - "github.com/wundergraph/go-arena" "go.uber.org/atomic" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/xcontext" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) From 4d4b4c5f1679eed3e5761a596835d2409f1ace1f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 08:55:27 +0100 Subject: [PATCH 22/57] chore: cleanup & comments --- v2/pkg/engine/resolve/resolve.go | 21 ++++++------ v2/pkg/engine/resolve/response.go | 15 +++++---- v2/pkg/engine/resolve/singleflight.go | 46 +++++++++++++++++++++------ 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 2f7bec6602..3420e93277 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -251,10 +251,9 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight, a arena.Arena) *tools { return &tools{ - // we set the arena manually - resolvable: NewResolvable(nil, options.ResolvableOptions), + resolvable: NewResolvable(a, options.ResolvableOptions), loader: &Loader{ propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, @@ -271,6 +270,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, sf: sf, + jsonArena: a, }, } } @@ -289,7 +289,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -321,9 +321,9 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) + t.loader.jsonArena = resolveArena.Arena t.resolvable.astjsonArena = resolveArena.Arena @@ -486,11 +486,8 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar input := make([]byte, len(sharedInput)) copy(input, sharedInput) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) - resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) - t.loader.jsonArena = resolveArena.Arena - t.resolvable.astjsonArena = resolveArena.Arena + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) @@ -1097,7 +1094,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1207,7 +1204,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index c02d92f497..1efe078cca 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -16,13 +16,14 @@ type GraphQLSubscription struct { } type GraphQLSubscriptionTrigger struct { - Input []byte - InputTemplate InputTemplate - Variables Variables - Source SubscriptionDataSource - PostProcessing PostProcessingConfiguration - QueryPlan *QueryPlan - SourceName, SourceID string + Input []byte + InputTemplate InputTemplate + Variables Variables + Source SubscriptionDataSource + PostProcessing PostProcessingConfiguration + QueryPlan *QueryPlan + SourceName string + SourceID string } // GraphQLResponse contains an ordered tree of fetches and the response shape. diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/singleflight.go index a179602492..76121d98e9 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/singleflight.go @@ -6,13 +6,6 @@ import ( "github.com/cespare/xxhash/v2" ) -type SingleFlightItem struct { - loaded chan struct{} - response []byte - err error - sizeHint int -} - type SingleFlight struct { mu *sync.RWMutex items map[uint64]*SingleFlightItem @@ -21,8 +14,26 @@ type SingleFlight struct { cleanup chan func() } +// SingleFlightItem is used to communicate between leader and followers +// If an Item for a key doesn't exist, the leader creates and followers can join +type SingleFlightItem struct { + // loaded will be closed by the leader to indicate to followers when the work is done + loaded chan struct{} + // response is the shared result, it must not be modified + response []byte + // err is non nil if the leader produced an error while doing the work + err error + // sizeHint keeps track of the last 50 responses per fetchKey to give an estimate on the size + // this gives a leader a hint on how much space it should pre-allocate for buffers when fetching + // this reduces memory usage + sizeHint int +} + +// fetchSize gives an estimate of required buffer size for a given fetchKey when dividing totalBytes / count type fetchSize struct { - count int + // count is the number of fetches tracked + count int + // totalBytes is the cumulative bytes across tracked fetches totalBytes int } @@ -40,6 +51,13 @@ func NewSingleFlight() *SingleFlight { } } +// GetOrCreateItem generates a single flight key (100% identical fetches) and a fetchKey (similar fetches, collisions possible but unproblematic) +// and return a SingleFlightItem as well as an indication if it's shared or not +// If shared == false, the caller is a leader +// If shared == true, the caller is a follower +// item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader +// item.err must always be checked +// item.response must never be mutated func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { sfKey, fetchKey = s.keys(fetchItem, input, extraKey) @@ -62,6 +80,7 @@ func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extra // Create a new item item = &SingleFlightItem{ + // empty chan to indicate to all followers when we're done (close) loaded: make(chan struct{}), } if size, ok := s.sizes[fetchKey]; ok { @@ -82,6 +101,8 @@ func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) return sfKey, fetchKey } +// sfKey returns a key that 100% uniquely identifies a fetch with no collision +// two sfKey are only the same when the fetches are 100% equal func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() @@ -91,9 +112,13 @@ func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byt } } _, _ = h.Write(input) - return h.Sum64() + extraKey + return h.Sum64() + extraKey // extraKey in this case is the pre-generated hash for the headers } +// fetchKey is a less robust key compared to sfKey +// the purpose is to create a key from the DataSourceID and root fields to have less cardinality +// the goal is to get an estimate buffer size for similar fetches +// there's no point in hashing headers or the body for this purpose func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { if fetchItem == nil || fetchItem.Fetch == nil { return 0 @@ -115,6 +140,9 @@ func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { return h.Sum64() } +// Finish is for the leader to mark the SingleFlightItem as "done" +// trigger all followers to look at the err & response of the item +// and to update the size estimates func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { close(item.loaded) s.mu.Lock() From 48de6512dede3af421afcf66cceb00ac30e74763 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 10:13:09 +0100 Subject: [PATCH 23/57] chore: refactor --- v2/pkg/engine/resolve/resolve.go | 62 ++++++++++++++++---------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 3420e93277..b5e3ff14bd 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -322,11 +322,9 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe }() resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) + // we're intentionally not using defer Release to have more control over the timing (see below) t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) - t.loader.jsonArena = resolveArena.Arena - t.resolvable.astjsonArena = resolveArena.Arena - err := t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) @@ -341,6 +339,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe } } + // only when loading is done, acquire an arena for the response buffer responseArena := r.responseBufferPool.Acquire(ctx.Request.ID) buf := arena.NewArenaBuffer(responseArena.Arena) err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) @@ -350,8 +349,16 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe return nil, err } + // first release resolverArena + // all data is resolved and written into the response arena r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + // next we write back to the client + // this includes flushing and syscalls + // as such, it can take some time + // which is why we split the arenas and released the first one _, err = writer.Write(buf.Bytes()) + // all data is written to the client + // we're safe to release our buffer r.responseBufferPool.Release(ctx.Request.ID, responseArena) return resp, err } @@ -722,16 +729,14 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) asyncDataSource = async } - headers, _ := r.triggerHeaders(add.ctx, add.sourceName) - go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, headers, add.input, updater) + err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.headers, add.input, updater) } else { - err = add.resolve.Trigger.Source.Start(cloneCtx, headers, add.input, updater) + err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, updater) } if err != nil { if r.options.Debug { @@ -1074,9 +1079,17 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { return nil } -func (r *Resolver) triggerHeaders(ctx *Context, sourceName string) (http.Header, uint64) { +// prepareTrigger safely gets the headers for the trigger Subgraph and computes the hash across headers and input +// the generated has is the unique triggerID +// the headers must be forwarded to the DataSource to create the trigger +func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) (headers http.Header, triggerID uint64) { if ctx.SubgraphHeadersBuilder != nil { - return ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) + keyGen := pool.Hash64.Get() + _, _ = keyGen.Write(input) + triggerID = keyGen.Sum64() + headerHash + pool.Hash64.Put(keyGen) + return header, triggerID } return nil, 0 } @@ -1118,20 +1131,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ return nil } - _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) - - xxh := pool.Hash64.Get() - _, _ = xxh.Write(input) - // the hash for subgraph headers is pre-computed - // we can just add it to the input hash to get a unique id - uniqueID := xxh.Sum64() + headersHash - pool.Hash64.Put(xxh) + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ ConnectionID: ConnectionIDs.Inc(), SubscriptionID: 0, } if r.options.Debug { - fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } completed := make(chan struct{}) @@ -1141,7 +1147,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // Stop processing if the resolver is shutting down return r.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ ctx: ctx, @@ -1151,6 +1157,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ id: id, completed: completed, sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1177,13 +1184,13 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } if r.options.Debug { - fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) + fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", triggerID, id.SubscriptionID) } // Remove the subscription when the client disconnects. r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindRemoveSubscription, id: id, } @@ -1228,14 +1235,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } - _, headersHash := r.triggerHeaders(ctx, subscription.Trigger.SourceName) - - xxh := pool.Hash64.Get() - _, _ = xxh.Write(input) - // the hash for subgraph headers is pre-computed - // we can just add it to the input hash to get a unique id - uniqueID := xxh.Sum64() + headersHash - pool.Hash64.Put(xxh) + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) select { case <-r.ctx.Done(): @@ -1245,7 +1245,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // Stop resolving if the client is gone return ctx.ctx.Err() case r.events <- subscriptionEvent{ - triggerID: uniqueID, + triggerID: triggerID, kind: subscriptionEventKindAddSubscription, addSubscription: &addSubscription{ ctx: ctx, @@ -1255,6 +1255,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G id: id, completed: make(chan struct{}), sourceName: subscription.Trigger.SourceName, + headers: headers, }, }: } @@ -1369,6 +1370,7 @@ type addSubscription struct { id SubscriptionIdentifier completed chan struct{} sourceName string + headers http.Header } type subscriptionEventKind int From 6653948325e9f4bf91994f0d743968709771ca24 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 11:56:58 +0100 Subject: [PATCH 24/57] chore: refactor & comments --- v2/pkg/engine/resolve/arena.go | 8 ++++++ v2/pkg/engine/resolve/context.go | 10 ++++++-- v2/pkg/engine/resolve/inputtemplate.go | 2 ++ v2/pkg/engine/resolve/loader.go | 34 ++++++++++++++++++-------- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index 0aae889742..cca1f33125 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -10,12 +10,20 @@ import ( // ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. // It uses weak pointers to allow garbage collection of unused arenas while maintaining // a pool of reusable arenas for high-frequency allocation patterns. +// +// by storing ArenaPoolItem as weak pointers, the GC can collect them at any time +// before using an ArenaPoolItem, we try to get a strong pointer while removing it from the pool +// once we call Release, we turn the item back to the pool and make it a weak pointer again +// this means that at any time, GC can claim back the memory if required, +// allowing GC to automatically manage an appropriate pool size depending on available memory and GC pressure type ArenaPool struct { + // pool is a slice of weak pointers to the struct holding the arena.Arena pool []weak.Pointer[ArenaPoolItem] sizes map[uint64]*arenaPoolItemSize mu sync.Mutex } +// arenaPoolItemSize is used to track the required memory across the last 50 arenas in the pool type arenaPoolItemSize struct { count int totalBytes int diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index fdb2ebb581..52f2eb3bb7 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -55,9 +55,15 @@ func (c *Context) HeadersForSubgraphRequest(subgraphName string) (http.Header, u } type ExecutionOptions struct { - SkipLoader bool + // SkipLoader will, as the name indicates, skip loading data + // However, it does indeed resolve a response + // This can be useful, e.g. in combination with IncludeQueryPlanInResponse + // The purpose is to get a QueryPlan (even for Subscriptions) + SkipLoader bool + // IncludeQueryPlanInResponse generates a QueryPlan as part of the response in Resolvable IncludeQueryPlanInResponse bool - SendHeartbeat bool + // SendHeartbeat sends regular HeartBeats for Subscriptions + SendHeartbeat bool // DisableRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. DisableRequestDeduplication bool } diff --git a/v2/pkg/engine/resolve/inputtemplate.go b/v2/pkg/engine/resolve/inputtemplate.go index 0ad72ec949..e0fc97aa69 100644 --- a/v2/pkg/engine/resolve/inputtemplate.go +++ b/v2/pkg/engine/resolve/inputtemplate.go @@ -55,6 +55,8 @@ func SetInputUndefinedVariables(preparedInput InputTemplateWriter, undefinedVari // to callers; renderSegments intercepts it and writes literal.NULL instead. var errSetTemplateOutputNull = errors.New("set to null") +// InputTemplateWriter is used to decouple Buffer implementations from InputTemplate +// This way, the implementation can easily be swapped, e.g. between bytes.Buffer and similar implementations type InputTemplateWriter interface { io.Writer io.StringWriter diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 340c41894b..4b51df7e66 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -192,7 +192,9 @@ type Loader struct { // If you're not doing this, you will see segfaults // Example of correct usage in func "mergeResult" jsonArena arena.Arena - sf *SingleFlight + // sf is the SingleFlight object shared across all client requests + // it's thread safe and can be used to de-duplicate subgraph requests + sf *SingleFlight } func (l *Loader) Free() { @@ -302,7 +304,6 @@ func (l *Loader) resolveSingle(item *FetchItem) error { if l.ctx.LoaderHooks != nil { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } - return err case *BatchEntityFetch: res := &result{} @@ -438,7 +439,7 @@ func selectItems(a arena.Arena, items []*astjson.Value, element FetchItemPathEle return selected } -func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { +func (l *Loader) itemsData(items []*astjson.Value) *astjson.Value { if len(items) == 0 { return astjson.NullValue } @@ -449,7 +450,7 @@ func itemsData(a arena.Arena, items []*astjson.Value) *astjson.Value { // however, itemsData can be called concurrently, so this might result in a race arr := astjson.MustParseBytes([]byte(`[]`)) for i, item := range items { - arr.SetArrayItem(a, i, item) + arr.SetArrayItem(nil, i, item) } return arr } @@ -553,6 +554,9 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson if len(res.out) == 0 { return l.renderErrorsFailedToFetch(fetchItem, res, emptyGraphQLResponse) } + // before parsing bytes with an arena.Arena, it's important to first allocate the bytes ON the same arena.Arena + // this ties their lifecycles together + // if you don't do this, you'll get segfaults slice := arena.AllocateSlice[byte](l.jsonArena, len(res.out), len(res.out)) copy(slice, res.out) response, err := astjson.ParseBytesWithArena(l.jsonArena, slice) @@ -707,7 +711,7 @@ var ( ) func (l *Loader) renderErrorsInvalidInput(fetchItem *FetchItem) []byte { - out := &bytes.Buffer{} + out := bytes.NewBuffer(nil) elements := fetchItem.ResponsePathElements if len(elements) > 0 && elements[len(elements)-1] == "@" { elements = elements[:len(elements)-1] @@ -1319,7 +1323,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI res.init(fetch.PostProcessing, fetch.Info) buf := bytes.NewBuffer(nil) - inputData := itemsData(l.jsonArena, items) + inputData := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && inputData != nil { @@ -1358,7 +1362,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, fetchI func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *EntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - input := itemsData(l.jsonArena, items) + input := l.itemsData(items) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && input != nil { @@ -1441,17 +1445,22 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData && len(items) != 0 { - data := itemsData(l.jsonArena, items) + data := l.itemsData(items) if data != nil { fetch.Trace.RawInputData, _ = l.compactJSON(data.MarshalTo(nil)) } } } - // I tried using arena here but it only worsened the situation + // I tried using arena here, but it only worsened the situation preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) keyGen := pool.Hash64.Get() - defer pool.Hash64.Put(keyGen) + defer func() { + if keyGen == nil { + return + } + pool.Hash64.Put(keyGen) + }() var undefinedVariables []string @@ -1512,6 +1521,11 @@ WithNextItem: } } + // not used anymore + pool.Hash64.Put(keyGen) + // setting to nil so that the defer func doesn't return it twice + keyGen = nil + if len(itemHashes) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true From 6cbfed0eacdd78479892e925c8bddb8ed905ecce Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 26 Oct 2025 11:57:13 +0100 Subject: [PATCH 25/57] chore: remove unused ParallelListItemFetch --- .../create_concrete_single_fetch_types.go | 8 - v2/pkg/engine/resolve/fetch.go | 32 --- v2/pkg/engine/resolve/fetchtree.go | 25 --- v2/pkg/engine/resolve/loader.go | 62 ------ v2/pkg/engine/resolve/loader_hooks_test.go | 63 ------ v2/pkg/engine/resolve/resolve_test.go | 208 ------------------ 6 files changed, 398 deletions(-) diff --git a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go index f5d0b2ae26..44b3225fbe 100644 --- a/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go +++ b/v2/pkg/engine/postprocess/create_concrete_single_fetch_types.go @@ -51,19 +51,11 @@ func (d *createConcreteSingleFetchTypes) traverseSingleFetch(fetch *resolve.Sing return d.createEntityBatchFetch(fetch) case fetch.RequiresEntityFetch: return d.createEntityFetch(fetch) - case fetch.RequiresParallelListItemFetch: - return d.createParallelListItemFetch(fetch) default: return fetch } } -func (d *createConcreteSingleFetchTypes) createParallelListItemFetch(fetch *resolve.SingleFetch) resolve.Fetch { - return &resolve.ParallelListItemFetch{ - Fetch: fetch, - } -} - func (d *createConcreteSingleFetchTypes) createEntityBatchFetch(fetch *resolve.SingleFetch) resolve.Fetch { representationsVariableIndex := -1 for i, segment := range fetch.InputTemplate.Segments { diff --git a/v2/pkg/engine/resolve/fetch.go b/v2/pkg/engine/resolve/fetch.go index deeea25a41..622e731c4b 100644 --- a/v2/pkg/engine/resolve/fetch.go +++ b/v2/pkg/engine/resolve/fetch.go @@ -12,7 +12,6 @@ type FetchKind int const ( FetchKindSingle FetchKind = iota + 1 - FetchKindParallelListItem FetchKindEntity FetchKindEntityBatch ) @@ -227,27 +226,6 @@ func (*EntityFetch) FetchKind() FetchKind { return FetchKindEntity } -// The ParallelListItemFetch can be used to make nested parallel fetches within a list -// Usually, you want to batch fetches within a list, which is the default behavior of SingleFetch -// However, if the data source does not support batching, you can use this fetch to make parallel fetches within a list -type ParallelListItemFetch struct { - Fetch *SingleFetch - Traces []*SingleFetch - Trace *DataSourceLoadTrace -} - -func (p *ParallelListItemFetch) Dependencies() *FetchDependencies { - return &p.Fetch.FetchDependencies -} - -func (p *ParallelListItemFetch) FetchInfo() *FetchInfo { - return p.Fetch.Info -} - -func (*ParallelListItemFetch) FetchKind() FetchKind { - return FetchKindParallelListItem -} - type QueryPlan struct { DependsOnFields []Representation Query string @@ -272,12 +250,6 @@ type FetchConfiguration struct { Variables Variables DataSource DataSource - // RequiresParallelListItemFetch indicates that the single fetches should be executed without batching. - // If we have multiple fetches attached to the object, then after post-processing of a plan - // we will get ParallelListItemFetch instead of ParallelFetch. - // Happens only for objects under the array path and used only for the introspection. - RequiresParallelListItemFetch bool - // RequiresEntityFetch will be set to true if the fetch is an entity fetch on an object. // After post-processing, we will get EntityFetch. RequiresEntityFetch bool @@ -313,9 +285,6 @@ func (fc *FetchConfiguration) Equals(other *FetchConfiguration) bool { // Note: we do not compare datasources, as they will always be a different instance. - if fc.RequiresParallelListItemFetch != other.RequiresParallelListItemFetch { - return false - } if fc.RequiresEntityFetch != other.RequiresEntityFetch { return false } @@ -505,5 +474,4 @@ var ( _ Fetch = (*SingleFetch)(nil) _ Fetch = (*BatchEntityFetch)(nil) _ Fetch = (*EntityFetch)(nil) - _ Fetch = (*ParallelListItemFetch)(nil) ) diff --git a/v2/pkg/engine/resolve/fetchtree.go b/v2/pkg/engine/resolve/fetchtree.go index f4fd987cea..9bc38497cf 100644 --- a/v2/pkg/engine/resolve/fetchtree.go +++ b/v2/pkg/engine/resolve/fetchtree.go @@ -130,17 +130,6 @@ func (n *FetchTreeNode) Trace() *FetchTreeTraceNode { Trace: f.Trace, Path: n.Item.ResponsePath, } - case *ParallelListItemFetch: - trace.Fetch = &FetchTraceNode{ - Kind: "ParallelList", - SourceID: f.Fetch.Info.DataSourceID, - SourceName: f.Fetch.Info.DataSourceName, - Traces: make([]*DataSourceLoadTrace, len(f.Traces)), - Path: n.Item.ResponsePath, - } - for i, t := range f.Traces { - trace.Fetch.Traces[i] = t.Trace - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: @@ -253,20 +242,6 @@ func (n *FetchTreeNode) queryPlan() *FetchTreeQueryPlanNode { queryPlan.Fetch.Query = f.Info.QueryPlan.Query queryPlan.Fetch.Representations = f.Info.QueryPlan.DependsOnFields } - case *ParallelListItemFetch: - queryPlan.Fetch = &FetchTreeQueryPlan{ - Kind: "ParallelList", - FetchID: f.Fetch.FetchDependencies.FetchID, - DependsOnFetchIDs: f.Fetch.FetchDependencies.DependsOnFetchIDs, - SubgraphName: f.Fetch.Info.DataSourceName, - SubgraphID: f.Fetch.Info.DataSourceID, - Path: n.Item.ResponsePath, - } - - if f.Fetch.Info.QueryPlan != nil { - queryPlan.Fetch.Query = f.Fetch.Info.QueryPlan.Query - queryPlan.Fetch.Representations = f.Fetch.Info.QueryPlan.DependsOnFields - } default: } case FetchTreeNodeKindSequence, FetchTreeNodeKindParallel: diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 4b51df7e66..cff02a4882 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -327,41 +327,6 @@ func (l *Loader) resolveSingle(item *FetchItem) error { l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{} - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], item, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, item, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - for i := range results { - err = l.mergeResult(item, results[i], items[i:i+1]) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) - } - if err != nil { - return errors.WithStack(err) - } - } - return nil default: return nil } @@ -459,33 +424,6 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, fetchItem *FetchIte switch f := fetch.(type) { case *SingleFetch: return l.loadSingleFetch(ctx, f, fetchItem, items, res) - case *ParallelListItemFetch: - results := make([]*result, len(items)) - if l.ctx.TracingOptions.Enable { - f.Traces = make([]*SingleFetch, len(items)) - } - g, ctx := errgroup.WithContext(l.ctx.ctx) - for i := range items { - i := i - results[i] = &result{} - if l.ctx.TracingOptions.Enable { - f.Traces[i] = new(SingleFetch) - *f.Traces[i] = *f.Fetch - g.Go(func() error { - return l.loadFetch(ctx, f.Traces[i], fetchItem, items[i:i+1], results[i]) - }) - continue - } - g.Go(func() error { - return l.loadFetch(ctx, f.Fetch, fetchItem, items[i:i+1], results[i]) - }) - } - err := g.Wait() - if err != nil { - return errors.WithStack(err) - } - res.nestedMergeItems = results - return nil case *EntityFetch: return l.loadEntityFetch(ctx, fetchItem, f, items, res) case *BatchEntityFetch: diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index ebe263dcd9..4a2ce9cb2e 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -248,69 +248,6 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { } })) - t.Run("parallel list item fetch with simple subgraph error", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { - mockDataSource := NewMockDataSource(ctrl) - mockDataSource.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - return []byte(`{"errors":[{"message":"errorMessage"}]}`), nil - }) - resolveCtx := Context{ - ctx: context.Background(), - LoaderHooks: NewTestLoaderHooks(), - } - return &GraphQLResponse{ - Info: &GraphQLResponseInfo{ - OperationType: ast.OperationTypeQuery, - }, - Fetches: SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: mockDataSource, - PostProcessing: PostProcessingConfiguration{ - SelectResponseErrorsPath: []string{"errors"}, - }, - }, - Info: &FetchInfo{ - DataSourceID: "Users", - DataSourceName: "Users", - }, - }, - }, "query"), - Data: &Object{ - Nullable: false, - Fields: []*Field{ - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - Nullable: true, - }, - }, - }, - }, - }, &resolveCtx, `{"errors":[{"message":"Failed to fetch from Subgraph 'Users' at Path 'query'.","extensions":{"errors":[{"message":"errorMessage"}]}}],"data":{"name":null}}`, - func(t *testing.T) { - loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) - - assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) - assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) - - var subgraphError *SubgraphError - assert.Len(t, loaderHooks.errors, 1) - assert.ErrorAs(t, loaderHooks.errors[0], &subgraphError) - assert.Equal(t, "Users", subgraphError.DataSourceInfo.Name) - assert.Equal(t, "query", subgraphError.Path) - assert.Equal(t, "", subgraphError.Reason) - assert.Equal(t, 0, subgraphError.ResponseCode) - assert.Len(t, subgraphError.DownstreamErrors, 1) - assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message) - assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions) - - assert.NotNil(t, resolveCtx.SubgraphErrors()) - } - })) - t.Run("fetch with subgraph error and custom extension code. No extension fields are propagated by default", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 5c2ea4ed66..1127760377 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -2793,214 +2793,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, Context{ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"firstArg":"firstArgValue","thirdArg":123,"secondArg": true, "fourthArg": 12.34}`))}, `{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"serviceTwo":{"fieldTwo":"fieldTwoValue","serviceOneResponse":{"fieldOne":"fieldOneValue"}},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}` })) t.Run("federation", func(t *testing.T) { - t.Run("simple", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - - userService := NewMockDataSource(ctrl) - userService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}` - assert.Equal(t, expected, actual) - return []byte(`{"data":{"me":{"id":"1234","username":"Me","__typename":"User"}}}`), nil - }) - - reviewsService := NewMockDataSource(ctrl) - reviewsService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - actual := string(input) - expected := `{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[{"id":"1234","__typename":"User"}]}}}` - assert.Equal(t, expected, actual) - return []byte(`{"data":{"_entities":[{"reviews":[{"body": "A highly effective form of birth control.","product": {"upc": "top-1","__typename": "Product"}},{"body": "Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product": {"upc": "top-2","__typename": "Product"}}]}]}}`), nil - }) - - var productServiceCallCount atomic.Int64 - - productService := NewMockDataSource(ctrl) - productService.EXPECT(). - Load(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { - actual := string(input) - productServiceCallCount.Add(1) - switch actual { - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-1","__typename":"Product"}]}}}`: - return []byte(`{"data":{"_entities":[{"name": "Furby"}]}}`), nil - case `{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[{"upc":"top-2","__typename":"Product"}]}}}`: - return []byte(`{"data":{"_entities":[{"name": "Trilby"}]}}`), nil - default: - t.Fatalf("unexpected request: %s", actual) - } - return nil, nil - }).Times(2) - - return &GraphQLResponse{ - Fetches: Sequence( - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4001","body":{"query":"{me {id username}}"}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: userService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data"}, - }, - }, - }, "query"), - SingleWithPath(&SingleFetch{ - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on User {reviews {body product {upc __typename}}}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - FetchConfiguration: FetchConfiguration{ - DataSource: reviewsService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - }, "query.me", ObjectPath("me")), - SingleWithPath(&ParallelListItemFetch{ - Fetch: &SingleFetch{ - FetchConfiguration: FetchConfiguration{ - DataSource: productService, - PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "0"}, - }, - }, - InputTemplate: InputTemplate{ - Segments: []TemplateSegment{ - { - Data: []byte(`{"method":"POST","url":"http://localhost:4003","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){... on Product {name}}}","variables":{"representations":[`), - SegmentType: StaticSegmentType, - }, - { - SegmentType: VariableSegmentType, - VariableKind: ResolvableObjectVariableKind, - Renderer: NewGraphQLVariableResolveRenderer(&Object{ - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("__typename"), - Value: &String{ - Path: []string{"__typename"}, - }, - }, - }, - }), - }, - { - Data: []byte(`]}}}`), - SegmentType: StaticSegmentType, - }, - }, - }, - }, - }, "query.me.reviews.@.product", ObjectPath("me"), ArrayPath("reviews"), ObjectPath("product")), - ), - Data: &Object{ - Fields: []*Field{ - { - Name: []byte("me"), - Value: &Object{ - Path: []string{"me"}, - Nullable: true, - Fields: []*Field{ - { - Name: []byte("id"), - Value: &String{ - Path: []string{"id"}, - }, - }, - { - Name: []byte("username"), - Value: &String{ - Path: []string{"username"}, - }, - }, - { - - Name: []byte("reviews"), - Value: &Array{ - Path: []string{"reviews"}, - Nullable: true, - Item: &Object{ - Nullable: true, - Fields: []*Field{ - { - Name: []byte("body"), - Value: &String{ - Path: []string{"body"}, - }, - }, - { - Name: []byte("product"), - Value: &Object{ - Path: []string{"product"}, - Fields: []*Field{ - { - Name: []byte("upc"), - Value: &String{ - Path: []string{"upc"}, - }, - }, - { - Name: []byte("name"), - Value: &String{ - Path: []string{"name"}, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Furby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Trilby"}}]}}}` - })) t.Run("federation with batch", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) userService.EXPECT(). From daa18e84c305e3d99d62749706034b27d38c1aad Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 07:56:36 +0100 Subject: [PATCH 26/57] chore: simplify batchStats logic --- v2/pkg/engine/resolve/loader.go | 106 +++++++++++++++----------------- 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index cff02a4882..1f63375608 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -91,34 +91,32 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo { return responseInfo } -// batchStats represents an index map for batched items. -// It is used to ensure that the correct json values will be merged with the correct items from the batch. +// batchStats represents per-unique-batch-item merge targets. +// Outer slice index corresponds to the unique representation index in the request batch, +// and the inner slice contains all target values that should be merged with the response at that index. // // Example: -// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1). -// This means we are deduplicating 2 items by merging them from their response entity indexes. -// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1 -type batchStats [][]int - -// getUniqueIndexes returns the number of unique indexes in the batchStats. -// This is used to ensure that we can provide a valid error message in case of differing array lengths. -func (b *batchStats) getUniqueIndexes() int { - uniqueIndexes := make(map[int]struct{}) - for _, bi := range *b { - for _, index := range bi { - if index < 0 { - continue - } - uniqueIndexes[index] = struct{}{} - } - } +// For 4 original items that deduplicate to 2 unique representations, we might have: +// [ +// +// [item0, item2], // merge response[0] into item0 and item2 +// [item1, item3], // merge response[1] into item1 and item3 +// +// ] +type batchStats [][]*astjson.Value - return len(uniqueIndexes) +// expectedNumberOfBatchItems returns the number of unique indexes in the batchStats. +// With the new structure, this equals the outer slice length. +func (b *batchStats) expectedNumberOfBatchItems() int { + return len(*b) } type result struct { - postProcessing PostProcessingConfiguration - batchStats batchStats + postProcessing PostProcessingConfiguration + batchStats batchStats + // batchHashToIndex maps a request item hash to its unique batch index. + // Used during request construction and to avoid recomputing uniqueness. + batchHashToIndex map[uint64]int fetchSkipped bool nestedMergeItems []*result @@ -597,26 +595,24 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.batchStats != nil { - uniqueIndexes := res.batchStats.getUniqueIndexes() - if uniqueIndexes != len(batch) { - return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch))) + expectedBatchItems := res.batchStats.expectedNumberOfBatchItems() + if expectedBatchItems != len(batch) { + return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, expectedBatchItems, len(batch))) } - for i, stats := range res.batchStats { - for _, idx := range stats { - if idx == -1 { - continue - } - items[i], _, err = astjson.MergeValuesWithPath(l.jsonArena, items[i], batch[idx], res.postProcessing.MergePath...) - if err != nil { + for batchIndex, targets := range res.batchStats { + src := batch[batchIndex] + for _, target := range targets { + _, _, mErr := astjson.MergeValuesWithPath(l.jsonArena, target, src, res.postProcessing.MergePath...) + if mErr != nil { return errors.WithStack(ErrMergeResult{ Subgraph: res.ds.Name, - Reason: err, + Reason: mErr, Path: fetchItem.ResponsePath, }) } - if slices.Contains(taintedIndices, idx) { - l.taintedObjs.add(items[i]) + if slices.Contains(taintedIndices, batchIndex) { + l.taintedObjs.add(target) } } } @@ -1406,8 +1402,8 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, if err != nil { return errors.WithStack(err) } - res.batchStats = make(batchStats, len(items)) - itemHashes := make([]uint64, 0, len(items)) + res.batchStats = make(batchStats, 0, len(items)) + res.batchHashToIndex = make(map[uint64]int, len(items)) batchItemIndex := 0 addSeparator := false @@ -1419,7 +1415,6 @@ WithNextItem: if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign - res.batchStats[i] = append(res.batchStats[i], -1) continue } if l.ctx.TracingOptions.Enable { @@ -1428,34 +1423,33 @@ WithNextItem: return errors.WithStack(err) } if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { - res.batchStats[i] = append(res.batchStats[i], -1) continue } if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { - res.batchStats[i] = append(res.batchStats[i], -1) continue } keyGen.Reset() _, _ = keyGen.Write(itemInput.Bytes()) itemHash := keyGen.Sum64() - for k := range itemHashes { - if itemHashes[k] == itemHash { - res.batchStats[i] = append(res.batchStats[i], k) - continue WithNextItem - } - } - itemHashes = append(itemHashes, itemHash) - if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) - if err != nil { - return errors.WithStack(err) + if existingIndex, ok := res.batchHashToIndex[itemHash]; ok { + res.batchStats[existingIndex] = append(res.batchStats[existingIndex], items[i]) + continue WithNextItem + } else { + if addSeparator { + err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) + if err != nil { + return errors.WithStack(err) + } } + _, _ = itemInput.WriteTo(preparedInput) + // new unique representation + res.batchHashToIndex[itemHash] = batchItemIndex + // create a new targets bucket for this unique index + res.batchStats = append(res.batchStats, []*astjson.Value{items[i]}) + batchItemIndex++ + addSeparator = true } - _, _ = itemInput.WriteTo(preparedInput) - res.batchStats[i] = append(res.batchStats[i], batchItemIndex) - batchItemIndex++ - addSeparator = true } } @@ -1464,7 +1458,7 @@ WithNextItem: // setting to nil so that the defer func doesn't return it twice keyGen = nil - if len(itemHashes) == 0 { + if len(res.batchStats) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true if l.ctx.TracingOptions.Enable { From 2003186c30fa9680eb6900f6b3e6662146631149 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 08:24:33 +0100 Subject: [PATCH 27/57] chore: simplify --- v2/pkg/engine/resolve/loader.go | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 1f63375608..971dc4a169 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1389,12 +1389,6 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) itemInput := bytes.NewBuffer(make([]byte, 0, 32)) keyGen := pool.Hash64.Get() - defer func() { - if keyGen == nil { - return - } - pool.Hash64.Put(keyGen) - }() var undefinedVariables []string @@ -1420,6 +1414,7 @@ WithNextItem: if l.ctx.TracingOptions.Enable { fetch.Trace.LoadSkipped = true } + pool.Hash64.Put(keyGen) return errors.WithStack(err) } if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { @@ -1439,6 +1434,7 @@ WithNextItem: if addSeparator { err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) if err != nil { + pool.Hash64.Put(keyGen) return errors.WithStack(err) } } @@ -1453,10 +1449,7 @@ WithNextItem: } } - // not used anymore pool.Hash64.Put(keyGen) - // setting to nil so that the defer func doesn't return it twice - keyGen = nil if len(res.batchStats) == 0 { // all items were skipped - discard fetch From 0c0e1ce22ae21f98941d54d4d41deed7948cd3a4 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 11:37:24 +0100 Subject: [PATCH 28/57] chore: add tools pool for loadBatchEntityFetch --- v2/pkg/engine/resolve/loader.go | 137 +++++++++++++++++++++----------- 1 file changed, 91 insertions(+), 46 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 971dc4a169..a4893ef73d 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -11,9 +11,11 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -26,7 +28,6 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/errorcodes" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -91,32 +92,21 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo { return responseInfo } -// batchStats represents per-unique-batch-item merge targets. -// Outer slice index corresponds to the unique representation index in the request batch, -// and the inner slice contains all target values that should be merged with the response at that index. -// -// Example: -// For 4 original items that deduplicate to 2 unique representations, we might have: -// [ -// -// [item0, item2], // merge response[0] into item0 and item2 -// [item1, item3], // merge response[1] into item1 and item3 -// -// ] -type batchStats [][]*astjson.Value - -// expectedNumberOfBatchItems returns the number of unique indexes in the batchStats. -// With the new structure, this equals the outer slice length. -func (b *batchStats) expectedNumberOfBatchItems() int { - return len(*b) -} - type result struct { postProcessing PostProcessingConfiguration - batchStats batchStats - // batchHashToIndex maps a request item hash to its unique batch index. - // Used during request construction and to avoid recomputing uniqueness. - batchHashToIndex map[uint64]int + // batchStats represents per-unique-batch-item merge targets. + // Outer slice index corresponds to the unique representation index in the request batch, + // and the inner slice contains all target values that should be merged with the response at that index. + // + // Example: + // For 4 original items that deduplicate to 2 unique representations, we might have: + // [ + // + // [item0, item2], // merge response[0] into item0 and item2 + // [item1, item3], // merge response[1] into item1 and item3 + // + // ] + batchStats [][]*astjson.Value fetchSkipped bool nestedMergeItems []*result @@ -138,6 +128,7 @@ type result struct { // out is the subgraph response body out []byte singleFlightStats *singleFlightStats + tools *batchEntityTools } func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { @@ -231,6 +222,12 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { return nil } results := make([]*result, len(nodes)) + defer func() { + for i := range results { + // no-op if tools == nil + batchEntityToolPool.Put(results[i].tools) + } + }() itemsItems := make([][]*astjson.Value, len(nodes)) g, ctx := errgroup.WithContext(l.ctx.ctx) for i := range nodes { @@ -305,6 +302,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err case *BatchEntityFetch: res := &result{} + defer batchEntityToolPool.Put(res.tools) err := l.loadBatchEntityFetch(l.ctx.ctx, item, f, items, res) if err != nil { return errors.WithStack(err) @@ -595,9 +593,8 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson } if res.batchStats != nil { - expectedBatchItems := res.batchStats.expectedNumberOfBatchItems() - if expectedBatchItems != len(batch) { - return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, expectedBatchItems, len(batch))) + if len(res.batchStats) != len(batch) { + return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, len(res.batchStats), len(batch))) } for batchIndex, targets := range res.batchStats { @@ -1373,6 +1370,48 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetchItem *FetchItem, fetc return nil } +type batchEntityTools struct { + keyGen *xxhash.Digest + batchHashToIndex map[uint64]int + a arena.Arena +} + +func (b *batchEntityTools) reset() { + b.keyGen.Reset() + b.a.Reset() + for i := range b.batchHashToIndex { + delete(b.batchHashToIndex, i) + } +} + +type _batchEntityToolPool struct { + pool sync.Pool +} + +func (p *_batchEntityToolPool) Get(items int) *batchEntityTools { + item := p.pool.Get() + if item == nil { + return &batchEntityTools{ + keyGen: xxhash.New(), + batchHashToIndex: make(map[uint64]int, items), + a: arena.NewMonotonicArena(arena.WithMinBufferSize(1024)), + } + } + return item.(*batchEntityTools) +} + +func (p *_batchEntityToolPool) Put(item *batchEntityTools) { + if item == nil { + return + } + item.reset() + p.pool.Put(item) +} + +var ( + batchEntityToolPool = _batchEntityToolPool{} +) + func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, fetch *BatchEntityFetch, items []*astjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) @@ -1385,19 +1424,19 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, } } } - // I tried using arena here, but it only worsened the situation - preparedInput := bytes.NewBuffer(make([]byte, 0, 64)) - itemInput := bytes.NewBuffer(make([]byte, 0, 32)) - keyGen := pool.Hash64.Get() + res.tools = batchEntityToolPool.Get(len(items)) + preparedInput := arena.NewArenaBuffer(res.tools.a) + itemInput := arena.NewArenaBuffer(res.tools.a) + batchStats := arena.AllocateSlice[[]*astjson.Value](res.tools.a, 0, len(items)) + + // I tried using arena here, but it only worsened the situation var undefinedVariables []string err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - res.batchStats = make(batchStats, 0, len(items)) - res.batchHashToIndex = make(map[uint64]int, len(items)) batchItemIndex := 0 addSeparator := false @@ -1414,7 +1453,6 @@ WithNextItem: if l.ctx.TracingOptions.Enable { fetch.Trace.LoadSkipped = true } - pool.Hash64.Put(keyGen) return errors.WithStack(err) } if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { @@ -1424,34 +1462,31 @@ WithNextItem: continue } - keyGen.Reset() - _, _ = keyGen.Write(itemInput.Bytes()) - itemHash := keyGen.Sum64() - if existingIndex, ok := res.batchHashToIndex[itemHash]; ok { - res.batchStats[existingIndex] = append(res.batchStats[existingIndex], items[i]) + res.tools.keyGen.Reset() + _, _ = res.tools.keyGen.Write(itemInput.Bytes()) + itemHash := res.tools.keyGen.Sum64() + if existingIndex, ok := res.tools.batchHashToIndex[itemHash]; ok { + batchStats[existingIndex] = arena.SliceAppend(res.tools.a, batchStats[existingIndex], items[i]) continue WithNextItem } else { if addSeparator { err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) if err != nil { - pool.Hash64.Put(keyGen) return errors.WithStack(err) } } _, _ = itemInput.WriteTo(preparedInput) // new unique representation - res.batchHashToIndex[itemHash] = batchItemIndex + res.tools.batchHashToIndex[itemHash] = batchItemIndex // create a new targets bucket for this unique index - res.batchStats = append(res.batchStats, []*astjson.Value{items[i]}) + batchStats = arena.SliceAppend(res.tools.a, batchStats, []*astjson.Value{items[i]}) batchItemIndex++ addSeparator = true } } } - pool.Hash64.Put(keyGen) - - if len(res.batchStats) == 0 { + if len(batchStats) == 0 { // all items were skipped - discard fetch res.fetchSkipped = true if l.ctx.TracingOptions.Enable { @@ -1470,7 +1505,16 @@ WithNextItem: if err != nil { return errors.WithStack(err) } + fetchInput := preparedInput.Bytes() + // it's important to copy the *astjson.Value's off the arena to avoid memory corruption + res.batchStats = make([][]*astjson.Value, len(batchStats)) + for i := range batchStats { + res.batchStats[i] = make([]*astjson.Value, len(batchStats[i])) + copy(res.batchStats[i], batchStats[i]) + batchStats[i] = nil + } + batchStats = nil if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) @@ -1484,6 +1528,7 @@ WithNextItem: if !allowed { return nil } + l.executeSourceLoad(ctx, fetchItem, fetch.DataSource, fetchInput, res, fetch.Trace) return nil } From 8e3d0df3ed11e4c8a2799f2c16c1759ba160fd0f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 13:31:29 +0100 Subject: [PATCH 29/57] chore: improved cleanup --- v2/pkg/engine/resolve/loader.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index a4893ef73d..e4bd36d813 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1429,6 +1429,16 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem, preparedInput := arena.NewArenaBuffer(res.tools.a) itemInput := arena.NewArenaBuffer(res.tools.a) batchStats := arena.AllocateSlice[[]*astjson.Value](res.tools.a, 0, len(items)) + defer func() { + // we need to clear the batchStats slice to avoid memory corruption + // once the outer func returns, we must not keep pointers to items on the arena + for i := range batchStats { + // nolint:ineffassign + batchStats[i] = nil + } + // nolint:ineffassign + batchStats = nil + }() // I tried using arena here, but it only worsened the situation var undefinedVariables []string @@ -1512,9 +1522,7 @@ WithNextItem: for i := range batchStats { res.batchStats[i] = make([]*astjson.Value, len(batchStats[i])) copy(res.batchStats[i], batchStats[i]) - batchStats[i] = nil } - batchStats = nil if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchItem, fetchInput, fetch.Trace) From f3f2a8ef3dea9f5d59522ac3ec530bf82bac312e Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 19:28:46 +0100 Subject: [PATCH 30/57] chore: refactor, docs, inbound sf --- v2/pkg/engine/resolve/const.go | 2 + v2/pkg/engine/resolve/context.go | 3 + .../resolve/inbound_request_singleflight.go | 138 ++++++++++++++++++ v2/pkg/engine/resolve/loader.go | 4 +- v2/pkg/engine/resolve/resolve.go | 37 +++-- ...ht.go => subgraph_request_singleflight.go} | 92 ++++++++---- 6 files changed, 232 insertions(+), 44 deletions(-) create mode 100644 v2/pkg/engine/resolve/inbound_request_singleflight.go rename v2/pkg/engine/resolve/{singleflight.go => subgraph_request_singleflight.go} (61%) diff --git a/v2/pkg/engine/resolve/const.go b/v2/pkg/engine/resolve/const.go index 8a259494ec..2958fe1f54 100644 --- a/v2/pkg/engine/resolve/const.go +++ b/v2/pkg/engine/resolve/const.go @@ -8,6 +8,8 @@ var ( lBrack = []byte("[") rBrack = []byte("]") comma = []byte(",") + pipe = []byte("|") + dot = []byte(".") colon = []byte(":") quote = []byte("\"") null = []byte("null") diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 52f2eb3bb7..d6a8657e46 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -16,6 +16,7 @@ import ( type Context struct { ctx context.Context Variables *astjson.Value + VariablesHash uint64 Files []*httpclient.FileUpload Request Request RenameTypeNames []RenameTypeName @@ -44,6 +45,8 @@ type SubgraphHeadersBuilder interface { // HeadersForSubgraph must return the headers and a hash for a Subgraph Request // The hash will be used for request deduplication HeadersForSubgraph(subgraphName string) (http.Header, uint64) + // HashAll must return the hash for all subgraph requests combined + HashAll() uint64 } // HeadersForSubgraphRequest returns headers and a hash for a request that the engine will make to a subgraph diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go new file mode 100644 index 0000000000..995ee390c7 --- /dev/null +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -0,0 +1,138 @@ +package resolve + +import ( + "sync" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +// InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests +// It's taking into consideration the normalized operation hash, variables hash and headers hash +// making it robust against collisions +// for scalability, you can add more shards in case the mutexes are a bottleneck +type InboundRequestSingleFlight struct { + shards []requestShard +} + +type requestShard struct { + mu sync.Mutex + m map[uint64]*InflightRequest +} + +const defaultRequestSingleFlightShardCount = 4 + +// NewRequestSingleFlight creates a InboundRequestSingleFlight with the provided +// number of shards. If shardCount <= 0, the default of 4 is used. +func NewRequestSingleFlight(shardCount int) *InboundRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultRequestSingleFlightShardCount + } + r := &InboundRequestSingleFlight{ + shards: make([]requestShard, shardCount), + } + for i := range r.shards { + r.shards[i] = requestShard{ + m: make(map[uint64]*InflightRequest), + } + } + return r +} + +type InflightRequest struct { + Done chan struct{} + Data []byte + Err error + ID uint64 + HasFollowers bool +} + +// GetOrCreate creates a new InflightRequest or returns an existing (shared) one +// The first caller to create an InflightRequest for a given key is a leader, everyone else a follower +// GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed +// It returns an error if the leader returned an error +// It returns nil,nil if the inbound request is not eligible for request deduplication +func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { + + if ctx.ExecutionOptions.DisableRequestDeduplication { + return nil, nil + } + + if response != nil && response.Info != nil && response.Info.OperationType == ast.OperationTypeMutation { + return nil, nil + } + + // ctx.Request.ID is the unique ID of the normalized GraphQL document +1 (offset) + key := ctx.Request.ID + 1 + // ctx.VariablesHash is the hash of the normalized variables from the client request + // this makes the key unique across different variables + key += ctx.VariablesHash + 1 + if ctx.SubgraphHeadersBuilder != nil { + // ctx.SubgraphHeadersBuilder.HashAll() returns the hash of all headers that will be forwarded to all subgraphs + // this makes the key unique across different client request headers, given that we forward them + // we pre-compute all headers that will be forwarded to each subgraph + // if we combine all the subgraph header hashes, the key will be stable across all headers + key += ctx.SubgraphHeadersBuilder.HashAll() + } + + shard := r.shardFor(key) + shard.mu.Lock() + req, shared := shard.m[key] + if shared { + req.HasFollowers = true + shard.mu.Unlock() + select { + case <-req.Done: + if req.Err != nil { + return nil, req.Err + } + return req, nil + case <-ctx.ctx.Done(): + return nil, ctx.ctx.Err() + } + } + + req = &InflightRequest{ + Done: make(chan struct{}), + ID: key, + } + + shard.m[key] = req + shard.mu.Unlock() + return req, nil +} + +func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.mu.Lock() + delete(shard.m, req.ID) + hasFollowers := req.HasFollowers + shard.mu.Unlock() + if hasFollowers { + // optimization to only copy when we actually have to + req.Data = make([]byte, len(data)) + copy(req.Data, data) + } + close(req.Done) +} + +func (r *InboundRequestSingleFlight) FinishErr(req *InflightRequest, err error) { + if req == nil { + return + } + shard := r.shardFor(req.ID) + shard.mu.Lock() + delete(shard.m, req.ID) + shard.mu.Unlock() + req.Err = err + close(req.Done) +} + +func (r *InboundRequestSingleFlight) shardFor(key uint64) *requestShard { + // Fast modulo using power-of-two shard count if desired in the future. + // For now, use standard modulo for clarity. + idx := int(key % uint64(len(r.shards))) + return &r.shards[idx] +} diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index e4bd36d813..63cda90b28 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -181,9 +181,9 @@ type Loader struct { // If you're not doing this, you will see segfaults // Example of correct usage in func "mergeResult" jsonArena arena.Arena - // sf is the SingleFlight object shared across all client requests + // sf is the SubgraphRequestSingleFlight object shared across all client requests // it's thread safe and can be used to de-duplicate subgraph requests - sf *SingleFlight + sf *SubgraphRequestSingleFlight } func (l *Loader) Free() { diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b5e3ff14bd..dc1f0ba851 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -82,8 +82,10 @@ type Resolver struct { // responseBufferPool is the arena pool dedicated for response buffering before sending to the client responseBufferPool *ArenaPool - // Single flight cache for deduplicating requests across all loaders - sf *SingleFlight + // subgraphRequestSingleFlight is used to de-duplicate subgraph requests + subgraphRequestSingleFlight *SubgraphRequestSingleFlight + // inboundRequestSingleFlight is used to de-duplicate subgraph requests + inboundRequestSingleFlight *InboundRequestSingleFlight } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -239,7 +241,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, resolveArenaPool: NewArenaPool(), responseBufferPool: NewArenaPool(), - sf: NewSingleFlight(), + subgraphRequestSingleFlight: NewSingleFlight(8), + inboundRequestSingleFlight: NewRequestSingleFlight(8), } resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) for i := 0; i < options.MaxConcurrency; i++ { @@ -251,7 +254,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { return resolver } -func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SingleFlight, a arena.Arena) *tools { +func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{}, allowedErrorFields map[string]struct{}, sf *SubgraphRequestSingleFlight, a arena.Arena) *tools { return &tools{ resolvable: NewResolvable(a, options.ResolvableOptions), loader: &Loader{ @@ -289,7 +292,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons r.maxConcurrency <- struct{}{} }() - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err := t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { @@ -314,6 +317,16 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, writer io.Writer) (*GraphQLResolveInfo, error) { resp := &GraphQLResolveInfo{} + inflight, err := r.inboundRequestSingleFlight.GetOrCreate(ctx, response) + if err != nil { + return nil, err + } + + if inflight != nil && inflight.Data != nil { // follower + _, err = writer.Write(inflight.Data) + return resp, err + } + start := time.Now() <-r.maxConcurrency resp.ResolveAcquireWaitTime = time.Since(start) @@ -323,10 +336,11 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe resolveArena := r.resolveArenaPool.Acquire(ctx.Request.ID) // we're intentionally not using defer Release to have more control over the timing (see below) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) - err := t.resolvable.Init(ctx, nil, response.Info.OperationType) + err = t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } @@ -334,6 +348,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe if !ctx.ExecutionOptions.SkipLoader { err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) return nil, err } @@ -344,6 +359,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe buf := arena.NewArenaBuffer(responseArena.Arena) err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) if err != nil { + r.inboundRequestSingleFlight.FinishErr(inflight, err) r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) r.responseBufferPool.Release(ctx.Request.ID, responseArena) return nil, err @@ -357,6 +373,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe // as such, it can take some time // which is why we split the arenas and released the first one _, err = writer.Write(buf.Bytes()) + r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) // all data is written to the client // we're safe to release our buffer r.responseBufferPool.Release(ctx.Request.ID, responseArena) @@ -494,7 +511,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar copy(input, sharedInput) resolveArena := r.resolveArenaPool.Acquire(resolveCtx.Request.ID) - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, resolveArena.Arena) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) @@ -1107,7 +1124,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { @@ -1211,7 +1228,7 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G // If SkipLoader is enabled, we skip retrieving actual data. For example, this is useful when requesting a query plan. // By returning early, we avoid starting a subscription and resolve with empty data instead. if ctx.ExecutionOptions.SkipLoader { - t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.sf, nil) + t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, nil) err = t.resolvable.InitSubscription(ctx, nil, subscription.Trigger.PostProcessing) if err != nil { diff --git a/v2/pkg/engine/resolve/singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go similarity index 61% rename from v2/pkg/engine/resolve/singleflight.go rename to v2/pkg/engine/resolve/subgraph_request_singleflight.go index 76121d98e9..013d906775 100644 --- a/v2/pkg/engine/resolve/singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -6,14 +6,23 @@ import ( "github.com/cespare/xxhash/v2" ) -type SingleFlight struct { - mu *sync.RWMutex - items map[uint64]*SingleFlightItem - sizes map[uint64]*fetchSize +// SubgraphRequestSingleFlight is a sharded, goroutine safe single flight implementation to de-duplicate subgraph requests +// It's hashing the input and adds the pre-computed subgraph headers hash to avoid collisions +// In addition to single flight, it provides size hints to create right-sized buffers for subgraph requests +type SubgraphRequestSingleFlight struct { + shards []singleFlightShard xxPool *sync.Pool cleanup chan func() } +type singleFlightShard struct { + mu sync.RWMutex + items map[uint64]*SingleFlightItem + sizes map[uint64]*fetchSize +} + +const defaultSingleFlightShardCount = 4 + // SingleFlightItem is used to communicate between leader and followers // If an Item for a key doesn't exist, the leader creates and followers can join type SingleFlightItem struct { @@ -37,11 +46,12 @@ type fetchSize struct { totalBytes int } -func NewSingleFlight() *SingleFlight { - return &SingleFlight{ - items: make(map[uint64]*SingleFlightItem), - sizes: make(map[uint64]*fetchSize), - mu: new(sync.RWMutex), +func NewSingleFlight(shardCount int) *SubgraphRequestSingleFlight { + if shardCount <= 0 { + shardCount = defaultSingleFlightShardCount + } + s := &SubgraphRequestSingleFlight{ + shards: make([]singleFlightShard, shardCount), xxPool: &sync.Pool{ New: func() any { return xxhash.New() @@ -49,6 +59,13 @@ func NewSingleFlight() *SingleFlight { }, cleanup: make(chan func()), } + for i := range s.shards { + s.shards[i] = singleFlightShard{ + items: make(map[uint64]*SingleFlightItem), + sizes: make(map[uint64]*fetchSize), + } + } + return s } // GetOrCreateItem generates a single flight key (100% identical fetches) and a fetchKey (similar fetches, collisions possible but unproblematic) @@ -58,23 +75,26 @@ func NewSingleFlight() *SingleFlight { // item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader // item.err must always be checked // item.response must never be mutated -func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { +func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { sfKey, fetchKey = s.keys(fetchItem, input, extraKey) - // First, try to get the item with a read lock - s.mu.RLock() - item, exists := s.items[sfKey] - s.mu.RUnlock() + // Get shard based on sfKey for items + shard := s.shardFor(sfKey) + + // First, try to get the item with a read lock on its shard + shard.mu.RLock() + item, exists := shard.items[sfKey] + shard.mu.RUnlock() if exists { return sfKey, fetchKey, item, true } // If not exists, acquire a write lock to create the item - s.mu.Lock() + shard.mu.Lock() // Double-check if the item was created while acquiring the write lock - item, exists = s.items[sfKey] + item, exists = shard.items[sfKey] if exists { - s.mu.Unlock() + shard.mu.Unlock() return sfKey, fetchKey, item, true } @@ -83,15 +103,16 @@ func (s *SingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extra // empty chan to indicate to all followers when we're done (close) loaded: make(chan struct{}), } - if size, ok := s.sizes[fetchKey]; ok { + // Read size hint from the same shard (both items and sizes use the same shard now) + if size, ok := shard.sizes[fetchKey]; ok { item.sizeHint = size.totalBytes / size.count } - s.items[sfKey] = item - s.mu.Unlock() + shard.items[sfKey] = item + shard.mu.Unlock() return sfKey, fetchKey, item, false } -func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { +func (s *SubgraphRequestSingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) sfKey = s.sfKey(h, fetchItem, input, extraKey) h.Reset() @@ -103,7 +124,7 @@ func (s *SingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) // sfKey returns a key that 100% uniquely identifies a fetch with no collision // two sfKey are only the same when the fetches are 100% equal -func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { +func (s *SubgraphRequestSingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -119,7 +140,7 @@ func (s *SingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byt // the purpose is to create a key from the DataSourceID and root fields to have less cardinality // the goal is to get an estimate buffer size for similar fetches // there's no point in hashing headers or the body for this purpose -func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { +func (s *SubgraphRequestSingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { if fetchItem == nil || fetchItem.Fetch == nil { return 0 } @@ -128,13 +149,13 @@ func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { return 0 } _, _ = h.WriteString(info.DataSourceID) - _, _ = h.WriteString("|") + _, _ = h.Write(pipe) for i := range info.RootFields { if i != 0 { - _, _ = h.WriteString(",") + _, _ = h.Write(comma) } _, _ = h.WriteString(info.RootFields[i].TypeName) - _, _ = h.WriteString(".") + _, _ = h.Write(dot) _, _ = h.WriteString(info.RootFields[i].FieldName) } return h.Sum64() @@ -143,11 +164,13 @@ func (s *SingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { // Finish is for the leader to mark the SingleFlightItem as "done" // trigger all followers to look at the err & response of the item // and to update the size estimates -func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { +func (s *SubgraphRequestSingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { close(item.loaded) - s.mu.Lock() - delete(s.items, sfKey) - if size, ok := s.sizes[fetchKey]; ok { + // Update sizes in the same shard as the item (using sfKey to get the shard) + shard := s.shardFor(sfKey) + shard.mu.Lock() + delete(shard.items, sfKey) + if size, ok := shard.sizes[fetchKey]; ok { if size.count == 50 { size.count = 1 size.totalBytes = size.totalBytes / 50 @@ -155,10 +178,15 @@ func (s *SingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { size.count++ size.totalBytes += len(item.response) } else { - s.sizes[fetchKey] = &fetchSize{ + shard.sizes[fetchKey] = &fetchSize{ count: 1, totalBytes: len(item.response), } } - s.mu.Unlock() + shard.mu.Unlock() +} + +func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { + idx := int(key % uint64(len(s.shards))) + return &s.shards[idx] } From cd59d03f8ea2b60440b28011850a3f7997bc0b0f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 20:23:51 +0100 Subject: [PATCH 31/57] chore: refactor --- .../resolve/inbound_request_singleflight.go | 20 +++++++++---------- v2/pkg/engine/resolve/loader.go | 2 +- v2/pkg/engine/resolve/resolve.go | 6 +++++- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 995ee390c7..1dbe8c9a74 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -1,8 +1,10 @@ package resolve import ( + "encoding/binary" "sync" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -61,18 +63,16 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL return nil, nil } - // ctx.Request.ID is the unique ID of the normalized GraphQL document +1 (offset) - key := ctx.Request.ID + 1 - // ctx.VariablesHash is the hash of the normalized variables from the client request - // this makes the key unique across different variables - key += ctx.VariablesHash + 1 + // Derive a robust key from request ID, variables hash and (optional) headers hash + var b [24]byte + binary.LittleEndian.PutUint64(b[0:8], ctx.Request.ID) + binary.LittleEndian.PutUint64(b[8:16], ctx.VariablesHash) + hh := uint64(0) if ctx.SubgraphHeadersBuilder != nil { - // ctx.SubgraphHeadersBuilder.HashAll() returns the hash of all headers that will be forwarded to all subgraphs - // this makes the key unique across different client request headers, given that we forward them - // we pre-compute all headers that will be forwarded to each subgraph - // if we combine all the subgraph header hashes, the key will be stable across all headers - key += ctx.SubgraphHeadersBuilder.HashAll() + hh = ctx.SubgraphHeadersBuilder.HashAll() } + binary.LittleEndian.PutUint64(b[16:24], hh) + key := xxhash.Sum64(b[:]) shard := r.shardFor(key) shard.mu.Lock() diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 63cda90b28..88ef6fec25 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1642,7 +1642,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) if res.singleFlightStats != nil { - res.singleFlightStats.used = shared + res.singleFlightStats.used = true res.singleFlightStats.shared = shared } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index dc1f0ba851..b93888a79d 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -5,6 +5,7 @@ package resolve import ( "bytes" "context" + "encoding/binary" "fmt" "io" "net/http" @@ -1104,7 +1105,10 @@ func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) keyGen := pool.Hash64.Get() _, _ = keyGen.Write(input) - triggerID = keyGen.Sum64() + headerHash + var b [8]byte + binary.LittleEndian.PutUint64(b[:], headerHash) + _, _ = keyGen.Write(b[:]) + triggerID = keyGen.Sum64() pool.Hash64.Put(keyGen) return header, triggerID } From c579f4898d41ff07ef19746e75dfed4d35d783df Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 20:24:32 +0100 Subject: [PATCH 32/57] chore: fmt --- v2/pkg/engine/resolve/inbound_request_singleflight.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 1dbe8c9a74..6db40dc707 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) From 319126c5c61ee5eb75571f9b6af64b52f9aed45a Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 27 Oct 2025 20:35:11 +0100 Subject: [PATCH 33/57] chore: fix test --- .../engine/testdata/complex_nesting_query_with_art.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/execution/engine/testdata/complex_nesting_query_with_art.json b/execution/engine/testdata/complex_nesting_query_with_art.json index 69a208fe47..ec85c1e5c1 100644 --- a/execution/engine/testdata/complex_nesting_query_with_art.json +++ b/execution/engine/testdata/complex_nesting_query_with_art.json @@ -170,7 +170,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -310,7 +310,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { @@ -496,7 +496,7 @@ "duration_since_start_pretty": "1ns", "duration_load_nanoseconds": 1, "duration_load_pretty": "1ns", - "single_flight_used": false, + "single_flight_used": true, "single_flight_shared_response": false, "load_skipped": false, "load_stats": { From 0bf8fb37ad1272532e1c81c9bf4ef0f6b75d7ddf Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 08:54:54 +0100 Subject: [PATCH 34/57] chore: refactor --- v2/pkg/engine/resolve/context.go | 10 ++++++++-- .../resolve/inbound_request_singleflight.go | 7 +++---- v2/pkg/engine/resolve/loader.go | 17 ++++++++++++++--- v2/pkg/engine/resolve/response.go | 13 +++++++++++++ 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index d6a8657e46..5783b29a56 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -67,8 +67,14 @@ type ExecutionOptions struct { IncludeQueryPlanInResponse bool // SendHeartbeat sends regular HeartBeats for Subscriptions SendHeartbeat bool - // DisableRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. - DisableRequestDeduplication bool + // DisableSubgraphRequestDeduplication disables deduplication of requests to the same subgraph with the same input within a single operation execution. + DisableSubgraphRequestDeduplication bool + // DisableInboundRequestDeduplication disables deduplication of inbound client requests + // The engine is hashing the normalized operation, variables, and forwarded headers to achieve robust deduplication + // By default, overhead is negligible and as such this should be false (not disabled) most of the time + // However, if you're benchmarking internals of the engine, it can be helpful to switch it off + // When disabled (set to true) the code becomes a no-op + DisableInboundRequestDeduplication bool } type FieldValue struct { diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 6db40dc707..f5ad8eb4a1 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -5,8 +5,6 @@ import ( "sync" "github.com/cespare/xxhash/v2" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) // InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests @@ -54,13 +52,14 @@ type InflightRequest struct { // GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed // It returns an error if the leader returned an error // It returns nil,nil if the inbound request is not eligible for request deduplication +// or if DisableSubgraphRequestDeduplication or DisableInboundRequestDeduplication is set to true on Context func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { - if ctx.ExecutionOptions.DisableRequestDeduplication { + if ctx.ExecutionOptions.DisableSubgraphRequestDeduplication || ctx.ExecutionOptions.DisableInboundRequestDeduplication { return nil, nil } - if response != nil && response.Info != nil && response.Info.OperationType == ast.OperationTypeMutation { + if !response.SingleFlightAllowed() { return nil, nil } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 88ef6fec25..893b70638d 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1625,6 +1625,19 @@ func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, u return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) } +func (l *Loader) singleFlightAllowed() bool { + if l.ctx.ExecutionOptions.DisableSubgraphRequestDeduplication { + return false + } + if l.info == nil { + return false + } + if l.info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem *FetchItem, input []byte, res *result) error { if l.info != nil { @@ -1633,9 +1646,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem headers, extraKey := l.headersForSubgraphRequest(fetchItem) - if l.info == nil || - l.info.OperationType == ast.OperationTypeMutation || - l.ctx.ExecutionOptions.DisableRequestDeduplication { + if !l.singleFlightAllowed() { // Disable single flight for mutations return l.loadByContextDirect(ctx, source, headers, input, res) } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index 1efe078cca..d8af8d017b 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -43,6 +43,19 @@ type GraphQLResponse struct { DataSources []DataSourceInfo } +func (g *GraphQLResponse) SingleFlightAllowed() bool { + if g == nil { + return false + } + if g.Info == nil { + return false + } + if g.Info.OperationType == ast.OperationTypeQuery { + return true + } + return false +} + type GraphQLResponseInfo struct { OperationType ast.OperationType } From 1ae36b46599e570b4ef31eca674a64c28297040f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 09:33:24 +0100 Subject: [PATCH 35/57] chore: refactor --- v2/pkg/engine/resolve/inbound_request_singleflight.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index f5ad8eb4a1..66505a36a4 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -52,10 +52,10 @@ type InflightRequest struct { // GetOrCreate blocks until ctx.ctx.Done() returns or InflightRequest.Done is closed // It returns an error if the leader returned an error // It returns nil,nil if the inbound request is not eligible for request deduplication -// or if DisableSubgraphRequestDeduplication or DisableInboundRequestDeduplication is set to true on Context +// or if DisableInboundRequestDeduplication is set to true on Context func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQLResponse) (*InflightRequest, error) { - if ctx.ExecutionOptions.DisableSubgraphRequestDeduplication || ctx.ExecutionOptions.DisableInboundRequestDeduplication { + if ctx.ExecutionOptions.DisableInboundRequestDeduplication { return nil, nil } From 57e688cc32728979bf942354d2dace6178160763 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 28 Oct 2025 10:06:54 +0100 Subject: [PATCH 36/57] chore: allow single flight in loader for sub Queries, even if root operation type is Mutation or Subscription --- v2/pkg/engine/resolve/loader.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 893b70638d..a33242bc1d 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1625,14 +1625,25 @@ func (l *Loader) headersForSubgraphRequest(fetchItem *FetchItem) (http.Header, u return l.ctx.HeadersForSubgraphRequest(info.DataSourceName) } -func (l *Loader) singleFlightAllowed() bool { +// singleFlightAllowed returns true if the specific GraphQL Operation is a Query +// even if the root operation type is a Mutation or Subscription +// sub-operations can still be of type Query +// even in such cases we allow request de-duplication because such requests are idempotent +func (l *Loader) singleFlightAllowed(fetchItem *FetchItem) bool { if l.ctx.ExecutionOptions.DisableSubgraphRequestDeduplication { return false } - if l.info == nil { + if fetchItem == nil { return false } - if l.info.OperationType == ast.OperationTypeQuery { + if fetchItem.Fetch == nil { + return false + } + info := fetchItem.Fetch.FetchInfo() + if info == nil { + return false + } + if info.OperationType == ast.OperationTypeQuery { return true } return false @@ -1646,7 +1657,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem headers, extraKey := l.headersForSubgraphRequest(fetchItem) - if !l.singleFlightAllowed() { + if !l.singleFlightAllowed(fetchItem) { // Disable single flight for mutations return l.loadByContextDirect(ctx, source, headers, input, res) } From 8f3e30f68444125efe5a83f08cd241f3037f4a11 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 09:12:39 +0100 Subject: [PATCH 37/57] chore: improve arena pool & add tests --- v2/pkg/engine/resolve/arena.go | 16 +- v2/pkg/engine/resolve/arena_test.go | 257 ++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+), 6 deletions(-) create mode 100644 v2/pkg/engine/resolve/arena_test.go diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index cca1f33125..98bd930873 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -48,13 +48,17 @@ func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { defer p.mu.Unlock() // Try to find an available arena in the pool - for i := 0; i < len(p.pool); i++ { - v := p.pool[i].Value() - p.pool = append(p.pool[:i], p.pool[i+1:]...) - if v == nil { - continue + for len(p.pool) > 0 { + // Pop the last item + lastIdx := len(p.pool) - 1 + wp := p.pool[lastIdx] + p.pool = p.pool[:lastIdx] + + v := wp.Value() + if v != nil { + return v } - return v + // If weak pointer was nil (GC collected), continue to next item } // No arena available, create a new one diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go new file mode 100644 index 0000000000..a6bb0f5570 --- /dev/null +++ b/v2/pkg/engine/resolve/arena_test.go @@ -0,0 +1,257 @@ +package resolve + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/go-arena" +) + +func TestNewArenaPool(t *testing.T) { + pool := NewArenaPool() + + require.NotNil(t, pool, "NewArenaPool returned nil") + assert.Equal(t, 0, len(pool.pool), "expected empty pool") + assert.Equal(t, 0, len(pool.sizes), "expected empty sizes map") +} + +func TestArenaPool_Acquire_EmptyPool(t *testing.T) { + pool := NewArenaPool() + + item := pool.Acquire(1) + + require.NotNil(t, item, "Acquire returned nil") + assert.NotNil(t, item.Arena, "Arena is nil") + + // Verify we can use the arena + buf := arena.NewArenaBuffer(item.Arena) + buf.WriteString("test") + + assert.Equal(t, 0, len(pool.pool), "pool should still be empty") +} + +func TestArenaPool_ReleaseAndAcquire(t *testing.T) { + pool := NewArenaPool() + id := uint64(42) + + // Acquire first arena + item1 := pool.Acquire(id) + + // Use the arena + buf := arena.NewArenaBuffer(item1.Arena) + buf.WriteString("test data") + + // Release it + pool.Release(id, item1) + + // Pool should have one item + assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") + + // Acquire from pool + item2 := pool.Acquire(id) + + require.NotNil(t, item2, "Acquire returned nil") + + // Pool should be empty again + assert.Equal(t, 0, len(pool.pool), "expected empty pool after acquire") + + // The acquired arena should be reset and usable + buf2 := arena.NewArenaBuffer(item2.Arena) + buf2.WriteString("new data") + + assert.Equal(t, "new data", buf2.String()) +} + +func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { + // This test specifically proves the bug fix works + // Creates multiple items, clears some references, then acquires + // to ensure all items are checked without skipping + pool := NewArenaPool() + id := uint64(800) + + numItems := 10 + items := make([]*ArenaPoolItem, numItems) + + // Acquire all items + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("item data") + assert.NoError(t, err) + } + + // Release all while keeping strong references + for i := 0; i < numItems; i++ { + pool.Release(id, items[i]) + } + + // Pool should have all items + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Clear every other item to simulate partial GC + for i := 0; i < numItems; i += 2 { + items[i] = nil + } + + // Force GC + runtime.GC() + runtime.GC() + + // Acquire items - should process ALL items without skipping + processed := 0 + acquired := 0 + + for len(pool.pool) > 0 && processed < numItems*2 { + poolSizeBefore := len(pool.pool) + item := pool.Acquire(id) + poolSizeAfter := len(pool.pool) + processed++ + + assert.Less(t, poolSizeAfter, poolSizeBefore, "Pool size did not decrease - item not removed properly!") + + if item != nil { + acquired++ + } + } + + // Pool should be empty + assert.Equal(t, 0, len(pool.pool), "expected empty pool") +} + +func TestArenaPool_Release_PeakTracking(t *testing.T) { + pool := NewArenaPool() + id := uint64(200) + + // First arena + item1 := pool.Acquire(id) + buf1 := arena.NewArenaBuffer(item1.Arena) + _, err := buf1.WriteString("small") + assert.NoError(t, err) + + peak1 := item1.Arena.Peak() + assert.Equal(t, peak1, 5) + + pool.Release(id, item1) + + // Check that size was tracked + size, exists := pool.sizes[id] + require.True(t, exists, "size tracking not created") + assert.Equal(t, 1, size.count, "expected count 1") + + // Second arena + item2 := pool.Acquire(id) + buf2 := arena.NewArenaBuffer(item2.Arena) + _, err = buf2.WriteString("larger data") + assert.NoError(t, err) + + pool.Release(id, item2) + + // Check updated tracking + assert.Equal(t, 2, size.count, "expected count 2") +} + +func TestArenaPool_GetArenaSize(t *testing.T) { + pool := NewArenaPool() + + // Test default size for unknown ID + size1 := pool.getArenaSize(999) + expectedDefault := 1024 * 1024 + assert.Equal(t, expectedDefault, size1, "expected default size") + + // Test calculated size after usage + id := uint64(400) + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("some data") + assert.NoError(t, err) + pool.Release(id, item) + + size2 := pool.getArenaSize(id) + assert.NotEqual(t, 0, size2, "expected non-zero size after usage") +} + +func TestArenaPool_MultipleItemsInPool(t *testing.T) { + pool := NewArenaPool() + id := uint64(500) + + // Acquire multiple distinct items + numItems := 3 + items := make([]*ArenaPoolItem, numItems) + + for i := 0; i < numItems; i++ { + items[i] = pool.Acquire(id) + buf := arena.NewArenaBuffer(items[i].Arena) + _, err := buf.WriteString("data") + assert.NoError(t, err) + } + + // Release all while keeping references + for i := 0; i < numItems; i++ { + pool.Release(id, items[i]) + } + + // Should have all items in pool + assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + + // Acquire all back + acquired := 0 + for len(pool.pool) > 0 { + item := pool.Acquire(id) + if item != nil { + acquired++ + } + } + + assert.Equal(t, numItems, acquired, "expected to acquire all items") +} + +func TestArenaPool_Release_MovingWindow(t *testing.T) { + pool := NewArenaPool() + id := uint64(600) + + // Release exactly 50 items + for i := 0; i < 50; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("test data") + assert.NoError(t, err) + pool.Release(id, item) + } + + // After 50 releases, verify count and total + size := pool.sizes[id] + require.NotNil(t, size, "size tracking should exist") + assert.Equal(t, 50, size.count, "expected count to be 50") + + totalBytesAfter50 := size.totalBytes + + // Release one more item to trigger the window reset + item51 := pool.Acquire(id) + buf51 := arena.NewArenaBuffer(item51.Arena) + _, err := buf51.WriteString("test data") + assert.NoError(t, err) + peak51 := item51.Arena.Peak() + pool.Release(id, item51) + + // After 51st release, verify the window was reset + // count should be 2 (reset to 1, then incremented) + // totalBytes should be (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, 2, size.count, "expected count to be 2 after window reset") + + expectedTotalBytes := (totalBytesAfter50 / 50) + peak51 + assert.Equal(t, expectedTotalBytes, size.totalBytes, "expected totalBytes to be divided by 50 and new peak added") + + // Verify we can continue releasing and counting works correctly + for i := 0; i < 10; i++ { + item := pool.Acquire(id) + buf := arena.NewArenaBuffer(item.Arena) + _, err := buf.WriteString("more data") + assert.NoError(t, err) + pool.Release(id, item) + } + + // After 10 more releases, count should be 12 (2 + 10) + assert.Equal(t, 12, size.count, "expected count to continue incrementing after window reset") +} From 3df9e01d9dcf796ad6910a3266ad9a788a8d89a0 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 09:13:04 +0100 Subject: [PATCH 38/57] chore: use arena in Walker --- v2/pkg/astvisitor/visitor.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index a2cbb102da..bd48ad6923 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" @@ -94,6 +95,8 @@ type Walker struct { deferred []func() OnExternalError func(err *operationreport.ExternalError) + + arena arena.Arena } // NewWalker returns a fully initialized Walker @@ -125,6 +128,9 @@ func WalkerFromPool() *Walker { } func (w *Walker) Release() { + if w.arena != nil { + w.arena.Reset() + } w.ResetVisitors() w.Report = nil w.document = nil @@ -1370,6 +1376,11 @@ func (w *Walker) Walk(document, definition *ast.Document, report *operationrepor } else { w.Report = report } + if w.arena == nil { + w.arena = arena.NewMonotonicArena(arena.WithMinBufferSize(64)) + } else { + w.arena.Reset() + } w.Ancestors = w.Ancestors[:0] w.Path = w.Path[:0] w.TypeDefinitions = w.TypeDefinitions[:0] @@ -1822,8 +1833,7 @@ func (w *Walker) walkSelectionSet(ref int, skipFor SkipVisitors) { RefsChanged: for { - refs := make([]int, 0, len(w.document.SelectionSets[ref].SelectionRefs)) - refs = append(refs, w.document.SelectionSets[ref].SelectionRefs...) + refs := arena.SliceAppend(w.arena, nil, w.document.SelectionSets[ref].SelectionRefs...) for i, j := range refs { From aa789e070ea383a384355228b2b22e5061451d50 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 29 Oct 2025 09:14:37 +0100 Subject: [PATCH 39/57] chore: fix lint --- v2/pkg/astvisitor/visitor.go | 1 + v2/pkg/engine/resolve/arena_test.go | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index bd48ad6923..86a29c0c7a 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go index a6bb0f5570..20c1069b86 100644 --- a/v2/pkg/engine/resolve/arena_test.go +++ b/v2/pkg/engine/resolve/arena_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/wundergraph/go-arena" ) @@ -27,7 +28,8 @@ func TestArenaPool_Acquire_EmptyPool(t *testing.T) { // Verify we can use the arena buf := arena.NewArenaBuffer(item.Arena) - buf.WriteString("test") + _, err := buf.WriteString("test") + assert.NoError(t, err) assert.Equal(t, 0, len(pool.pool), "pool should still be empty") } @@ -41,7 +43,8 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { // Use the arena buf := arena.NewArenaBuffer(item1.Arena) - buf.WriteString("test data") + _, err := buf.WriteString("test data") + assert.NoError(t, err) // Release it pool.Release(id, item1) @@ -59,7 +62,8 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { // The acquired arena should be reset and usable buf2 := arena.NewArenaBuffer(item2.Arena) - buf2.WriteString("new data") + _, err = buf2.WriteString("new data") + assert.NoError(t, err) assert.Equal(t, "new data", buf2.String()) } From d8f04cabe7e7be1d748c657e315ed9e38b15a52b Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Sun, 2 Nov 2025 13:33:20 +0100 Subject: [PATCH 40/57] chore: refactor arena handling --- .../grpc_datasource/grpc_datasource.go | 51 ++++++++++++++----- .../grpc_datasource/grpc_datasource_test.go | 12 ++--- .../grpc_datasource/json_builder.go | 6 +-- v2/pkg/engine/resolve/arena.go | 47 +++++++++++++++-- v2/pkg/engine/resolve/arena_test.go | 18 +++---- v2/pkg/engine/resolve/resolve.go | 20 ++++---- 6 files changed, 106 insertions(+), 48 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index c9c37891fa..6cbc4ca125 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -8,11 +8,11 @@ package grpcdatasource import ( "context" + "encoding/binary" "fmt" "net/http" - "sync" - "errors" + "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -46,6 +46,8 @@ type DataSource struct { mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations disabled bool + + pool *resolve.ArenaPool } type ProtoConfig struct { @@ -81,6 +83,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D mapping: config.Mapping, federationConfigs: config.FederationConfigs, disabled: config.Disabled, + pool: resolve.NewArenaPool(), }, nil } @@ -93,15 +96,23 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") - builder := newJSONBuilder(d.mapping, variables) + + var ( + poolItems []*resolve.ArenaPoolItem + ) + defer func() { + d.pool.ReleaseMany(poolItems) + }() + + item := d.acquirePoolItem(input, 0) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil } - arena := astjson.Arena{} - defer arena.Reset() - root := arena.NewObject() + root := astjson.ObjectValue(nil) failed := false @@ -116,8 +127,10 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte // make gRPC calls for index, serviceCall := range serviceCalls { + item := d.acquirePoolItem(input, index) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) errGrp.Go(func() error { - a := astjson.Arena{} // Invoke the gRPC method - this will populate serviceCall.Output err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) @@ -125,7 +138,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return err } - response, err := builder.marshalResponseJSON(&a, &serviceCall.RPC.Response, serviceCall.Output) + response, err := builder.marshalResponseJSON(&serviceCall.RPC.Response, serviceCall.Output) if err != nil { return err } @@ -150,7 +163,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte } if err := errGrp.Wait(); err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) failed = true return nil } @@ -163,19 +176,29 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte root, err = builder.mergeValues(root, result.response) } if err != nil { - out.Write(builder.writeErrorBytes(err)) + data = builder.writeErrorBytes(err) return err } } return nil }); err != nil || failed { - return err + return data, err } - data := builder.toDataObject(root) - out.Write(data.MarshalTo(nil)) - return nil + value := builder.toDataObject(root) + return value.MarshalTo(nil), err +} + +func (d *DataSource) acquirePoolItem(input []byte, index int) *resolve.ArenaPoolItem { + keyGen := xxhash.New() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(index)) + _, _ = keyGen.Write(b[:]) + key := keyGen.Sum64() + item := d.pool.Acquire(key) + return item } // LoadWithFiles implements resolve.DataSource interface. diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 9b4d6be438..9a427809a9 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -54,8 +54,7 @@ func Benchmark_DataSource_Load(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - output := new(bytes.Buffer) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), output) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -93,7 +92,7 @@ func Benchmark_DataSource_Load_WithFieldArguments(b *testing.B) { }) require.NoError(b, err) - err = ds.Load(context.Background(), []byte(`{"query":"`+query+`","body":`+variables+`}`), new(bytes.Buffer)) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(b, err) } } @@ -564,7 +563,7 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - jsonBuilder := newJSONBuilder(nil, gjson.Result{}) + jsonBuilder := newJSONBuilder(nil, nil, gjson.Result{}) responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) @@ -3723,15 +3722,14 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + output, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(output, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 7eb8745141..0b2edc07c2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -114,12 +114,12 @@ type jsonBuilder struct { // newJSONBuilder creates a new JSON builder instance with the provided mapping // and variables. The builder automatically creates an index map for proper // federation entity ordering if representations are present in the variables. -func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { +func newJSONBuilder(a arena.Arena, mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { return &jsonBuilder{ mapping: mapping, variables: variables, indexMap: createRepresentationIndexMap(variables), - jsonArena: arena.NewMonotonicArena(), + jsonArena: a, } } @@ -259,7 +259,7 @@ func (j *jsonBuilder) mergeWithPath(base *astjson.Value, resolved *astjson.Value } for i := range responseValues { - responseValues[i].Set(elementName, resolvedValues[i].Get(elementName)) + responseValues[i].Set(j.jsonArena, elementName, resolvedValues[i].Get(elementName)) } return nil diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go index 98bd930873..7909460b29 100644 --- a/v2/pkg/engine/resolve/arena.go +++ b/v2/pkg/engine/resolve/arena.go @@ -32,6 +32,7 @@ type arenaPoolItemSize struct { // ArenaPoolItem wraps an arena.Arena for use in the pool type ArenaPoolItem struct { Arena arena.Arena + Key uint64 } // NewArenaPool creates a new ArenaPool instance @@ -43,7 +44,7 @@ func NewArenaPool() *ArenaPool { // Acquire gets an arena from the pool or creates a new one if none are available. // The id parameter is used to track arena sizes per use case for optimization. -func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { +func (p *ArenaPool) Acquire(key uint64) *ArenaPoolItem { p.mu.Lock() defer p.mu.Unlock() @@ -56,21 +57,23 @@ func (p *ArenaPool) Acquire(id uint64) *ArenaPoolItem { v := wp.Value() if v != nil { + v.Key = key return v } // If weak pointer was nil (GC collected), continue to next item } // No arena available, create a new one - size := arena.WithMinBufferSize(p.getArenaSize(id)) + size := arena.WithMinBufferSize(p.getArenaSize(key)) return &ArenaPoolItem{ Arena: arena.NewMonotonicArena(size), + Key: key, } } // Release returns an arena to the pool for reuse. // The peak memory usage is recorded to optimize future arena sizes for this use case. -func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { +func (p *ArenaPool) Release(item *ArenaPoolItem) { peak := item.Arena.Peak() item.Arena.Reset() @@ -78,7 +81,7 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { defer p.mu.Unlock() // Record the peak usage for this use case - if size, ok := p.sizes[id]; ok { + if size, ok := p.sizes[item.Key]; ok { if size.count == 50 { size.count = 1 size.totalBytes = size.totalBytes / 50 @@ -86,17 +89,51 @@ func (p *ArenaPool) Release(id uint64, item *ArenaPoolItem) { size.count++ size.totalBytes += peak } else { - p.sizes[id] = &arenaPoolItemSize{ + p.sizes[item.Key] = &arenaPoolItemSize{ count: 1, totalBytes: peak, } } + item.Key = 0 + // Add the arena back to the pool using a weak pointer w := weak.Make(item) p.pool = append(p.pool, w) } +func (p *ArenaPool) ReleaseMany(items []*ArenaPoolItem) { + p.mu.Lock() + defer p.mu.Unlock() + + for _, item := range items { + + peak := item.Arena.Peak() + item.Arena.Reset() + + // Record the peak usage for this use case + if size, ok := p.sizes[item.Key]; ok { + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 + } + size.count++ + size.totalBytes += peak + } else { + p.sizes[item.Key] = &arenaPoolItemSize{ + count: 1, + totalBytes: peak, + } + } + + item.Key = 0 + + // Add the arena back to the pool using a weak pointer + w := weak.Make(item) + p.pool = append(p.pool, w) + } +} + // getArenaSize returns the optimal arena size for a given use case ID. // If no size is recorded, it defaults to 1MB. func (p *ArenaPool) getArenaSize(id uint64) int { diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go index 20c1069b86..c884434f18 100644 --- a/v2/pkg/engine/resolve/arena_test.go +++ b/v2/pkg/engine/resolve/arena_test.go @@ -47,7 +47,7 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { assert.NoError(t, err) // Release it - pool.Release(id, item1) + pool.Release(item1) // Pool should have one item assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") @@ -88,7 +88,7 @@ func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { // Release all while keeping strong references for i := 0; i < numItems; i++ { - pool.Release(id, items[i]) + pool.Release(items[i]) } // Pool should have all items @@ -137,7 +137,7 @@ func TestArenaPool_Release_PeakTracking(t *testing.T) { peak1 := item1.Arena.Peak() assert.Equal(t, peak1, 5) - pool.Release(id, item1) + pool.Release(item1) // Check that size was tracked size, exists := pool.sizes[id] @@ -150,7 +150,7 @@ func TestArenaPool_Release_PeakTracking(t *testing.T) { _, err = buf2.WriteString("larger data") assert.NoError(t, err) - pool.Release(id, item2) + pool.Release(item2) // Check updated tracking assert.Equal(t, 2, size.count, "expected count 2") @@ -170,7 +170,7 @@ func TestArenaPool_GetArenaSize(t *testing.T) { buf := arena.NewArenaBuffer(item.Arena) _, err := buf.WriteString("some data") assert.NoError(t, err) - pool.Release(id, item) + pool.Release(item) size2 := pool.getArenaSize(id) assert.NotEqual(t, 0, size2, "expected non-zero size after usage") @@ -193,7 +193,7 @@ func TestArenaPool_MultipleItemsInPool(t *testing.T) { // Release all while keeping references for i := 0; i < numItems; i++ { - pool.Release(id, items[i]) + pool.Release(items[i]) } // Should have all items in pool @@ -221,7 +221,7 @@ func TestArenaPool_Release_MovingWindow(t *testing.T) { buf := arena.NewArenaBuffer(item.Arena) _, err := buf.WriteString("test data") assert.NoError(t, err) - pool.Release(id, item) + pool.Release(item) } // After 50 releases, verify count and total @@ -237,7 +237,7 @@ func TestArenaPool_Release_MovingWindow(t *testing.T) { _, err := buf51.WriteString("test data") assert.NoError(t, err) peak51 := item51.Arena.Peak() - pool.Release(id, item51) + pool.Release(item51) // After 51st release, verify the window was reset // count should be 2 (reset to 1, then incremented) @@ -253,7 +253,7 @@ func TestArenaPool_Release_MovingWindow(t *testing.T) { buf := arena.NewArenaBuffer(item.Arena) _, err := buf.WriteString("more data") assert.NoError(t, err) - pool.Release(id, item) + pool.Release(item) } // After 10 more releases, count should be 12 (2 + 10) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b93888a79d..747ee02c4e 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -342,7 +342,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe err = t.resolvable.Init(ctx, nil, response.Info.OperationType) if err != nil { r.inboundRequestSingleFlight.FinishErr(inflight, err) - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) return nil, err } @@ -350,7 +350,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { r.inboundRequestSingleFlight.FinishErr(inflight, err) - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) return nil, err } } @@ -361,14 +361,14 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe err = t.resolvable.Resolve(ctx.ctx, response.Data, response.Fetches, buf) if err != nil { r.inboundRequestSingleFlight.FinishErr(inflight, err) - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) - r.responseBufferPool.Release(ctx.Request.ID, responseArena) + r.resolveArenaPool.Release(resolveArena) + r.responseBufferPool.Release(responseArena) return nil, err } // first release resolverArena // all data is resolved and written into the response arena - r.resolveArenaPool.Release(ctx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) // next we write back to the client // this includes flushing and syscalls // as such, it can take some time @@ -377,7 +377,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) // all data is written to the client // we're safe to release our buffer - r.responseBufferPool.Release(ctx.Request.ID, responseArena) + r.responseBufferPool.Release(responseArena) return resp, err } @@ -515,7 +515,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar t := newTools(r.options, r.allowedErrorExtensionFields, r.allowedErrorFields, r.subgraphRequestSingleFlight, resolveArena.Arena) if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) @@ -527,7 +527,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) @@ -539,7 +539,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) @@ -550,7 +550,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } - r.resolveArenaPool.Release(resolveCtx.Request.ID, resolveArena) + r.resolveArenaPool.Release(resolveArena) if err := sub.writer.Flush(); err != nil { // If flush fails (e.g. client disconnected), remove the subscription. From dd00412003ab62a9fc5a420792dafdddab34c84d Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Fri, 14 Nov 2025 20:27:01 +0100 Subject: [PATCH 41/57] chore: use sync.Map, cleanup --- .../graphql_datasource_test.go | 7 +- .../grpc_datasource/grpc_datasource_test.go | 4 +- .../resolve/inbound_request_singleflight.go | 42 ++-- v2/pkg/engine/resolve/loader.go | 4 +- .../resolve/subgraph_request_singleflight.go | 176 ++++++++------- .../subgraph_request_singleflight_test.go | 209 ++++++++++++++++++ 6 files changed, 325 insertions(+), 117 deletions(-) create mode 100644 v2/pkg/engine/resolve/subgraph_request_singleflight_test.go diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 9654eb1065..cb54ab8a83 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -8866,8 +8866,10 @@ func TestLoadFiles(t *testing.T) { input = httpclient.SetInputURL(input, []byte(serverUrl)) ctx := context.Background() - _, err = src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) + got, err := src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{httpclient.NewFileUpload(f.Name(), fileName, "variables.file")}) require.NoError(t, err) + require.Equal(t, []byte{}, got) + }) t.Run("multiple files", func(t *testing.T) { @@ -8921,11 +8923,12 @@ func TestLoadFiles(t *testing.T) { assert.NoError(t, err) ctx := context.Background() - _, err = src.LoadWithFiles(ctx, nil, input, + got, err := src.LoadWithFiles(ctx, nil, input, []*httpclient.FileUpload{ httpclient.NewFileUpload(f1.Name(), file1Name, "variables.files.0"), httpclient.NewFileUpload(f2.Name(), file2Name, "variables.files.1")}) require.NoError(t, err) + require.Equal(t, []byte{}, got) }) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 9a427809a9..de66be94ac 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -219,10 +219,8 @@ func Test_DataSource_Load(t *testing.T) { require.NoError(t, err) - output, err := ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) require.NoError(t, err) - - fmt.Println(string(output)) } // Test_DataSource_Load_WithMockService tests the datasource.Load method with an actual gRPC server diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 66505a36a4..2552a43fd6 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -16,8 +16,7 @@ type InboundRequestSingleFlight struct { } type requestShard struct { - mu sync.Mutex - m map[uint64]*InflightRequest + m sync.Map } const defaultRequestSingleFlightShardCount = 4 @@ -31,20 +30,17 @@ func NewRequestSingleFlight(shardCount int) *InboundRequestSingleFlight { r := &InboundRequestSingleFlight{ shards: make([]requestShard, shardCount), } - for i := range r.shards { - r.shards[i] = requestShard{ - m: make(map[uint64]*InflightRequest), - } - } return r } type InflightRequest struct { - Done chan struct{} - Data []byte - Err error - ID uint64 + Done chan struct{} + Data []byte + Err error + ID uint64 + HasFollowers bool + Mu sync.Mutex } // GetOrCreate creates a new InflightRequest or returns an existing (shared) one @@ -75,11 +71,12 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL key := xxhash.Sum64(b[:]) shard := r.shardFor(key) - shard.mu.Lock() - req, shared := shard.m[key] + req, shared := shard.m.Load(key) if shared { + req := req.(*InflightRequest) + req.Mu.Lock() req.HasFollowers = true - shard.mu.Unlock() + req.Mu.Unlock() select { case <-req.Done: if req.Err != nil { @@ -91,14 +88,13 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL } } - req = &InflightRequest{ + value := &InflightRequest{ Done: make(chan struct{}), ID: key, } - shard.m[key] = req - shard.mu.Unlock() - return req, nil + shard.m.Store(key, value) + return value, nil } func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) { @@ -106,10 +102,10 @@ func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) return } shard := r.shardFor(req.ID) - shard.mu.Lock() - delete(shard.m, req.ID) + shard.m.Delete(req.ID) + req.Mu.Lock() hasFollowers := req.HasFollowers - shard.mu.Unlock() + req.Mu.Unlock() if hasFollowers { // optimization to only copy when we actually have to req.Data = make([]byte, len(data)) @@ -123,9 +119,7 @@ func (r *InboundRequestSingleFlight) FinishErr(req *InflightRequest, err error) return } shard := r.shardFor(req.ID) - shard.mu.Lock() - delete(shard.m, req.ID) - shard.mu.Unlock() + shard.m.Delete(req.ID) req.Err = err close(req.Done) } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index a33242bc1d..23f0b6d327 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1662,7 +1662,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return l.loadByContextDirect(ctx, source, headers, input, res) } - sfKey, fetchKey, item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) + item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) if res.singleFlightStats != nil { res.singleFlightStats.used = true res.singleFlightStats.shared = shared @@ -1686,7 +1686,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem // helps the http client to create buffers at the right size ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) - defer l.sf.Finish(sfKey, fetchKey, item) + defer l.sf.Finish(item) // Perform the actual load err := l.loadByContextDirect(ctx, source, headers, input, res) diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go index 013d906775..85f73d7423 100644 --- a/v2/pkg/engine/resolve/subgraph_request_singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -10,15 +10,13 @@ import ( // It's hashing the input and adds the pre-computed subgraph headers hash to avoid collisions // In addition to single flight, it provides size hints to create right-sized buffers for subgraph requests type SubgraphRequestSingleFlight struct { - shards []singleFlightShard - xxPool *sync.Pool - cleanup chan func() + shards []singleFlightShard + xxPool *sync.Pool } type singleFlightShard struct { - mu sync.RWMutex - items map[uint64]*SingleFlightItem - sizes map[uint64]*fetchSize + items sync.Map // map[uint64]*SingleFlightItem + sizes sync.Map // map[uint64]*fetchSize } const defaultSingleFlightShardCount = 4 @@ -36,10 +34,15 @@ type SingleFlightItem struct { // this gives a leader a hint on how much space it should pre-allocate for buffers when fetching // this reduces memory usage sizeHint int + // SFKey uniquely identifies a single flight request + SFKey uint64 + // FetchKey groups similar fetches for size hinting + FetchKey uint64 } // fetchSize gives an estimate of required buffer size for a given fetchKey when dividing totalBytes / count type fetchSize struct { + mu sync.Mutex // count is the number of fetches tracked count int // totalBytes is the cumulative bytes across tracked fetches @@ -57,74 +60,103 @@ func NewSingleFlight(shardCount int) *SubgraphRequestSingleFlight { return xxhash.New() }, }, - cleanup: make(chan func()), - } - for i := range s.shards { - s.shards[i] = singleFlightShard{ - items: make(map[uint64]*SingleFlightItem), - sizes: make(map[uint64]*fetchSize), - } } return s } -// GetOrCreateItem generates a single flight key (100% identical fetches) and a fetchKey (similar fetches, collisions possible but unproblematic) -// and return a SingleFlightItem as well as an indication if it's shared or not -// If shared == false, the caller is a leader -// If shared == true, the caller is a follower -// item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader -// item.err must always be checked -// item.response must never be mutated -func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64, item *SingleFlightItem, shared bool) { - sfKey, fetchKey = s.keys(fetchItem, input, extraKey) +// GetOrCreateItem returns a SingleFlightItem, which contains the single flight key (100% identical fetches), +// a fetchKey (similar fetches, collisions possible but unproblematic because it's only used for size hints), +// and an indication if it is shared or not. +// If not shared, the caller is a leader, otherwise it is a follower. +// item.sizeHint can be used to create an optimal buffer for the fetch in case of a leader. +// item.err must always be checked. +// item.response must never be mutated. +func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, input []byte, extraKey uint64) (item *SingleFlightItem, shared bool) { + sfKey, fetchKey := s.computeKeys(fetchItem, input, extraKey) // Get shard based on sfKey for items shard := s.shardFor(sfKey) - // First, try to get the item with a read lock on its shard - shard.mu.RLock() - item, exists := shard.items[sfKey] - shard.mu.RUnlock() - if exists { - return sfKey, fetchKey, item, true - } - - // If not exists, acquire a write lock to create the item - shard.mu.Lock() - // Double-check if the item was created while acquiring the write lock - item, exists = shard.items[sfKey] - if exists { - shard.mu.Unlock() - return sfKey, fetchKey, item, true + if existing, ok := shard.items.Load(sfKey); ok { + return existing.(*SingleFlightItem), true } - // Create a new item item = &SingleFlightItem{ // empty chan to indicate to all followers when we're done (close) - loaded: make(chan struct{}), + loaded: make(chan struct{}), + SFKey: sfKey, + FetchKey: fetchKey, } // Read size hint from the same shard (both items and sizes use the same shard now) - if size, ok := shard.sizes[fetchKey]; ok { - item.sizeHint = size.totalBytes / size.count + if sizeValue, ok := shard.sizes.Load(fetchKey); ok { + size := sizeValue.(*fetchSize) + size.mu.Lock() + if size.count > 0 { + item.sizeHint = size.totalBytes / size.count + } + size.mu.Unlock() + } + + actual, loaded := shard.items.LoadOrStore(sfKey, item) + if loaded { + return actual.(*SingleFlightItem), true + } + return item, false +} + +// Finish is for the leader to mark the SingleFlightItem as "done" +// trigger all followers to look at the err & response of the item +// and to update the size estimates +func (s *SubgraphRequestSingleFlight) Finish(item *SingleFlightItem) { + sfKey := item.SFKey + fetchKey := item.FetchKey + close(item.loaded) + // Update sizes in the same shard as the item (using sfKey to get the shard) + shard := s.shardFor(sfKey) + + shard.items.Delete(sfKey) + + sizeValue, ok := shard.sizes.Load(fetchKey) + if !ok { + newSize := &fetchSize{} + sizeValue, _ = shard.sizes.LoadOrStore(fetchKey, newSize) + } + size := sizeValue.(*fetchSize) + size.mu.Lock() + if size.count == 0 { + size.count = 1 + size.totalBytes = len(item.response) + size.mu.Unlock() + return + } + if size.count == 50 { + size.count = 1 + size.totalBytes = size.totalBytes / 50 } - shard.items[sfKey] = item - shard.mu.Unlock() - return sfKey, fetchKey, item, false + size.count++ + size.totalBytes += len(item.response) + size.mu.Unlock() } -func (s *SubgraphRequestSingleFlight) keys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { +func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { + idx := int(key % uint64(len(s.shards))) + return &s.shards[idx] +} + +func (s *SubgraphRequestSingleFlight) computeKeys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { h := s.xxPool.Get().(*xxhash.Digest) - sfKey = s.sfKey(h, fetchItem, input, extraKey) + sfKey = s.computeSFKey(fetchItem, input, extraKey) h.Reset() - fetchKey = s.fetchKey(h, fetchItem) + fetchKey = s.computeFetchKey(fetchItem) h.Reset() s.xxPool.Put(h) return sfKey, fetchKey } -// sfKey returns a key that 100% uniquely identifies a fetch with no collision -// two sfKey are only the same when the fetches are 100% equal -func (s *SubgraphRequestSingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { +// computeSFKey returns a key that 100% uniquely identifies a fetch with no collision. +// Two sfKey values are only the same when the fetches are 100% equal. +func (s *SubgraphRequestSingleFlight) computeSFKey(fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { + h := s.xxPool.Get().(*xxhash.Digest) if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -136,11 +168,12 @@ func (s *SubgraphRequestSingleFlight) sfKey(h *xxhash.Digest, fetchItem *FetchIt return h.Sum64() + extraKey // extraKey in this case is the pre-generated hash for the headers } -// fetchKey is a less robust key compared to sfKey -// the purpose is to create a key from the DataSourceID and root fields to have less cardinality -// the goal is to get an estimate buffer size for similar fetches -// there's no point in hashing headers or the body for this purpose -func (s *SubgraphRequestSingleFlight) fetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { +// computeFetchKey is a less robust key compared to sfKey. +// The purpose is to create a key from the DataSourceID and root fields to have less cardinality. +// The goal is to get an estimate buffer size for similar fetches; hashing headers or the body is not needed. +func (s *SubgraphRequestSingleFlight) computeFetchKey(fetchItem *FetchItem) uint64 { + h := s.xxPool.Get().(*xxhash.Digest) + defer s.xxPool.Put(h) if fetchItem == nil || fetchItem.Fetch == nil { return 0 } @@ -158,35 +191,6 @@ func (s *SubgraphRequestSingleFlight) fetchKey(h *xxhash.Digest, fetchItem *Fetc _, _ = h.Write(dot) _, _ = h.WriteString(info.RootFields[i].FieldName) } - return h.Sum64() -} - -// Finish is for the leader to mark the SingleFlightItem as "done" -// trigger all followers to look at the err & response of the item -// and to update the size estimates -func (s *SubgraphRequestSingleFlight) Finish(sfKey, fetchKey uint64, item *SingleFlightItem) { - close(item.loaded) - // Update sizes in the same shard as the item (using sfKey to get the shard) - shard := s.shardFor(sfKey) - shard.mu.Lock() - delete(shard.items, sfKey) - if size, ok := shard.sizes[fetchKey]; ok { - if size.count == 50 { - size.count = 1 - size.totalBytes = size.totalBytes / 50 - } - size.count++ - size.totalBytes += len(item.response) - } else { - shard.sizes[fetchKey] = &fetchSize{ - count: 1, - totalBytes: len(item.response), - } - } - shard.mu.Unlock() -} - -func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { - idx := int(key % uint64(len(s.shards))) - return &s.shards[idx] + sum := h.Sum64() + return sum } diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go b/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go new file mode 100644 index 0000000000..312236359a --- /dev/null +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go @@ -0,0 +1,209 @@ +package resolve + +import ( + "bytes" + "fmt" + "testing" +) + +type stubFetch struct { + info *FetchInfo +} + +func (s *stubFetch) FetchKind() FetchKind { + return FetchKindSingle +} + +func (s *stubFetch) Dependencies() *FetchDependencies { + return nil +} + +func (s *stubFetch) FetchInfo() *FetchInfo { + return s.info +} + +type nilInfoFetch struct{} + +func (n *nilInfoFetch) FetchKind() FetchKind { + return FetchKindSingle +} + +func (n *nilInfoFetch) Dependencies() *FetchDependencies { + return nil +} + +func (n *nilInfoFetch) FetchInfo() *FetchInfo { + return nil +} + +func newFetchItem(info *FetchInfo) *FetchItem { + return &FetchItem{ + Fetch: &stubFetch{ + info: info, + }, + } +} + +func TestSubgraphRequestSingleFlight_LeaderFollowerSizeHint(t *testing.T) { + flight := NewSingleFlight(2) + fetchInfo := &FetchInfo{ + DataSourceID: "accounts", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "viewer"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + item, shared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if shared { + t.Fatalf("expected leader to be first caller") + } + if item == nil { + t.Fatalf("expected item, got nil") + } + if item.sizeHint != 0 { + t.Fatalf("expected empty size hint, got %d", item.sizeHint) + } + + follower, followerShared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if !followerShared { + t.Fatalf("expected second caller to be follower") + } + if follower != item { + t.Fatalf("expected follower to receive same item instance") + } + + item.response = []byte("hello") + flight.Finish(item) + + select { + case <-item.loaded: + default: + t.Fatalf("expected leader to close loaded channel") + } + + next, nextShared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if nextShared { + t.Fatalf("expected new leader after finish") + } + if next == item { + t.Fatalf("expected new item after finish") + } + if next.sizeHint != len("hello") { + t.Fatalf("expected size hint %d, got %d", len("hello"), next.sizeHint) + } +} + +func TestSubgraphRequestSingleFlight_SimilarFetchesShareFetchKey(t *testing.T) { + flight := NewSingleFlight(1) + fetchInfo := &FetchInfo{ + DataSourceID: "reviews", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "reviews"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + item1, shared1 := flight.GetOrCreateItem(fetchItem, []byte("body-1"), 0) + if shared1 { + t.Fatalf("expected first call to be leader") + } + item1.response = []byte("first response") + flight.Finish(item1) + + item2, shared2 := flight.GetOrCreateItem(fetchItem, []byte("body-2"), 0) + if shared2 { + t.Fatalf("expected leader after finishing previous item") + } + if item1.FetchKey != item2.FetchKey { + t.Fatalf("expected identical fetch keys for similar fetches") + } + if item1.SFKey == item2.SFKey { + t.Fatalf("expected different single-flight keys for different request bodies") + } + item2.response = []byte("second response") + flight.Finish(item2) +} + +func TestSubgraphRequestSingleFlight_FetchKeyZeroWithoutFetchInfo(t *testing.T) { + t.Run("nil fetch item", func(t *testing.T) { + flight := NewSingleFlight(1) + item, shared := flight.GetOrCreateItem(nil, []byte("body"), 0) + if shared { + t.Fatalf("expected leader for nil fetch item") + } + if item.FetchKey != 0 { + t.Fatalf("expected fetch key 0, got %d", item.FetchKey) + } + flight.Finish(item) + }) + + t.Run("nil fetch", func(t *testing.T) { + flight := NewSingleFlight(1) + item, shared := flight.GetOrCreateItem(&FetchItem{}, []byte("body"), 0) + if shared { + t.Fatalf("expected leader for nil fetch") + } + if item.FetchKey != 0 { + t.Fatalf("expected fetch key 0, got %d", item.FetchKey) + } + flight.Finish(item) + }) + + t.Run("missing fetch info", func(t *testing.T) { + flight := NewSingleFlight(1) + item, shared := flight.GetOrCreateItem(&FetchItem{Fetch: &nilInfoFetch{}}, []byte("body"), 0) + if shared { + t.Fatalf("expected leader for missing fetch info") + } + if item.FetchKey != 0 { + t.Fatalf("expected fetch key 0, got %d", item.FetchKey) + } + flight.Finish(item) + }) +} + +func TestSubgraphRequestSingleFlight_SizeHintRollingWindow(t *testing.T) { + flight := NewSingleFlight(1) + fetchInfo := &FetchInfo{ + DataSourceID: "products", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "products"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + var fetchKey uint64 + for i := 0; i < 50; i++ { + item, shared := flight.GetOrCreateItem(fetchItem, []byte(fmt.Sprintf("body-%d", i)), 0) + if shared { + t.Fatalf("expected leader for iteration %d", i) + } + if i == 0 { + fetchKey = item.FetchKey + } else if item.FetchKey != fetchKey { + t.Fatalf("expected consistent fetch key across iterations, got %d and %d", fetchKey, item.FetchKey) + } + item.response = bytes.Repeat([]byte("a"), 100) + flight.Finish(item) + } + + item, shared := flight.GetOrCreateItem(fetchItem, []byte("body-50"), 0) + if shared { + t.Fatalf("expected leader for rolling window update") + } + if item.FetchKey != fetchKey { + t.Fatalf("expected same fetch key, got %d and %d", fetchKey, item.FetchKey) + } + item.response = bytes.Repeat([]byte("b"), 200) + flight.Finish(item) + + next, nextShared := flight.GetOrCreateItem(fetchItem, []byte("body-51"), 0) + if nextShared { + t.Fatalf("expected leader for new request") + } + expected := 150 + if next.sizeHint != expected { + t.Fatalf("expected rolling average size hint %d, got %d", expected, next.sizeHint) + } +} From 65e3d92fc7ff2fef7e5ee60054fec6e9cdc492fe Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 12:18:07 +0100 Subject: [PATCH 42/57] chore: use assert.Len --- v2/pkg/engine/resolve/arena_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go index c884434f18..4a7d779509 100644 --- a/v2/pkg/engine/resolve/arena_test.go +++ b/v2/pkg/engine/resolve/arena_test.go @@ -14,8 +14,8 @@ func TestNewArenaPool(t *testing.T) { pool := NewArenaPool() require.NotNil(t, pool, "NewArenaPool returned nil") - assert.Equal(t, 0, len(pool.pool), "expected empty pool") - assert.Equal(t, 0, len(pool.sizes), "expected empty sizes map") + assert.Len(t, pool.pool, 0, "expected empty pool") + assert.Len(t, pool.sizes, 0, "expected empty sizes map") } func TestArenaPool_Acquire_EmptyPool(t *testing.T) { @@ -31,7 +31,7 @@ func TestArenaPool_Acquire_EmptyPool(t *testing.T) { _, err := buf.WriteString("test") assert.NoError(t, err) - assert.Equal(t, 0, len(pool.pool), "pool should still be empty") + assert.Len(t, pool.pool, 0, "pool should still be empty") } func TestArenaPool_ReleaseAndAcquire(t *testing.T) { @@ -50,7 +50,7 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { pool.Release(item1) // Pool should have one item - assert.Equal(t, 1, len(pool.pool), "expected pool to have 1 item") + assert.Len(t, pool.pool, 1, "expected pool to have 1 item") // Acquire from pool item2 := pool.Acquire(id) @@ -58,7 +58,7 @@ func TestArenaPool_ReleaseAndAcquire(t *testing.T) { require.NotNil(t, item2, "Acquire returned nil") // Pool should be empty again - assert.Equal(t, 0, len(pool.pool), "expected empty pool after acquire") + assert.Len(t, pool.pool, 0, "expected empty pool after acquire") // The acquired arena should be reset and usable buf2 := arena.NewArenaBuffer(item2.Arena) @@ -92,7 +92,7 @@ func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { } // Pool should have all items - assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + assert.Len(t, pool.pool, numItems, "expected items in pool") // Clear every other item to simulate partial GC for i := 0; i < numItems; i += 2 { @@ -121,7 +121,7 @@ func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { } // Pool should be empty - assert.Equal(t, 0, len(pool.pool), "expected empty pool") + assert.Len(t, pool.pool, 0, "expected empty pool") } func TestArenaPool_Release_PeakTracking(t *testing.T) { @@ -197,7 +197,7 @@ func TestArenaPool_MultipleItemsInPool(t *testing.T) { } // Should have all items in pool - assert.Equal(t, numItems, len(pool.pool), "expected items in pool") + assert.Len(t, pool.pool, numItems, "expected items in pool") // Acquire all back acquired := 0 From a826d980d4be2e15f756ea4a1f2751cf9a6fbc2c Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 12:23:45 +0100 Subject: [PATCH 43/57] chore: improve file handling --- .../datasource/httpclient/nethttpclient.go | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 46af845e4f..4f1e9fe09b 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -298,6 +298,17 @@ func DoMultipartForm( var tempFiles []*os.File + defer func() { + for _, file := range tempFiles { + if err := file.Close(); err != nil { + continue + } + if err = os.Remove(file.Name()); err != nil { + continue + } + } + }() + fileMap := bytes.NewBuffer(nil) fileMap.WriteString("{") hasWrittenFileName := false @@ -307,15 +318,13 @@ func DoMultipartForm( fileMap.WriteString(",") } hasWrittenFileName = true - _, _ = fmt.Fprintf(fileMap, `"%d":["%s"]`, i, file.variablePath) - key := fmt.Sprintf("%d", i) temporaryFile, err := os.Open(file.Path()) - tempFiles = append(tempFiles, temporaryFile) if err != nil { return nil, err } + tempFiles = append(tempFiles, temporaryFile) formValues[key] = bufio.NewReader(temporaryFile) } fileMap.WriteString("}") @@ -327,15 +336,7 @@ func DoMultipartForm( } defer func() { - multipartBody.Close() - for _, file := range tempFiles { - if err := file.Close(); err != nil { - return - } - if err = os.Remove(file.Name()); err != nil { - return - } - } + _ = multipartBody.Close() }() return makeHTTPRequest(client, ctx, baseHeaders, url, method, headers, queryParams, multipartBody, enableTrace, contentType, 0) From d2dfbdea4474b71dfec7b5d71f985af82690712d Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 12:38:54 +0100 Subject: [PATCH 44/57] chore: move arena pool into go-arena package --- v2/go.mod | 2 +- v2/go.sum | 2 + v2/pkg/engine/resolve/arena.go | 144 --------------- v2/pkg/engine/resolve/arena_test.go | 261 ---------------------------- v2/pkg/engine/resolve/loader.go | 11 +- v2/pkg/engine/resolve/resolve.go | 20 +-- 6 files changed, 19 insertions(+), 421 deletions(-) delete mode 100644 v2/pkg/engine/resolve/arena.go delete mode 100644 v2/pkg/engine/resolve/arena_test.go diff --git a/v2/go.mod b/v2/go.mod index 43ada453b4..ddb6105a77 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -29,7 +29,7 @@ require ( github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 github.com/wundergraph/astjson v1.0.0 - github.com/wundergraph/go-arena v1.0.0 + github.com/wundergraph/go-arena v1.1.0 go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.26.0 diff --git a/v2/go.sum b/v2/go.sum index 6d0fb36360..59916bf7f6 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -138,6 +138,8 @@ github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUP github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/go-arena v1.0.0 h1:RVYWpDkJ1/6851BRHYehBeEcTLKmZygYIZsvBorcOjw= github.com/wundergraph/go-arena v1.0.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= +github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= +github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/v2/pkg/engine/resolve/arena.go b/v2/pkg/engine/resolve/arena.go deleted file mode 100644 index 7909460b29..0000000000 --- a/v2/pkg/engine/resolve/arena.go +++ /dev/null @@ -1,144 +0,0 @@ -package resolve - -import ( - "sync" - "weak" - - "github.com/wundergraph/go-arena" -) - -// ArenaPool provides a thread-safe pool of arena.Arena instances for memory-efficient allocations. -// It uses weak pointers to allow garbage collection of unused arenas while maintaining -// a pool of reusable arenas for high-frequency allocation patterns. -// -// by storing ArenaPoolItem as weak pointers, the GC can collect them at any time -// before using an ArenaPoolItem, we try to get a strong pointer while removing it from the pool -// once we call Release, we turn the item back to the pool and make it a weak pointer again -// this means that at any time, GC can claim back the memory if required, -// allowing GC to automatically manage an appropriate pool size depending on available memory and GC pressure -type ArenaPool struct { - // pool is a slice of weak pointers to the struct holding the arena.Arena - pool []weak.Pointer[ArenaPoolItem] - sizes map[uint64]*arenaPoolItemSize - mu sync.Mutex -} - -// arenaPoolItemSize is used to track the required memory across the last 50 arenas in the pool -type arenaPoolItemSize struct { - count int - totalBytes int -} - -// ArenaPoolItem wraps an arena.Arena for use in the pool -type ArenaPoolItem struct { - Arena arena.Arena - Key uint64 -} - -// NewArenaPool creates a new ArenaPool instance -func NewArenaPool() *ArenaPool { - return &ArenaPool{ - sizes: make(map[uint64]*arenaPoolItemSize), - } -} - -// Acquire gets an arena from the pool or creates a new one if none are available. -// The id parameter is used to track arena sizes per use case for optimization. -func (p *ArenaPool) Acquire(key uint64) *ArenaPoolItem { - p.mu.Lock() - defer p.mu.Unlock() - - // Try to find an available arena in the pool - for len(p.pool) > 0 { - // Pop the last item - lastIdx := len(p.pool) - 1 - wp := p.pool[lastIdx] - p.pool = p.pool[:lastIdx] - - v := wp.Value() - if v != nil { - v.Key = key - return v - } - // If weak pointer was nil (GC collected), continue to next item - } - - // No arena available, create a new one - size := arena.WithMinBufferSize(p.getArenaSize(key)) - return &ArenaPoolItem{ - Arena: arena.NewMonotonicArena(size), - Key: key, - } -} - -// Release returns an arena to the pool for reuse. -// The peak memory usage is recorded to optimize future arena sizes for this use case. -func (p *ArenaPool) Release(item *ArenaPoolItem) { - peak := item.Arena.Peak() - item.Arena.Reset() - - p.mu.Lock() - defer p.mu.Unlock() - - // Record the peak usage for this use case - if size, ok := p.sizes[item.Key]; ok { - if size.count == 50 { - size.count = 1 - size.totalBytes = size.totalBytes / 50 - } - size.count++ - size.totalBytes += peak - } else { - p.sizes[item.Key] = &arenaPoolItemSize{ - count: 1, - totalBytes: peak, - } - } - - item.Key = 0 - - // Add the arena back to the pool using a weak pointer - w := weak.Make(item) - p.pool = append(p.pool, w) -} - -func (p *ArenaPool) ReleaseMany(items []*ArenaPoolItem) { - p.mu.Lock() - defer p.mu.Unlock() - - for _, item := range items { - - peak := item.Arena.Peak() - item.Arena.Reset() - - // Record the peak usage for this use case - if size, ok := p.sizes[item.Key]; ok { - if size.count == 50 { - size.count = 1 - size.totalBytes = size.totalBytes / 50 - } - size.count++ - size.totalBytes += peak - } else { - p.sizes[item.Key] = &arenaPoolItemSize{ - count: 1, - totalBytes: peak, - } - } - - item.Key = 0 - - // Add the arena back to the pool using a weak pointer - w := weak.Make(item) - p.pool = append(p.pool, w) - } -} - -// getArenaSize returns the optimal arena size for a given use case ID. -// If no size is recorded, it defaults to 1MB. -func (p *ArenaPool) getArenaSize(id uint64) int { - if size, ok := p.sizes[id]; ok { - return size.totalBytes / size.count - } - return 1024 * 1024 // Default 1MB -} diff --git a/v2/pkg/engine/resolve/arena_test.go b/v2/pkg/engine/resolve/arena_test.go deleted file mode 100644 index 4a7d779509..0000000000 --- a/v2/pkg/engine/resolve/arena_test.go +++ /dev/null @@ -1,261 +0,0 @@ -package resolve - -import ( - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/go-arena" -) - -func TestNewArenaPool(t *testing.T) { - pool := NewArenaPool() - - require.NotNil(t, pool, "NewArenaPool returned nil") - assert.Len(t, pool.pool, 0, "expected empty pool") - assert.Len(t, pool.sizes, 0, "expected empty sizes map") -} - -func TestArenaPool_Acquire_EmptyPool(t *testing.T) { - pool := NewArenaPool() - - item := pool.Acquire(1) - - require.NotNil(t, item, "Acquire returned nil") - assert.NotNil(t, item.Arena, "Arena is nil") - - // Verify we can use the arena - buf := arena.NewArenaBuffer(item.Arena) - _, err := buf.WriteString("test") - assert.NoError(t, err) - - assert.Len(t, pool.pool, 0, "pool should still be empty") -} - -func TestArenaPool_ReleaseAndAcquire(t *testing.T) { - pool := NewArenaPool() - id := uint64(42) - - // Acquire first arena - item1 := pool.Acquire(id) - - // Use the arena - buf := arena.NewArenaBuffer(item1.Arena) - _, err := buf.WriteString("test data") - assert.NoError(t, err) - - // Release it - pool.Release(item1) - - // Pool should have one item - assert.Len(t, pool.pool, 1, "expected pool to have 1 item") - - // Acquire from pool - item2 := pool.Acquire(id) - - require.NotNil(t, item2, "Acquire returned nil") - - // Pool should be empty again - assert.Len(t, pool.pool, 0, "expected empty pool after acquire") - - // The acquired arena should be reset and usable - buf2 := arena.NewArenaBuffer(item2.Arena) - _, err = buf2.WriteString("new data") - assert.NoError(t, err) - - assert.Equal(t, "new data", buf2.String()) -} - -func TestArenaPool_Acquire_ProvesBugFix(t *testing.T) { - // This test specifically proves the bug fix works - // Creates multiple items, clears some references, then acquires - // to ensure all items are checked without skipping - pool := NewArenaPool() - id := uint64(800) - - numItems := 10 - items := make([]*ArenaPoolItem, numItems) - - // Acquire all items - for i := 0; i < numItems; i++ { - items[i] = pool.Acquire(id) - buf := arena.NewArenaBuffer(items[i].Arena) - _, err := buf.WriteString("item data") - assert.NoError(t, err) - } - - // Release all while keeping strong references - for i := 0; i < numItems; i++ { - pool.Release(items[i]) - } - - // Pool should have all items - assert.Len(t, pool.pool, numItems, "expected items in pool") - - // Clear every other item to simulate partial GC - for i := 0; i < numItems; i += 2 { - items[i] = nil - } - - // Force GC - runtime.GC() - runtime.GC() - - // Acquire items - should process ALL items without skipping - processed := 0 - acquired := 0 - - for len(pool.pool) > 0 && processed < numItems*2 { - poolSizeBefore := len(pool.pool) - item := pool.Acquire(id) - poolSizeAfter := len(pool.pool) - processed++ - - assert.Less(t, poolSizeAfter, poolSizeBefore, "Pool size did not decrease - item not removed properly!") - - if item != nil { - acquired++ - } - } - - // Pool should be empty - assert.Len(t, pool.pool, 0, "expected empty pool") -} - -func TestArenaPool_Release_PeakTracking(t *testing.T) { - pool := NewArenaPool() - id := uint64(200) - - // First arena - item1 := pool.Acquire(id) - buf1 := arena.NewArenaBuffer(item1.Arena) - _, err := buf1.WriteString("small") - assert.NoError(t, err) - - peak1 := item1.Arena.Peak() - assert.Equal(t, peak1, 5) - - pool.Release(item1) - - // Check that size was tracked - size, exists := pool.sizes[id] - require.True(t, exists, "size tracking not created") - assert.Equal(t, 1, size.count, "expected count 1") - - // Second arena - item2 := pool.Acquire(id) - buf2 := arena.NewArenaBuffer(item2.Arena) - _, err = buf2.WriteString("larger data") - assert.NoError(t, err) - - pool.Release(item2) - - // Check updated tracking - assert.Equal(t, 2, size.count, "expected count 2") -} - -func TestArenaPool_GetArenaSize(t *testing.T) { - pool := NewArenaPool() - - // Test default size for unknown ID - size1 := pool.getArenaSize(999) - expectedDefault := 1024 * 1024 - assert.Equal(t, expectedDefault, size1, "expected default size") - - // Test calculated size after usage - id := uint64(400) - item := pool.Acquire(id) - buf := arena.NewArenaBuffer(item.Arena) - _, err := buf.WriteString("some data") - assert.NoError(t, err) - pool.Release(item) - - size2 := pool.getArenaSize(id) - assert.NotEqual(t, 0, size2, "expected non-zero size after usage") -} - -func TestArenaPool_MultipleItemsInPool(t *testing.T) { - pool := NewArenaPool() - id := uint64(500) - - // Acquire multiple distinct items - numItems := 3 - items := make([]*ArenaPoolItem, numItems) - - for i := 0; i < numItems; i++ { - items[i] = pool.Acquire(id) - buf := arena.NewArenaBuffer(items[i].Arena) - _, err := buf.WriteString("data") - assert.NoError(t, err) - } - - // Release all while keeping references - for i := 0; i < numItems; i++ { - pool.Release(items[i]) - } - - // Should have all items in pool - assert.Len(t, pool.pool, numItems, "expected items in pool") - - // Acquire all back - acquired := 0 - for len(pool.pool) > 0 { - item := pool.Acquire(id) - if item != nil { - acquired++ - } - } - - assert.Equal(t, numItems, acquired, "expected to acquire all items") -} - -func TestArenaPool_Release_MovingWindow(t *testing.T) { - pool := NewArenaPool() - id := uint64(600) - - // Release exactly 50 items - for i := 0; i < 50; i++ { - item := pool.Acquire(id) - buf := arena.NewArenaBuffer(item.Arena) - _, err := buf.WriteString("test data") - assert.NoError(t, err) - pool.Release(item) - } - - // After 50 releases, verify count and total - size := pool.sizes[id] - require.NotNil(t, size, "size tracking should exist") - assert.Equal(t, 50, size.count, "expected count to be 50") - - totalBytesAfter50 := size.totalBytes - - // Release one more item to trigger the window reset - item51 := pool.Acquire(id) - buf51 := arena.NewArenaBuffer(item51.Arena) - _, err := buf51.WriteString("test data") - assert.NoError(t, err) - peak51 := item51.Arena.Peak() - pool.Release(item51) - - // After 51st release, verify the window was reset - // count should be 2 (reset to 1, then incremented) - // totalBytes should be (totalBytesAfter50 / 50) + peak51 - assert.Equal(t, 2, size.count, "expected count to be 2 after window reset") - - expectedTotalBytes := (totalBytesAfter50 / 50) + peak51 - assert.Equal(t, expectedTotalBytes, size.totalBytes, "expected totalBytes to be divided by 50 and new peak added") - - // Verify we can continue releasing and counting works correctly - for i := 0; i < 10; i++ { - item := pool.Acquire(id) - buf := arena.NewArenaBuffer(item.Arena) - _, err := buf.WriteString("more data") - assert.NoError(t, err) - pool.Release(item) - } - - // After 10 more releases, count should be 12 (2 + 10) - assert.Equal(t, 12, size.count, "expected count to continue incrementing after window reset") -} diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 23f0b6d327..e11e43ad41 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -181,9 +181,10 @@ type Loader struct { // If you're not doing this, you will see segfaults // Example of correct usage in func "mergeResult" jsonArena arena.Arena - // sf is the SubgraphRequestSingleFlight object shared across all client requests - // it's thread safe and can be used to de-duplicate subgraph requests - sf *SubgraphRequestSingleFlight + + // singleFlight is the SubgraphRequestSingleFlight object shared across all client requests. + // It's thread safe and can be used to de-duplicate subgraph requests. + singleFlight *SubgraphRequestSingleFlight } func (l *Loader) Free() { @@ -1662,7 +1663,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem return l.loadByContextDirect(ctx, source, headers, input, res) } - item, shared := l.sf.GetOrCreateItem(fetchItem, input, extraKey) + item, shared := l.singleFlight.GetOrCreateItem(fetchItem, input, extraKey) if res.singleFlightStats != nil { res.singleFlightStats.used = true res.singleFlightStats.shared = shared @@ -1686,7 +1687,7 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem // helps the http client to create buffers at the right size ctx = httpclient.WithHTTPClientSizeHint(ctx, item.sizeHint) - defer l.sf.Finish(item) + defer l.singleFlight.Finish(item) // Perform the actual load err := l.loadByContextDirect(ctx, source, headers, input, res) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 747ee02c4e..2b05fb8141 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -74,14 +74,14 @@ type Resolver struct { // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration - // resolveArenaPool is the arena pool dedicated for Loader & Resolvable - // ArenaPool automatically adjusts arena buffer sizes per workload - // resolving & response buffering are very different tasks - // as such, it was best to have two arena pools in terms of memory usage - // A single pool for both was much less efficient - resolveArenaPool *ArenaPool + // resolveArenaPool is the arena pool dedicated for Loader & Resolvable. + // ArenaPool automatically adjusts arena buffer sizes per workload. + // Resolving & response buffering are very different tasks; + // as such, it was best to have two arena pools in terms of memory usage. + // A single pool for both was much less efficient. + resolveArenaPool *arena.Pool // responseBufferPool is the arena pool dedicated for response buffering before sending to the client - responseBufferPool *ArenaPool + responseBufferPool *arena.Pool // subgraphRequestSingleFlight is used to de-duplicate subgraph requests subgraphRequestSingleFlight *SubgraphRequestSingleFlight @@ -240,8 +240,8 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, maxSubscriptionFetchTimeout: options.MaxSubscriptionFetchTimeout, - resolveArenaPool: NewArenaPool(), - responseBufferPool: NewArenaPool(), + resolveArenaPool: arena.NewArenaPool(), + responseBufferPool: arena.NewArenaPool(), subgraphRequestSingleFlight: NewSingleFlight(8), inboundRequestSingleFlight: NewRequestSingleFlight(8), } @@ -273,7 +273,7 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ apolloRouterCompatibilitySubrequestHTTPError: options.ApolloRouterCompatibilitySubrequestHTTPError, propagateFetchReasons: options.PropagateFetchReasons, validateRequiredExternalFields: options.ValidateRequiredExternalFields, - sf: sf, + singleFlight: sf, jsonArena: a, }, } From 3fa6b287faa589541d13dc0027cdd6b876e0de26 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 15:46:20 +0100 Subject: [PATCH 45/57] chore: refactor rewriteErrorPaths --- v2/pkg/engine/resolve/loader.go | 21 +++++++-------------- v2/pkg/engine/resolve/loader_test.go | 2 +- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index e11e43ad41..618afcf84a 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -922,28 +922,21 @@ func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Va unsafebytes.BytesToString(item.GetStringBytes()) != "_entities" { continue } - // rewrite the path to pathPrefix + pathItems after _entities - newPath := make([]string, 0, len(pathPrefix)+len(pathItems)-i) - newPath = append(newPath, pathPrefix...) + arr := astjson.ArrayValue(a) + for j := range pathPrefix { + astjson.AppendToArray(arr, astjson.StringValue(a, pathPrefix[j])) + } for j := i + 1; j < len(pathItems); j++ { // If the item after _entities is an index (number), we should ignore it. if j == i+1 && pathItems[j].Type() == astjson.TypeNumber { continue } switch pathItems[j].Type() { - case astjson.TypeString: - newPath = append(newPath, unsafebytes.BytesToString(pathItems[j].GetStringBytes())) - case astjson.TypeNumber: - newPath = append(newPath, strconv.Itoa(pathItems[j].GetInt())) - default: + case astjson.TypeString, astjson.TypeNumber: + astjson.AppendToArray(arr, pathItems[j]) } } - newPathJSON, _ := json.Marshal(newPath) - pathBytes, err := astjson.ParseBytesWithArena(a, newPathJSON) - if err != nil { - continue - } - value.Set(a, "path", pathBytes) + value.Set(a, "path", arr) break } } diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index f88d7227f6..c155dc5c93 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -1463,7 +1463,7 @@ func TestRewriteErrorPaths(t *testing.T) { }, expectedErrors: []*astjson.Value{ mp(`{"message": "nested", "path": ["user", "profile", "address", "street"]}`), - mp(`{"message": "index", "path": ["user", "profile", "reviews", "1", "body"]}`), + mp(`{"message": "index", "path": ["user", "profile", "reviews", 1, "body"]}`), }, }, { From 01ddbb14111bafb940f72c596d26dee71879af21 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 15:46:30 +0100 Subject: [PATCH 46/57] chore: cleanup --- .../engine/datasource/grpc_datasource/grpc_datasource.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 6cbc4ca125..4d1330b384 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -14,6 +14,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" + "github.com/wundergraph/go-arena" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -47,7 +48,7 @@ type DataSource struct { federationConfigs plan.FederationFieldConfigurations disabled bool - pool *resolve.ArenaPool + pool *arena.Pool } type ProtoConfig struct { @@ -83,7 +84,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D mapping: config.Mapping, federationConfigs: config.FederationConfigs, disabled: config.Disabled, - pool: resolve.NewArenaPool(), + pool: arena.NewArenaPool(), }, nil } @@ -98,7 +99,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") var ( - poolItems []*resolve.ArenaPoolItem + poolItems []*arena.PoolItem ) defer func() { d.pool.ReleaseMany(poolItems) @@ -190,7 +191,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return value.MarshalTo(nil), err } -func (d *DataSource) acquirePoolItem(input []byte, index int) *resolve.ArenaPoolItem { +func (d *DataSource) acquirePoolItem(input []byte, index int) *arena.PoolItem { keyGen := xxhash.New() _, _ = keyGen.Write(input) var b [8]byte From f81e2538daddee320c6fdeefd5015f0e2cb3b66c Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 15:47:24 +0100 Subject: [PATCH 47/57] chore: fmt --- v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 4d1330b384..75b1be2dfe 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -14,11 +14,11 @@ import ( "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" - "github.com/wundergraph/go-arena" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" From 2ccc28c1ecd824066ab35af3a233fda3b5b23261 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Tue, 18 Nov 2025 15:52:25 +0100 Subject: [PATCH 48/57] chore: update comment --- v2/pkg/engine/resolve/loader.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 618afcf84a..970cc60846 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1891,7 +1891,10 @@ func (l *Loader) compactJSON(data []byte) ([]byte, error) { return nil, err } out := dst.Bytes() - // don't use arena here or segfault + // Don't use arena here to avoid segfaults. + // If we're not keeping the result long-term on the arena, + // we just parse and re-marshal it to deduplicate object keys. + // This is not a hot path so it's fine. // it's also not a hot path and not important to optimize // arena requires the parsed content to be on the arena as well v, err := astjson.ParseBytes(out) From 32a3368cbf1514efd2187d96fde2a4d03b20c0b2 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 24 Nov 2025 09:57:41 +0100 Subject: [PATCH 49/57] chore: address feedback --- .../datasource/httpclient/nethttpclient.go | 8 ++----- .../resolve/inbound_request_singleflight.go | 19 ++++++++------- .../resolve/subgraph_request_singleflight.go | 24 +++++++------------ 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 4f1e9fe09b..f0fb36694f 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -300,12 +300,8 @@ func DoMultipartForm( defer func() { for _, file := range tempFiles { - if err := file.Close(); err != nil { - continue - } - if err = os.Remove(file.Name()); err != nil { - continue - } + _ = file.Close() + _ = os.Remove(file.Name()) } }() diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index 2552a43fd6..aa0c079484 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -73,18 +73,19 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL shard := r.shardFor(key) req, shared := shard.m.Load(key) if shared { - req := req.(*InflightRequest) - req.Mu.Lock() - req.HasFollowers = true - req.Mu.Unlock() + inflightRequest := req.(*InflightRequest) + inflightRequest.Mu.Lock() + inflightRequest.HasFollowers = true + inflightRequest.Mu.Unlock() select { - case <-req.Done: - if req.Err != nil { - return nil, req.Err + case <-inflightRequest.Done: + if inflightRequest.Err != nil { + return nil, inflightRequest.Err } - return req, nil + return inflightRequest, nil case <-ctx.ctx.Done(): - return nil, ctx.ctx.Err() + inflightRequest.Err = ctx.ctx.Err() + return nil, inflightRequest.Err } } diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go index 85f73d7423..37d3c6941a 100644 --- a/v2/pkg/engine/resolve/subgraph_request_singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -77,16 +77,16 @@ func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, inpu // Get shard based on sfKey for items shard := s.shardFor(sfKey) - if existing, ok := shard.items.Load(sfKey); ok { - return existing.(*SingleFlightItem), true - } - item = &SingleFlightItem{ // empty chan to indicate to all followers when we're done (close) loaded: make(chan struct{}), SFKey: sfKey, FetchKey: fetchKey, } + + if existing, ok := shard.items.LoadOrStore(sfKey, item); ok { + return existing.(*SingleFlightItem), true + } // Read size hint from the same shard (both items and sizes use the same shard now) if sizeValue, ok := shard.sizes.Load(fetchKey); ok { size := sizeValue.(*fetchSize) @@ -97,10 +97,6 @@ func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, inpu size.mu.Unlock() } - actual, loaded := shard.items.LoadOrStore(sfKey, item) - if loaded { - return actual.(*SingleFlightItem), true - } return item, false } @@ -108,18 +104,14 @@ func (s *SubgraphRequestSingleFlight) GetOrCreateItem(fetchItem *FetchItem, inpu // trigger all followers to look at the err & response of the item // and to update the size estimates func (s *SubgraphRequestSingleFlight) Finish(item *SingleFlightItem) { - sfKey := item.SFKey - fetchKey := item.FetchKey + shard := s.shardFor(item.SFKey) + shard.items.Delete(item.SFKey) close(item.loaded) - // Update sizes in the same shard as the item (using sfKey to get the shard) - shard := s.shardFor(sfKey) - - shard.items.Delete(sfKey) - sizeValue, ok := shard.sizes.Load(fetchKey) + sizeValue, ok := shard.sizes.Load(item.FetchKey) if !ok { newSize := &fetchSize{} - sizeValue, _ = shard.sizes.LoadOrStore(fetchKey, newSize) + sizeValue, _ = shard.sizes.LoadOrStore(item.FetchKey, newSize) } size := sizeValue.(*fetchSize) size.mu.Lock() From 9e6c19838d8d3e40521228753292e8eee1b2ba13 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 24 Nov 2025 10:09:48 +0100 Subject: [PATCH 50/57] chore: merge main --- v2/pkg/engine/resolve/resolve_test.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 094f903abe..259e2130d8 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -4975,7 +4975,6 @@ type messageFunc func(counter int) (message string, done bool) var fakeStreamRequestId atomic.Int32 type _fakeStream struct { - uniqueRequestFn func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) messageFunc messageFunc onStart func(input []byte) delay time.Duration @@ -5711,9 +5710,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } fakeStream := createFakeStream(messageFn, time.Millisecond, onStartFn, subscriptionOnStartFn) - fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - return nil - } resolver, plan, recorder, id := setup(c, fakeStream) @@ -5813,10 +5809,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } fakeStream := createFakeStream(messageFn, 1*time.Millisecond, onStartFn, subscriptionOnStartFn) - fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.WriteString("unique") - return - } resolver, plan, recorder, id := setup(c, fakeStream) @@ -5994,14 +5986,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { }, func(ctx StartupHookContext, input []byte) (err error) { return nil }) - fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.WriteString("unique") - if err != nil { - return - } - _, err = xxh.Write(input) - return err - } resolver1, plan1, recorder1, id1 := setup(c, fakeStream) _, _, recorder2, id2 := setup(c, fakeStream) From 5304ba1d58a7709fe34f9b666c2f00b1507f4e7a Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Wed, 26 Nov 2025 12:04:41 +0100 Subject: [PATCH 51/57] chore: improve prepareTrigger --- v2/pkg/engine/resolve/resolve.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index c64e7fd651..03e2948576 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -1183,18 +1183,18 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { // the generated has is the unique triggerID // the headers must be forwarded to the DataSource to create the trigger func (r *Resolver) prepareTrigger(ctx *Context, sourceName string, input []byte) (headers http.Header, triggerID uint64) { + keyGen := pool.Hash64.Get() + _, _ = keyGen.Write(input) if ctx.SubgraphHeadersBuilder != nil { - header, headerHash := ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) - keyGen := pool.Hash64.Get() - _, _ = keyGen.Write(input) + var headersHash uint64 + headers, headersHash = ctx.SubgraphHeadersBuilder.HeadersForSubgraph(sourceName) var b [8]byte - binary.LittleEndian.PutUint64(b[:], headerHash) + binary.LittleEndian.PutUint64(b[:], headersHash) _, _ = keyGen.Write(b[:]) - triggerID = keyGen.Sum64() - pool.Hash64.Put(keyGen) - return header, triggerID } - return nil, 0 + triggerID = keyGen.Sum64() + pool.Hash64.Put(keyGen) + return headers, triggerID } func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter) error { From d3059f447716cbac21af7a0a0df669fb27ef30e8 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 4 Dec 2025 11:26:01 +0100 Subject: [PATCH 52/57] chore: merge main --- .../datasource/grpc_datasource/grpc_datasource_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index cac2fdf1ca..cf21460f8a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -4020,15 +4020,14 @@ func Test_DataSource_Load_WithEntity_Calls_WithCompositeTypes(t *testing.T) { require.NoError(t, err) // Execute the query through our datasource - output := new(bytes.Buffer) input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) - err = ds.Load(context.Background(), []byte(input), output) + data, err := ds.Load(context.Background(), nil, []byte(input)) require.NoError(t, err) // Parse the response var resp graphqlResponse - err = json.Unmarshal(output.Bytes(), &resp) + err = json.Unmarshal(data, &resp) require.NoError(t, err, "Failed to unmarshal response") tc.validate(t, resp.Data) From a6c9da880f9822cb0b2dd2a8bf2031dbc87d4b62 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 4 Dec 2025 11:48:20 +0100 Subject: [PATCH 53/57] chore: add ResolveDeduplicated to GraphQLResolveInfo --- v2/pkg/engine/resolve/resolve.go | 5 + v2/pkg/engine/resolve/resolve_test.go | 198 ++++++++++++++++++++++++++ 2 files changed, 203 insertions(+) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 03e2948576..207a0ad625 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -280,7 +280,11 @@ func newTools(options ResolverOptions, allowedExtensionFields map[string]struct{ } type GraphQLResolveInfo struct { + // ResolveAcquireWaitTime is the time spent waiting to acquire the resolver semaphore + // the semaphore limits the number of concurrent resolve operations ResolveAcquireWaitTime time.Duration + // ResolveDeduplicated indicates whether the resolution of the entire operation was deduplicated via single flight + ResolveDeduplicated bool } func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, data []byte, writer io.Writer) (*GraphQLResolveInfo, error) { @@ -324,6 +328,7 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe } if inflight != nil && inflight.Data != nil { // follower + resp.ResolveDeduplicated = true _, err = writer.Write(inflight.Data) return resp, err } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 259e2130d8..0b8d96ea1c 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -69,6 +69,86 @@ func fakeDataSourceWithInputCheck(t TestingTB, input []byte, data []byte) *_fake } } +type blockingDataSource struct { + data []byte + ready chan struct{} + release chan struct{} + readyOnce sync.Once + releaseOnce sync.Once +} + +func newBlockingDataSource(data []byte) *blockingDataSource { + return &blockingDataSource{ + data: data, + ready: make(chan struct{}), + release: make(chan struct{}), + } +} + +func (f *blockingDataSource) waitForRelease() { + f.readyOnce.Do(func() { + close(f.ready) + }) + <-f.release +} + +func (f *blockingDataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + f.waitForRelease() + return f.data, nil +} + +func (f *blockingDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + f.waitForRelease() + return f.data, nil +} + +func (f *blockingDataSource) Ready() <-chan struct{} { + return f.ready +} + +func (f *blockingDataSource) Release() { + f.releaseOnce.Do(func() { + close(f.release) + }) +} + +type blockingWriter struct { + buf bytes.Buffer + ready chan struct{} + release chan struct{} + readyOnce sync.Once + releaseOnce sync.Once +} + +func newBlockingWriter() *blockingWriter { + return &blockingWriter{ + ready: make(chan struct{}), + release: make(chan struct{}), + } +} + +func (w *blockingWriter) Write(p []byte) (int, error) { + w.readyOnce.Do(func() { + close(w.ready) + }) + <-w.release + return w.buf.Write(p) +} + +func (w *blockingWriter) Ready() <-chan struct{} { + return w.ready +} + +func (w *blockingWriter) Release() { + w.releaseOnce.Do(func() { + close(w.release) + }) +} + +func (w *blockingWriter) String() string { + return w.buf.String() +} + type TestErrorWriter struct { } @@ -4442,6 +4522,124 @@ func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { })) } +func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication(t *testing.T) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResolver(rCtx) + + ds := newBlockingDataSource([]byte(`{"value":"slow"}`)) + defer ds.Release() + + response := &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: ds, + }, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("value"), + Value: &String{ + Path: []string{"value"}, + Nullable: false, + }, + }, + }, + }, + } + + ctxTemplate := Context{ + ctx: context.Background(), + Request: Request{ + ID: 42, + }, + VariablesHash: 1337, + } + + const requestCount = 3 + + type result struct { + info *GraphQLResolveInfo + output string + err error + } + + results := make([]result, requestCount) + + var wg sync.WaitGroup + wg.Add(requestCount) + + leaderWriter := newBlockingWriter() + + go func() { + defer wg.Done() + ctx := ctxTemplate + info, err := r.ArenaResolveGraphQLResponse(&ctx, response, leaderWriter) + results[0] = result{info: info, output: leaderWriter.String(), err: err} + }() + + select { + case <-ds.Ready(): + case <-time.After(time.Second): + t.Fatalf("timeout waiting for leader data source load") + } + + startFollowers := make(chan struct{}) + followersEntered := make(chan struct{}, requestCount-1) + + for i := 1; i < requestCount; i++ { + go func(i int) { + defer wg.Done() + ctx := ctxTemplate + <-startFollowers + followersEntered <- struct{}{} + buf := &bytes.Buffer{} + info, err := r.ArenaResolveGraphQLResponse(&ctx, response, buf) + results[i] = result{info: info, output: buf.String(), err: err} + }(i) + } + + close(startFollowers) + + for i := 1; i < requestCount; i++ { + select { + case <-followersEntered: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for follower %d to start", i) + } + } + + ds.Release() + + select { + case <-leaderWriter.Ready(): + case <-time.After(time.Second): + t.Fatalf("timeout waiting for leader to start writing response") + } + + leaderWriter.Release() + wg.Wait() + + for _, res := range results { + require.NoError(t, res.err) + require.NotNil(t, res.info) + } + + assert.False(t, results[0].info.ResolveDeduplicated) + + expectedOutput := results[0].output + require.NotEmpty(t, expectedOutput) + + for i := 1; i < requestCount; i++ { + assert.True(t, results[i].info.ResolveDeduplicated) + assert.Equal(t, expectedOutput, results[i].output) + } +} + func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { options := apolloCompatibilityOptions{ valueCompletion: true, From 85774faf3ed4060a6585378c04550729ece9fc4d Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Thu, 4 Dec 2025 11:51:24 +0100 Subject: [PATCH 54/57] chore: fmt --- v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 0c8c96e731..ce01dbf2a2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -179,7 +179,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return nil }); err != nil { - return builder.writeErrorBytes(err),nil + return builder.writeErrorBytes(err), nil } value := builder.toDataObject(root) From da53e7bd32cf724734ee95c10749cb8501c7074f Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 8 Dec 2025 18:46:29 +0100 Subject: [PATCH 55/57] chore: improve hashing of keys --- .../resolve/inbound_request_singleflight.go | 45 ++++++++++--------- .../resolve/subgraph_request_singleflight.go | 32 +++++++------ 2 files changed, 40 insertions(+), 37 deletions(-) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index aa0c079484..a796dee4d0 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "sync" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) // InboundRequestSingleFlight is a sharded goroutine safe single flight implementation to de-couple inbound requests @@ -68,34 +68,39 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL hh = ctx.SubgraphHeadersBuilder.HashAll() } binary.LittleEndian.PutUint64(b[16:24], hh) - key := xxhash.Sum64(b[:]) + h := pool.Hash64.Get() + _, _ = h.Write(b[:]) + key := h.Sum64() + pool.Hash64.Put(h) shard := r.shardFor(key) - req, shared := shard.m.Load(key) + + //fmt.Printf("key: %d shard: %d\n", key, key%uint64(len(r.shards))) + + request := &InflightRequest{ + Done: make(chan struct{}), + ID: key, + } + + inflight, shared := shard.m.LoadOrStore(key, request) if shared { - inflightRequest := req.(*InflightRequest) - inflightRequest.Mu.Lock() - inflightRequest.HasFollowers = true - inflightRequest.Mu.Unlock() + request = inflight.(*InflightRequest) + request.Mu.Lock() + request.HasFollowers = true + request.Mu.Unlock() select { - case <-inflightRequest.Done: - if inflightRequest.Err != nil { - return nil, inflightRequest.Err + case <-request.Done: + if request.Err != nil { + return nil, request.Err } - return inflightRequest, nil + return request, nil case <-ctx.ctx.Done(): - inflightRequest.Err = ctx.ctx.Err() - return nil, inflightRequest.Err + request.Err = ctx.ctx.Err() + return nil, request.Err } } - value := &InflightRequest{ - Done: make(chan struct{}), - ID: key, - } - - shard.m.Store(key, value) - return value, nil + return request, nil } func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) { diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go index 37d3c6941a..e86302857d 100644 --- a/v2/pkg/engine/resolve/subgraph_request_singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -1,9 +1,11 @@ package resolve import ( + "encoding/binary" "sync" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) // SubgraphRequestSingleFlight is a sharded, goroutine safe single flight implementation to de-duplicate subgraph requests @@ -11,7 +13,6 @@ import ( // In addition to single flight, it provides size hints to create right-sized buffers for subgraph requests type SubgraphRequestSingleFlight struct { shards []singleFlightShard - xxPool *sync.Pool } type singleFlightShard struct { @@ -55,11 +56,6 @@ func NewSingleFlight(shardCount int) *SubgraphRequestSingleFlight { } s := &SubgraphRequestSingleFlight{ shards: make([]singleFlightShard, shardCount), - xxPool: &sync.Pool{ - New: func() any { - return xxhash.New() - }, - }, } return s } @@ -136,19 +132,17 @@ func (s *SubgraphRequestSingleFlight) shardFor(key uint64) *singleFlightShard { } func (s *SubgraphRequestSingleFlight) computeKeys(fetchItem *FetchItem, input []byte, extraKey uint64) (sfKey, fetchKey uint64) { - h := s.xxPool.Get().(*xxhash.Digest) - sfKey = s.computeSFKey(fetchItem, input, extraKey) + h := pool.Hash64.Get() + sfKey = s.computeSFKey(h, fetchItem, input, extraKey) h.Reset() - fetchKey = s.computeFetchKey(fetchItem) - h.Reset() - s.xxPool.Put(h) + fetchKey = s.computeFetchKey(h, fetchItem) + pool.Hash64.Put(h) return sfKey, fetchKey } // computeSFKey returns a key that 100% uniquely identifies a fetch with no collision. // Two sfKey values are only the same when the fetches are 100% equal. -func (s *SubgraphRequestSingleFlight) computeSFKey(fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { - h := s.xxPool.Get().(*xxhash.Digest) +func (s *SubgraphRequestSingleFlight) computeSFKey(h *xxhash.Digest, fetchItem *FetchItem, input []byte, extraKey uint64) uint64 { if fetchItem != nil && fetchItem.Fetch != nil { info := fetchItem.Fetch.FetchInfo() if info != nil { @@ -157,15 +151,19 @@ func (s *SubgraphRequestSingleFlight) computeSFKey(fetchItem *FetchItem, input [ } } _, _ = h.Write(input) - return h.Sum64() + extraKey // extraKey in this case is the pre-generated hash for the headers + if extraKey != 0 { + // include pre-computed headers hash to avoid collisions + var buf [8]byte + binary.LittleEndian.PutUint64(buf[0:8], extraKey) + _, _ = h.Write(buf[:]) + } + return h.Sum64() } // computeFetchKey is a less robust key compared to sfKey. // The purpose is to create a key from the DataSourceID and root fields to have less cardinality. // The goal is to get an estimate buffer size for similar fetches; hashing headers or the body is not needed. -func (s *SubgraphRequestSingleFlight) computeFetchKey(fetchItem *FetchItem) uint64 { - h := s.xxPool.Get().(*xxhash.Digest) - defer s.xxPool.Put(h) +func (s *SubgraphRequestSingleFlight) computeFetchKey(h *xxhash.Digest, fetchItem *FetchItem) uint64 { if fetchItem == nil || fetchItem.Fetch == nil { return 0 } From 5ae1a1686de61b677788c27ec3fd7ede84ad4427 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 8 Dec 2025 18:46:56 +0100 Subject: [PATCH 56/57] chore: fmt --- v2/pkg/engine/resolve/subgraph_request_singleflight.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go index e86302857d..1f6fcfaf4d 100644 --- a/v2/pkg/engine/resolve/subgraph_request_singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) From 70eb5187e12c45e206c37c2b75805c8d792f9599 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 8 Dec 2025 18:57:14 +0100 Subject: [PATCH 57/57] chore: fmt --- v2/pkg/engine/resolve/inbound_request_singleflight.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index a796dee4d0..3767b31aa6 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -75,8 +75,6 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL shard := r.shardFor(key) - //fmt.Printf("key: %d shard: %d\n", key, key%uint64(len(r.shards))) - request := &InflightRequest{ Done: make(chan struct{}), ID: key,