Skip to content

Commit f7492d3

Browse files
authored
fix: delete leftover heartbeat connections (#1033)
Users reported seeing the below sigfault when a connection to a subgraph is interrupted over multipart: ``` cosmo-router | 21:07:16 PM ERROR core/graphql_handler.go:380 Unable to write error response {"hostname": "ad1991cdcfd2", "pid": 1, "component": "@wundergraph/router", "service_version": "0.158.0", "request_id": "ad1991cdcfd2/RTnCLg1J8M-000022", "trace_id": "77476c6fe295cd7dafeb71e00183879b", "error": "context canceled"} cosmo-router | github.com/wundergraph/cosmo/router/core.(*GraphQLHandler).WriteError cosmo-router | github.com/wundergraph/cosmo/router/core/graphql_handler.go:380 cosmo-router | github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve.(*Resolver).handleHeartbeat.func1 cosmo-router | github.com/wundergraph/graphql-go-tools/v2@v2.0.0-rc.136/pkg/engine/resolve/resolve.go:431 ``` After investigating, the cause seemed to be a number of times where we deleted the subscription trigger but didn't clean up the heartbeat (which is running in a separate thread), causing it to write on a non-existent context. This PR cleans that up
1 parent a9c873f commit f7492d3

File tree

2 files changed

+104
-69
lines changed

2 files changed

+104
-69
lines changed

v2/pkg/engine/resolve/resolve.go

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type Resolver struct {
5151
maxConcurrency chan struct{}
5252

5353
triggers map[uint64]*trigger
54+
heartbeatSubLock *sync.Mutex
5455
heartbeatSubscriptions map[*Context]*sub
5556
events chan subscriptionEvent
5657
triggerEventsSem *semaphore.Weighted
@@ -189,6 +190,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver {
189190
propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes,
190191
events: make(chan subscriptionEvent),
191192
triggers: make(map[uint64]*trigger),
193+
heartbeatSubLock: &sync.Mutex{},
192194
heartbeatSubscriptions: make(map[*Context]*sub),
193195
reporter: options.Reporter,
194196
asyncErrorWriter: options.AsyncErrorWriter,
@@ -407,6 +409,9 @@ func (r *Resolver) handleEvent(event subscriptionEvent) {
407409
}
408410

409411
func (r *Resolver) handleHeartbeat(data []byte) {
412+
r.heartbeatSubLock.Lock()
413+
defer r.heartbeatSubLock.Unlock()
414+
410415
if r.options.Debug {
411416
fmt.Printf("resolver:heartbeat:%d\n", len(r.heartbeatSubscriptions))
412417
}
@@ -417,7 +422,7 @@ func (r *Resolver) handleHeartbeat(data []byte) {
417422
s.mux.Lock()
418423
skipHeartbeat := now.Sub(s.lastWrite) < r.multipartSubHeartbeatInterval
419424
s.mux.Unlock()
420-
if skipHeartbeat {
425+
if skipHeartbeat || (c.Context().Err() != nil && errors.Is(c.Context().Err(), context.Canceled)) {
421426
continue
422427
}
423428

@@ -428,6 +433,12 @@ func (r *Resolver) handleHeartbeat(data []byte) {
428433

429434
s.mux.Lock()
430435
if _, err := s.writer.Write(data); err != nil {
436+
if errors.Is(err, context.Canceled) {
437+
// client disconnected
438+
s.mux.Unlock()
439+
_ = r.AsyncUnsubscribeSubscription(s.id)
440+
return
441+
}
431442
r.asyncErrorWriter.WriteError(c, err, nil, s.writer)
432443
}
433444
err := s.writer.Flush()
@@ -468,30 +479,7 @@ func (r *Resolver) handleTriggerInitialized(triggerID uint64) {
468479
}
469480

470481
func (r *Resolver) handleTriggerDone(triggerID uint64) {
471-
trig, ok := r.triggers[triggerID]
472-
if !ok {
473-
return
474-
}
475-
isInitialized := trig.initialized
476-
wg := trig.inFlight
477-
subscriptionCount := len(trig.subscriptions)
478-
479-
delete(r.triggers, triggerID)
480-
481-
go func() {
482-
if wg != nil {
483-
wg.Wait()
484-
}
485-
for _, s := range trig.subscriptions {
486-
s.writer.Complete()
487-
}
488-
if r.reporter != nil {
489-
r.reporter.SubscriptionCountDec(subscriptionCount)
490-
if isInitialized {
491-
r.reporter.TriggerCountDec(1)
492-
}
493-
}
494-
}()
482+
r.shutdownTrigger(triggerID)
495483
}
496484

497485
func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) {
@@ -510,7 +498,9 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
510498
executor: add.executor,
511499
}
512500
if add.ctx.ExecutionOptions.SendHeartbeat {
501+
r.heartbeatSubLock.Lock()
513502
r.heartbeatSubscriptions[add.ctx] = s
503+
r.heartbeatSubLock.Unlock()
514504
}
515505
trig, ok := r.triggers[triggerID]
516506
if ok {
@@ -636,20 +626,9 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) {
636626
removed := 0
637627
for u := range r.triggers {
638628
trig := r.triggers[u]
639-
for ctx, s := range trig.subscriptions {
640-
if s.id == id {
641-
642-
if ctx.Context().Err() == nil {
643-
s.writer.Complete()
644-
}
645-
delete(r.heartbeatSubscriptions, ctx)
646-
delete(trig.subscriptions, ctx)
647-
if r.options.Debug {
648-
fmt.Printf("resolver:trigger:subscription:removed:%d:%d\n", trig.id, id.SubscriptionID)
649-
}
650-
removed++
651-
}
652-
}
629+
removed += r.shutdownTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool {
630+
return sID == id
631+
})
653632
if len(trig.subscriptions) == 0 {
654633
r.shutdownTrigger(trig.id)
655634
}
@@ -665,20 +644,9 @@ func (r *Resolver) handleRemoveClient(id int64) {
665644
}
666645
removed := 0
667646
for u := range r.triggers {
668-
for c, s := range r.triggers[u].subscriptions {
669-
if s.id.ConnectionID == id && !s.id.internal {
670-
671-
if c.Context().Err() == nil {
672-
s.writer.Complete()
673-
}
674-
675-
delete(r.triggers[u].subscriptions, c)
676-
if r.options.Debug {
677-
fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", u, s.id.SubscriptionID)
678-
}
679-
removed++
680-
}
681-
}
647+
removed += r.shutdownTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool {
648+
return sID.ConnectionID == id && !sID.internal
649+
})
682650
if len(r.triggers[u].subscriptions) == 0 {
683651
r.shutdownTrigger(r.triggers[u].id)
684652
}
@@ -739,30 +707,46 @@ func (r *Resolver) shutdownTrigger(id uint64) {
739707
return
740708
}
741709
count := len(trig.subscriptions)
710+
r.shutdownTriggerSubscriptions(id, nil)
711+
trig.cancel()
712+
delete(r.triggers, id)
713+
if r.options.Debug {
714+
fmt.Printf("resolver:trigger:done:%d\n", trig.id)
715+
}
716+
if r.reporter != nil {
717+
r.reporter.SubscriptionCountDec(count)
718+
if trig.initialized {
719+
r.reporter.TriggerCountDec(1)
720+
}
721+
}
722+
}
723+
724+
func (r *Resolver) shutdownTriggerSubscriptions(id uint64, shutdownMatcher func(a SubscriptionIdentifier) bool) int {
725+
trig, ok := r.triggers[id]
726+
if !ok {
727+
return 0
728+
}
729+
removed := 0
742730
for c, s := range trig.subscriptions {
731+
if shutdownMatcher != nil && !shutdownMatcher(s.id) {
732+
continue
733+
}
743734
if c.Context().Err() == nil {
744735
s.writer.Complete()
745736
}
746737
if s.completed != nil {
747738
close(s.completed)
748739
}
740+
r.heartbeatSubLock.Lock()
749741
delete(r.heartbeatSubscriptions, c)
742+
r.heartbeatSubLock.Unlock()
750743
delete(trig.subscriptions, c)
751744
if r.options.Debug {
752745
fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", trig.id, s.id.SubscriptionID)
753746
}
747+
removed++
754748
}
755-
trig.cancel()
756-
delete(r.triggers, id)
757-
if r.options.Debug {
758-
fmt.Printf("resolver:trigger:done:%d\n", trig.id)
759-
}
760-
if r.reporter != nil {
761-
r.reporter.SubscriptionCountDec(count)
762-
if trig.initialized {
763-
r.reporter.TriggerCountDec(1)
764-
}
765-
}
749+
return removed
766750
}
767751

768752
func (r *Resolver) handleShutdown() {

v2/pkg/engine/resolve/resolve_test.go

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,16 @@ func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLRespon
8686
}
8787
}
8888

89+
var multipartSubHeartbeatInterval = 15 * time.Millisecond
90+
8991
func newResolver(ctx context.Context) *Resolver {
9092
return New(ctx, ResolverOptions{
91-
MaxConcurrency: 1024,
92-
Debug: false,
93-
PropagateSubgraphErrors: true,
94-
PropagateSubgraphStatusCodes: true,
95-
AsyncErrorWriter: &TestErrorWriter{},
93+
MaxConcurrency: 1024,
94+
Debug: false,
95+
PropagateSubgraphErrors: true,
96+
PropagateSubgraphStatusCodes: true,
97+
AsyncErrorWriter: &TestErrorWriter{},
98+
MultipartSubHeartbeatInterval: multipartSubHeartbeatInterval,
9699
})
97100
}
98101

@@ -5164,19 +5167,67 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
51645167

51655168
ctx := &Context{
51665169
ctx: context.Background(),
5170+
ExecutionOptions: ExecutionOptions{
5171+
SendHeartbeat: true,
5172+
},
51675173
}
51685174

51695175
err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
51705176
assert.NoError(t, err)
5177+
51715178
recorder.AwaitComplete(t, defaultTimeout)
51725179
assert.Equal(t, 3, len(recorder.Messages()))
5180+
time.Sleep(2 * resolver.multipartSubHeartbeatInterval)
5181+
// Validate that despite the time, we don't see any heartbeats sent
51735182
assert.ElementsMatch(t, []string{
51745183
`{"data":{"counter":0}}`,
51755184
`{"data":{"counter":1}}`,
51765185
`{"data":{"counter":2}}`,
51775186
}, recorder.Messages())
51785187
})
51795188

5189+
t.Run("should successfully delete multiple finished subscriptions", func(t *testing.T) {
5190+
c, cancel := context.WithCancel(context.Background())
5191+
defer cancel()
5192+
5193+
fakeStream := createFakeStream(func(counter int) (message string, done bool) {
5194+
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 1
5195+
}, 1*time.Millisecond, func(input []byte) {
5196+
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
5197+
})
5198+
5199+
resolver, plan, recorder, id := setup(c, fakeStream)
5200+
5201+
ctx := &Context{
5202+
ctx: context.Background(),
5203+
ExecutionOptions: ExecutionOptions{
5204+
SendHeartbeat: true,
5205+
},
5206+
}
5207+
5208+
for i := 1; i <= 10; i++ {
5209+
id.ConnectionID = int64(i)
5210+
id.SubscriptionID = int64(i)
5211+
recorder.complete.Store(false)
5212+
err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
5213+
assert.NoError(t, err)
5214+
recorder.AwaitComplete(t, defaultTimeout)
5215+
}
5216+
5217+
recorder.AwaitComplete(t, defaultTimeout)
5218+
time.Sleep(2 * resolver.multipartSubHeartbeatInterval)
5219+
5220+
assert.Equal(t, 20, len(recorder.Messages()))
5221+
// Validate that despite the time, we don't see any heartbeats sent
5222+
assert.ElementsMatch(t, []string{
5223+
`{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
5224+
`{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
5225+
`{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
5226+
`{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
5227+
`{"data":{"counter":0}}`, `{"data":{"counter":1}}`, `{"data":{"counter":0}}`, `{"data":{"counter":1}}`,
5228+
}, recorder.Messages())
5229+
})
5230+
51805231
t.Run("should propagate extensions to stream", func(t *testing.T) {
51815232
c, cancel := context.WithCancel(context.Background())
51825233
defer cancel()

0 commit comments

Comments
 (0)