diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 7b5f5d6f79..0167287198 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -51,7 +51,7 @@ type Resolver struct { maxConcurrency chan struct{} triggers map[uint64]*trigger - heartbeatSubLock *sync.RWMutex + heartbeatSubLock *sync.Mutex heartbeatSubscriptions map[*Context]*sub events chan subscriptionEvent triggerEventsSem *semaphore.Weighted @@ -190,7 +190,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, events: make(chan subscriptionEvent), triggers: make(map[uint64]*trigger), - heartbeatSubLock: &sync.RWMutex{}, + heartbeatSubLock: &sync.Mutex{}, heartbeatSubscriptions: make(map[*Context]*sub), reporter: options.Reporter, asyncErrorWriter: options.AsyncErrorWriter, @@ -479,36 +479,7 @@ func (r *Resolver) handleTriggerInitialized(triggerID uint64) { } func (r *Resolver) handleTriggerDone(triggerID uint64) { - trig, ok := r.triggers[triggerID] - if !ok { - return - } - isInitialized := trig.initialized - wg := trig.inFlight - subscriptionCount := len(trig.subscriptions) - - delete(r.triggers, triggerID) - - go func() { - if wg != nil { - wg.Wait() - } - for c, s := range trig.subscriptions { - s.writer.Complete() - s.mux.Lock() - r.heartbeatSubLock.Lock() - delete(r.heartbeatSubscriptions, c) - r.heartbeatSubLock.Unlock() - delete(trig.subscriptions, c) - s.mux.Unlock() - } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(subscriptionCount) - if isInitialized { - r.reporter.TriggerCountDec(1) - } - } - }() + r.shutdownTrigger(triggerID) } func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) { @@ -655,22 +626,9 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) { removed := 0 for u := range r.triggers { trig := r.triggers[u] - for ctx, s := range trig.subscriptions { - if s.id == id { - - if ctx.Context().Err() == nil { - s.writer.Complete() - } - r.heartbeatSubLock.Lock() - delete(r.heartbeatSubscriptions, ctx) - r.heartbeatSubLock.Unlock() - delete(trig.subscriptions, ctx) - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:removed:%d:%d\n", trig.id, id.SubscriptionID) - } - removed++ - } - } + removed += r.shutdownTriggerSubscriptions(u, func(a SubscriptionIdentifier) bool { + return a == id + }) if len(trig.subscriptions) == 0 { r.shutdownTrigger(trig.id) } @@ -686,22 +644,9 @@ func (r *Resolver) handleRemoveClient(id int64) { } removed := 0 for u := range r.triggers { - for c, s := range r.triggers[u].subscriptions { - if s.id.ConnectionID == id && !s.id.internal { - - if c.Context().Err() == nil { - s.writer.Complete() - } - r.heartbeatSubLock.Lock() - delete(r.heartbeatSubscriptions, c) - r.heartbeatSubLock.Unlock() - delete(r.triggers[u].subscriptions, c) - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", u, s.id.SubscriptionID) - } - removed++ - } - } + removed += r.shutdownTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool { + return sID.ConnectionID == id && !sID.internal + }) if len(r.triggers[u].subscriptions) == 0 { r.shutdownTrigger(r.triggers[u].id) } @@ -762,7 +707,30 @@ func (r *Resolver) shutdownTrigger(id uint64) { return } count := len(trig.subscriptions) + r.shutdownTriggerSubscriptions(id, nil) + trig.cancel() + delete(r.triggers, id) + if r.options.Debug { + fmt.Printf("resolver:trigger:done:%d\n", trig.id) + } + if r.reporter != nil { + r.reporter.SubscriptionCountDec(count) + if trig.initialized { + r.reporter.TriggerCountDec(1) + } + } +} + +func (r *Resolver) shutdownTriggerSubscriptions(id uint64, shutdownMatcher func(a SubscriptionIdentifier) bool) int { + trig, ok := r.triggers[id] + if !ok { + return 0 + } + removed := 0 for c, s := range trig.subscriptions { + if shutdownMatcher != nil && !shutdownMatcher(s.id) { + continue + } if c.Context().Err() == nil { s.writer.Complete() } @@ -776,18 +744,9 @@ func (r *Resolver) shutdownTrigger(id uint64) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", trig.id, s.id.SubscriptionID) } + removed++ } - trig.cancel() - delete(r.triggers, id) - if r.options.Debug { - fmt.Printf("resolver:trigger:done:%d\n", trig.id) - } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(count) - if trig.initialized { - r.reporter.TriggerCountDec(1) - } - } + return removed } func (r *Resolver) handleShutdown() {