diff --git a/release-please-manifest.json b/release-please-manifest.json index 69cb102bb..3c136e44f 100644 --- a/release-please-manifest.json +++ b/release-please-manifest.json @@ -1,4 +1,4 @@ { - "v2": "2.0.0-rc.142", + "v2": "2.0.0-rc.145", "execution": "1.2.0" } diff --git a/v2/CHANGELOG.md b/v2/CHANGELOG.md index 02f3771c9..b14318dd1 100644 --- a/v2/CHANGELOG.md +++ b/v2/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## [2.0.0-rc.145](https://github.com/wundergraph/graphql-go-tools/compare/v2.0.0-rc.144...v2.0.0-rc.145) (2025-01-27) + + +### Features + +* add normalizedQuery to query plan and request info to trace ([#1045](https://github.com/wundergraph/graphql-go-tools/issues/1045)) ([e75a1dd](https://github.com/wundergraph/graphql-go-tools/commit/e75a1dd24d5255b6cc990269c5c7922f851f4fc1)) + +## [2.0.0-rc.144](https://github.com/wundergraph/graphql-go-tools/compare/v2.0.0-rc.143...v2.0.0-rc.144) (2025-01-23) + + +### Bug Fixes + +* remove semaphore from ResolveGraphQLSubscription ([#1043](https://github.com/wundergraph/graphql-go-tools/issues/1043)) ([76d644e](https://github.com/wundergraph/graphql-go-tools/commit/76d644eb2316bfc71ae3a09cd4a5614998f26f43)) + +## [2.0.0-rc.143](https://github.com/wundergraph/graphql-go-tools/compare/v2.0.0-rc.142...v2.0.0-rc.143) (2025-01-23) + + +### Bug Fixes + +* delete leftover heartbeat connections ([#1033](https://github.com/wundergraph/graphql-go-tools/issues/1033)) ([f7492d3](https://github.com/wundergraph/graphql-go-tools/commit/f7492d39b044f4901f695fb1e7718c9fe912504c)) + ## [2.0.0-rc.142](https://github.com/wundergraph/graphql-go-tools/compare/v2.0.0-rc.141...v2.0.0-rc.142) (2025-01-19) diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 8e97d714d..6d5b9559b 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -193,6 +193,8 @@ type PhaseStats struct { DurationSinceStartPretty string `json:"duration_since_start_pretty"` } +type requestContextKey struct{} + func SetTraceStart(ctx context.Context, predictableDebugTimings bool) context.Context { info := &TraceInfo{} if predictableDebugTimings { @@ -267,3 +269,16 @@ func SetPlannerStats(ctx context.Context, stats PhaseStats) { } info.PlannerStats = SetDebugStats(info, stats, 4) } + +func GetRequest(ctx context.Context) *RequestData { + // The context might not have trace info, in that case we return nil + req, ok := ctx.Value(requestContextKey{}).(*RequestData) + if !ok { + return nil + } + return req +} + +func SetRequest(ctx context.Context, r *RequestData) context.Context { + return context.WithValue(ctx, requestContextKey{}, r) +} diff --git a/v2/pkg/engine/resolve/fetchtree.go b/v2/pkg/engine/resolve/fetchtree.go index dddb019a5..5c73d291d 100644 --- a/v2/pkg/engine/resolve/fetchtree.go +++ b/v2/pkg/engine/resolve/fetchtree.go @@ -8,9 +8,10 @@ import ( type FetchTreeNode struct { Kind FetchTreeNodeKind `json:"kind"` // Only set for subscription - Trigger *FetchTreeNode `json:"trigger"` - Item *FetchItem `json:"item"` - ChildNodes []*FetchTreeNode `json:"child_nodes"` + Trigger *FetchTreeNode `json:"trigger"` + Item *FetchItem `json:"item"` + ChildNodes []*FetchTreeNode `json:"child_nodes"` + NormalizedQuery string `json:"normalized_query"` } type FetchTreeNodeKind string @@ -147,11 +148,12 @@ func (n *FetchTreeNode) Trace() *FetchTreeTraceNode { } type FetchTreeQueryPlanNode struct { - Version string `json:"version,omitempty"` - Kind FetchTreeNodeKind `json:"kind"` - Trigger *FetchTreeQueryPlan `json:"trigger,omitempty"` - Children []*FetchTreeQueryPlanNode `json:"children,omitempty"` - Fetch *FetchTreeQueryPlan `json:"fetch,omitempty"` + Version string `json:"version,omitempty"` + Kind FetchTreeNodeKind `json:"kind"` + Trigger *FetchTreeQueryPlan `json:"trigger,omitempty"` + Children []*FetchTreeQueryPlanNode `json:"children,omitempty"` + Fetch *FetchTreeQueryPlan `json:"fetch,omitempty"` + NormalizedQuery string `json:"normalizedQuery,omitempty"` } type FetchTreeQueryPlan struct { @@ -194,7 +196,8 @@ func (n *FetchTreeNode) queryPlan() *FetchTreeQueryPlanNode { return nil } queryPlan := &FetchTreeQueryPlanNode{ - Kind: n.Kind, + Kind: n.Kind, + NormalizedQuery: n.NormalizedQuery, } switch n.Kind { case FetchTreeNodeKindSingle: diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 9fa0f52f8..b05b0af26 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -51,6 +51,7 @@ type Resolver struct { maxConcurrency chan struct{} triggers map[uint64]*trigger + heartbeatSubLock *sync.Mutex heartbeatSubscriptions map[*Context]*sub events chan subscriptionEvent triggerEventsSem *semaphore.Weighted @@ -189,6 +190,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, events: make(chan subscriptionEvent), triggers: make(map[uint64]*trigger), + heartbeatSubLock: &sync.Mutex{}, heartbeatSubscriptions: make(map[*Context]*sub), reporter: options.Reporter, asyncErrorWriter: options.AsyncErrorWriter, @@ -407,6 +409,9 @@ func (r *Resolver) handleEvent(event subscriptionEvent) { } func (r *Resolver) handleHeartbeat(data []byte) { + r.heartbeatSubLock.Lock() + defer r.heartbeatSubLock.Unlock() + if r.options.Debug { fmt.Printf("resolver:heartbeat:%d\n", len(r.heartbeatSubscriptions)) } @@ -417,7 +422,7 @@ func (r *Resolver) handleHeartbeat(data []byte) { s.mux.Lock() skipHeartbeat := now.Sub(s.lastWrite) < r.multipartSubHeartbeatInterval s.mux.Unlock() - if skipHeartbeat { + if skipHeartbeat || (c.Context().Err() != nil && errors.Is(c.Context().Err(), context.Canceled)) { continue } @@ -427,6 +432,12 @@ func (r *Resolver) handleHeartbeat(data []byte) { s.mux.Lock() if _, err := s.writer.Write(data); err != nil { + if errors.Is(err, context.Canceled) { + // client disconnected + s.mux.Unlock() + _ = r.AsyncUnsubscribeSubscription(s.id) + return + } r.asyncErrorWriter.WriteError(c, err, nil, s.writer) } err := s.writer.Flush() @@ -466,30 +477,7 @@ func (r *Resolver) handleTriggerInitialized(triggerID uint64) { } func (r *Resolver) handleTriggerDone(triggerID uint64) { - trig, ok := r.triggers[triggerID] - if !ok { - return - } - isInitialized := trig.initialized - wg := trig.inFlight - subscriptionCount := len(trig.subscriptions) - - delete(r.triggers, triggerID) - - go func() { - if wg != nil { - wg.Wait() - } - for _, s := range trig.subscriptions { - s.writer.Complete() - } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(subscriptionCount) - if isInitialized { - r.reporter.TriggerCountDec(1) - } - } - }() + r.shutdownTrigger(triggerID) } func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) { @@ -508,7 +496,9 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) executor: add.executor, } if add.ctx.ExecutionOptions.SendHeartbeat { + r.heartbeatSubLock.Lock() r.heartbeatSubscriptions[add.ctx] = s + r.heartbeatSubLock.Unlock() } trig, ok := r.triggers[triggerID] if ok { @@ -634,20 +624,9 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) { removed := 0 for u := range r.triggers { trig := r.triggers[u] - for ctx, s := range trig.subscriptions { - if s.id == id { - - if ctx.Context().Err() == nil { - s.writer.Complete() - } - delete(r.heartbeatSubscriptions, ctx) - delete(trig.subscriptions, ctx) - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:removed:%d:%d\n", trig.id, id.SubscriptionID) - } - removed++ - } - } + removed += r.shutdownTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool { + return sID == id + }) if len(trig.subscriptions) == 0 { r.shutdownTrigger(trig.id) } @@ -663,20 +642,9 @@ func (r *Resolver) handleRemoveClient(id int64) { } removed := 0 for u := range r.triggers { - for c, s := range r.triggers[u].subscriptions { - if s.id.ConnectionID == id && !s.id.internal { - - if c.Context().Err() == nil { - s.writer.Complete() - } - - delete(r.triggers[u].subscriptions, c) - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", u, s.id.SubscriptionID) - } - removed++ - } - } + removed += r.shutdownTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool { + return sID.ConnectionID == id && !sID.internal + }) if len(r.triggers[u].subscriptions) == 0 { r.shutdownTrigger(r.triggers[u].id) } @@ -737,30 +705,46 @@ func (r *Resolver) shutdownTrigger(id uint64) { return } count := len(trig.subscriptions) + r.shutdownTriggerSubscriptions(id, nil) + trig.cancel() + delete(r.triggers, id) + if r.options.Debug { + fmt.Printf("resolver:trigger:done:%d\n", trig.id) + } + if r.reporter != nil { + r.reporter.SubscriptionCountDec(count) + if trig.initialized { + r.reporter.TriggerCountDec(1) + } + } +} + +func (r *Resolver) shutdownTriggerSubscriptions(id uint64, shutdownMatcher func(a SubscriptionIdentifier) bool) int { + trig, ok := r.triggers[id] + if !ok { + return 0 + } + removed := 0 for c, s := range trig.subscriptions { + if shutdownMatcher != nil && !shutdownMatcher(s.id) { + continue + } if c.Context().Err() == nil { s.writer.Complete() } if s.completed != nil { close(s.completed) } + r.heartbeatSubLock.Lock() delete(r.heartbeatSubscriptions, c) + r.heartbeatSubLock.Unlock() delete(trig.subscriptions, c) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", trig.id, s.id.SubscriptionID) } + removed++ } - trig.cancel() - delete(r.triggers, id) - if r.options.Debug { - fmt.Printf("resolver:trigger:done:%d\n", trig.id) - } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(count) - if trig.initialized { - r.reporter.TriggerCountDec(1) - } - } + return removed } func (r *Resolver) handleShutdown() { @@ -819,11 +803,6 @@ func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { } func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQLSubscription, writer SubscriptionResponseWriter) error { - if err := r.triggerEventsSem.Acquire(r.ctx, 1); err != nil { - return err - } - defer r.triggerEventsSem.Release(1) - if subscription.Trigger.Source == nil { return errors.New("no data source found") } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 54727feac..51febd082 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -86,13 +86,19 @@ func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLRespon } } +var multipartSubHeartbeatInterval = 15 * time.Millisecond + +const testMaxSubscriptionWorkers = 1 + func newResolver(ctx context.Context) *Resolver { return New(ctx, ResolverOptions{ - MaxConcurrency: 1024, - Debug: false, - PropagateSubgraphErrors: true, - PropagateSubgraphStatusCodes: true, - AsyncErrorWriter: &TestErrorWriter{}, + MaxConcurrency: 1024, + Debug: false, + PropagateSubgraphErrors: true, + PropagateSubgraphStatusCodes: true, + AsyncErrorWriter: &TestErrorWriter{}, + MultipartSubHeartbeatInterval: multipartSubHeartbeatInterval, + MaxSubscriptionWorkers: testMaxSubscriptionWorkers, }) } @@ -4855,6 +4861,8 @@ func createFakeStream(messageFunc messageFunc, delay time.Duration, onStart func type messageFunc func(counter int) (message string, done bool) +var fakeStreamRequestId atomic.Int32 + type _fakeStream struct { messageFunc messageFunc onStart func(input []byte) @@ -4877,7 +4885,7 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.WriteString("fakeStream") + _, err = xxh.WriteString(fmt.Sprintf("%d", fakeStreamRequestId.Add(1))) if err != nil { return } @@ -5164,12 +5172,18 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { ctx := &Context{ ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, } err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) assert.NoError(t, err) + recorder.AwaitComplete(t, defaultTimeout) assert.Equal(t, 3, len(recorder.Messages())) + time.Sleep(2 * resolver.multipartSubHeartbeatInterval) + // Validate that despite the time, we don't see any heartbeats sent assert.ElementsMatch(t, []string{ `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, @@ -5177,6 +5191,48 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { }, recorder.Messages()) }) + t.Run("should successfully delete multiple finished subscriptions", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 1 + }, 1*time.Millisecond, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + for i := 1; i <= 10; i++ { + id.ConnectionID = int64(i) + id.SubscriptionID = int64(i) + recorder.complete.Store(false) + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + recorder.AwaitComplete(t, defaultTimeout) + } + + recorder.AwaitComplete(t, defaultTimeout) + time.Sleep(2 * resolver.multipartSubHeartbeatInterval) + + assert.Equal(t, 20, len(recorder.Messages())) + // Validate that despite the time, we don't see any heartbeats sent + assert.ElementsMatch(t, []string{ + `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, + `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, + `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, + `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, + `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`, + }, recorder.Messages()) + }) + t.Run("should propagate extensions to stream", func(t *testing.T) { c, cancel := context.WithCancel(context.Background()) defer cancel() @@ -5340,6 +5396,60 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { `{"data":null,"extensions":{"queryPlan":{"version":"1","kind":"Sequence","trigger":{"kind":"Trigger","path":"countryUpdated","subgraphName":"country","subgraphId":"0","fetchId":0,"query":"subscription { countryUpdated { name time { local } } }"},"children":[{"kind":"Single","fetch":{"kind":"Single","path":"countryUpdated.time","subgraphName":"time","subgraphId":"1","fetchId":1,"dependsOnFetchIds":[0],"query":"query($representations: [_Any!]!){\n _entities(representations: $representations){\n ... on Time {\n __typename\n local\n }\n }\n}"}}]}}}`, }, recorder.Messages()) }) + + t.Run("should successfully allow more than one subscription using http multipart", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false + }, 0, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }) + + resolver, plan, _, _ := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + const numSubscriptions = testMaxSubscriptionWorkers + 1 + var resolverCompleted atomic.Uint32 + var recorderCompleted atomic.Uint32 + for i := 0; i < numSubscriptions; i++ { + recorder := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + recorder.complete.Store(false) + + go func() { + defer recorderCompleted.Add(1) + + recorder.AwaitAnyMessageCount(t, defaultTimeout) + }() + + go func() { + defer resolverCompleted.Add(1) + + err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder) + assert.ErrorIs(t, err, context.Canceled) + }() + } + assert.Eventually(t, func() bool { + return recorderCompleted.Load() == numSubscriptions + }, defaultTimeout, time.Millisecond*100) + + cancel() + + assert.Eventually(t, func() bool { + return resolverCompleted.Load() == numSubscriptions + }, defaultTimeout, time.Millisecond*100) + }) } func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { diff --git a/v2/pkg/engine/resolve/trace.go b/v2/pkg/engine/resolve/trace.go index 6043ae327..ea04e73ec 100644 --- a/v2/pkg/engine/resolve/trace.go +++ b/v2/pkg/engine/resolve/trace.go @@ -2,6 +2,8 @@ package resolve import ( "context" + "encoding/json" + "net/http" ) type TraceOptions struct { @@ -59,16 +61,36 @@ func (r *TraceOptions) DisableAll() { r.IncludeTraceOutputInResponseExtensions = false } +type BodyData struct { + Query string `json:"query,omitempty"` + OperationName string `json:"operationName,omitempty"` + Variables json.RawMessage `json:"variables,omitempty"` +} + +type RequestData struct { + Method string `json:"method"` + URL string `json:"url"` + Headers http.Header `json:"headers"` + Body BodyData `json:"body,omitempty"` +} + type TraceData struct { Version string `json:"version"` Info *TraceInfo `json:"info"` Fetches *FetchTreeTraceNode `json:"fetches"` + Request *RequestData `json:"request,omitempty"` } func GetTrace(ctx context.Context, fetchTree *FetchTreeNode) TraceData { - return TraceData{ + trace := TraceData{ Version: "1", Info: GetTraceInfo(ctx), Fetches: fetchTree.Trace(), } + + if req := GetRequest(ctx); req != nil { + trace.Request = req + } + + return trace }