@@ -51,6 +51,7 @@ type Resolver struct {
51
51
maxConcurrency chan struct {}
52
52
53
53
triggers map [uint64 ]* trigger
54
+ heartbeatSubLock * sync.Mutex
54
55
heartbeatSubscriptions map [* Context ]* sub
55
56
events chan subscriptionEvent
56
57
triggerEventsSem * semaphore.Weighted
@@ -189,6 +190,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver {
189
190
propagateSubgraphStatusCodes : options .PropagateSubgraphStatusCodes ,
190
191
events : make (chan subscriptionEvent ),
191
192
triggers : make (map [uint64 ]* trigger ),
193
+ heartbeatSubLock : & sync.Mutex {},
192
194
heartbeatSubscriptions : make (map [* Context ]* sub ),
193
195
reporter : options .Reporter ,
194
196
asyncErrorWriter : options .AsyncErrorWriter ,
@@ -407,6 +409,9 @@ func (r *Resolver) handleEvent(event subscriptionEvent) {
407
409
}
408
410
409
411
func (r * Resolver ) handleHeartbeat (data []byte ) {
412
+ r .heartbeatSubLock .Lock ()
413
+ defer r .heartbeatSubLock .Unlock ()
414
+
410
415
if r .options .Debug {
411
416
fmt .Printf ("resolver:heartbeat:%d\n " , len (r .heartbeatSubscriptions ))
412
417
}
@@ -417,7 +422,7 @@ func (r *Resolver) handleHeartbeat(data []byte) {
417
422
s .mux .Lock ()
418
423
skipHeartbeat := now .Sub (s .lastWrite ) < r .multipartSubHeartbeatInterval
419
424
s .mux .Unlock ()
420
- if skipHeartbeat {
425
+ if skipHeartbeat || ( c . Context (). Err () != nil && errors . Is ( c . Context (). Err (), context . Canceled )) {
421
426
continue
422
427
}
423
428
@@ -428,6 +433,12 @@ func (r *Resolver) handleHeartbeat(data []byte) {
428
433
429
434
s .mux .Lock ()
430
435
if _ , err := s .writer .Write (data ); err != nil {
436
+ if errors .Is (err , context .Canceled ) {
437
+ // client disconnected
438
+ s .mux .Unlock ()
439
+ _ = r .AsyncUnsubscribeSubscription (s .id )
440
+ return
441
+ }
431
442
r .asyncErrorWriter .WriteError (c , err , nil , s .writer )
432
443
}
433
444
err := s .writer .Flush ()
@@ -468,30 +479,7 @@ func (r *Resolver) handleTriggerInitialized(triggerID uint64) {
468
479
}
469
480
470
481
func (r * Resolver ) handleTriggerDone (triggerID uint64 ) {
471
- trig , ok := r .triggers [triggerID ]
472
- if ! ok {
473
- return
474
- }
475
- isInitialized := trig .initialized
476
- wg := trig .inFlight
477
- subscriptionCount := len (trig .subscriptions )
478
-
479
- delete (r .triggers , triggerID )
480
-
481
- go func () {
482
- if wg != nil {
483
- wg .Wait ()
484
- }
485
- for _ , s := range trig .subscriptions {
486
- s .writer .Complete ()
487
- }
488
- if r .reporter != nil {
489
- r .reporter .SubscriptionCountDec (subscriptionCount )
490
- if isInitialized {
491
- r .reporter .TriggerCountDec (1 )
492
- }
493
- }
494
- }()
482
+ r .shutdownTrigger (triggerID )
495
483
}
496
484
497
485
func (r * Resolver ) handleAddSubscription (triggerID uint64 , add * addSubscription ) {
@@ -510,7 +498,9 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
510
498
executor : add .executor ,
511
499
}
512
500
if add .ctx .ExecutionOptions .SendHeartbeat {
501
+ r .heartbeatSubLock .Lock ()
513
502
r .heartbeatSubscriptions [add .ctx ] = s
503
+ r .heartbeatSubLock .Unlock ()
514
504
}
515
505
trig , ok := r .triggers [triggerID ]
516
506
if ok {
@@ -636,20 +626,9 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) {
636
626
removed := 0
637
627
for u := range r .triggers {
638
628
trig := r .triggers [u ]
639
- for ctx , s := range trig .subscriptions {
640
- if s .id == id {
641
-
642
- if ctx .Context ().Err () == nil {
643
- s .writer .Complete ()
644
- }
645
- delete (r .heartbeatSubscriptions , ctx )
646
- delete (trig .subscriptions , ctx )
647
- if r .options .Debug {
648
- fmt .Printf ("resolver:trigger:subscription:removed:%d:%d\n " , trig .id , id .SubscriptionID )
649
- }
650
- removed ++
651
- }
652
- }
629
+ removed += r .shutdownTriggerSubscriptions (u , func (sID SubscriptionIdentifier ) bool {
630
+ return sID == id
631
+ })
653
632
if len (trig .subscriptions ) == 0 {
654
633
r .shutdownTrigger (trig .id )
655
634
}
@@ -665,20 +644,9 @@ func (r *Resolver) handleRemoveClient(id int64) {
665
644
}
666
645
removed := 0
667
646
for u := range r .triggers {
668
- for c , s := range r .triggers [u ].subscriptions {
669
- if s .id .ConnectionID == id && ! s .id .internal {
670
-
671
- if c .Context ().Err () == nil {
672
- s .writer .Complete ()
673
- }
674
-
675
- delete (r .triggers [u ].subscriptions , c )
676
- if r .options .Debug {
677
- fmt .Printf ("resolver:trigger:subscription:done:%d:%d\n " , u , s .id .SubscriptionID )
678
- }
679
- removed ++
680
- }
681
- }
647
+ removed += r .shutdownTriggerSubscriptions (u , func (sID SubscriptionIdentifier ) bool {
648
+ return sID .ConnectionID == id && ! sID .internal
649
+ })
682
650
if len (r .triggers [u ].subscriptions ) == 0 {
683
651
r .shutdownTrigger (r .triggers [u ].id )
684
652
}
@@ -739,30 +707,46 @@ func (r *Resolver) shutdownTrigger(id uint64) {
739
707
return
740
708
}
741
709
count := len (trig .subscriptions )
710
+ r .shutdownTriggerSubscriptions (id , nil )
711
+ trig .cancel ()
712
+ delete (r .triggers , id )
713
+ if r .options .Debug {
714
+ fmt .Printf ("resolver:trigger:done:%d\n " , trig .id )
715
+ }
716
+ if r .reporter != nil {
717
+ r .reporter .SubscriptionCountDec (count )
718
+ if trig .initialized {
719
+ r .reporter .TriggerCountDec (1 )
720
+ }
721
+ }
722
+ }
723
+
724
+ func (r * Resolver ) shutdownTriggerSubscriptions (id uint64 , shutdownMatcher func (a SubscriptionIdentifier ) bool ) int {
725
+ trig , ok := r .triggers [id ]
726
+ if ! ok {
727
+ return 0
728
+ }
729
+ removed := 0
742
730
for c , s := range trig .subscriptions {
731
+ if shutdownMatcher != nil && ! shutdownMatcher (s .id ) {
732
+ continue
733
+ }
743
734
if c .Context ().Err () == nil {
744
735
s .writer .Complete ()
745
736
}
746
737
if s .completed != nil {
747
738
close (s .completed )
748
739
}
740
+ r .heartbeatSubLock .Lock ()
749
741
delete (r .heartbeatSubscriptions , c )
742
+ r .heartbeatSubLock .Unlock ()
750
743
delete (trig .subscriptions , c )
751
744
if r .options .Debug {
752
745
fmt .Printf ("resolver:trigger:subscription:done:%d:%d\n " , trig .id , s .id .SubscriptionID )
753
746
}
747
+ removed ++
754
748
}
755
- trig .cancel ()
756
- delete (r .triggers , id )
757
- if r .options .Debug {
758
- fmt .Printf ("resolver:trigger:done:%d\n " , trig .id )
759
- }
760
- if r .reporter != nil {
761
- r .reporter .SubscriptionCountDec (count )
762
- if trig .initialized {
763
- r .reporter .TriggerCountDec (1 )
764
- }
765
- }
749
+ return removed
766
750
}
767
751
768
752
func (r * Resolver ) handleShutdown () {
0 commit comments