diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index b05b0af26..7ffc23b78 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -528,6 +528,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) id: triggerID, subscriptions: make(map[*Context]*sub), cancel: cancel, + inFlight: &sync.WaitGroup{}, } r.triggers[triggerID] = trig trig.subscriptions[add.ctx] = s @@ -662,8 +663,6 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { if r.options.Debug { fmt.Printf("resolver:trigger:update:%d\n", id) } - wg := &sync.WaitGroup{} - trig.inFlight = wg for c, s := range trig.subscriptions { c, s := c, s if err := c.ctx.Err(); err != nil { @@ -677,12 +676,12 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { if skip { continue } - wg.Add(1) + trig.inFlight.Add(1) fn := func() { r.executeSubscriptionUpdate(c, s, data) } go func(fn func()) { - defer wg.Done() + defer trig.inFlight.Done() if s.executor != nil { select { case <-r.ctx.Done(): @@ -704,6 +703,7 @@ func (r *Resolver) shutdownTrigger(id uint64) { if !ok { return } + trig.inFlight.Wait() count := len(trig.subscriptions) r.shutdownTriggerSubscriptions(id, nil) trig.cancel() diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 51febd082..d1bc62852 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -4777,6 +4777,7 @@ type SubscriptionRecorder struct { messages []string complete atomic.Bool mux sync.Mutex + onFlush func(p []byte) } func (s *SubscriptionRecorder) AwaitMessages(t *testing.T, count int, timeout time.Duration) { @@ -4834,6 +4835,9 @@ func (s *SubscriptionRecorder) Write(p []byte) (n int, err error) { } func (s *SubscriptionRecorder) Flush() error { + if s.onFlush != nil { + s.onFlush(s.buf.Bytes()) + } s.mux.Lock() defer s.mux.Unlock() s.messages = append(s.messages, s.buf.String()) @@ -5450,6 +5454,55 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { return resolverCompleted.Load() == numSubscriptions }, defaultTimeout, time.Millisecond*100) }) + + t.Run("should wait for all in flight operations to be completed", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + var started atomic.Bool + var complete atomic.Bool + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + defer started.Store(true) + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), true + }, 0, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }) + + resolver, plan, _, id := setup(c, fakeStream) + recorder := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + onFlush: func(p []byte) { + for !complete.Load() { + time.Sleep(time.Millisecond * 10) + } + }, + } + recorder.complete.Store(false) + + ctx := Context{ + ctx: context.Background(), + } + + err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return started.Load() + }, defaultTimeout, time.Millisecond*100) + var unsubscribeComplete atomic.Bool + go func() { + defer unsubscribeComplete.Store(true) + err = resolver.AsyncUnsubscribeSubscription(id) + assert.NoError(t, err) + }() + assert.Len(t, resolver.triggers, 1) + complete.Store(true) + assert.Eventually(t, unsubscribeComplete.Load, defaultTimeout, time.Millisecond*100) + recorder.AwaitComplete(t, defaultTimeout) + assert.Len(t, resolver.triggers, 0) + }) } func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {