Skip to content

Commit

Permalink
fix: always shutdown subscriptions in the same way
Browse files Browse the repository at this point in the history
  • Loading branch information
df-wg committed Jan 20, 2025
1 parent 780797e commit ebc619f
Showing 1 changed file with 34 additions and 75 deletions.
109 changes: 34 additions & 75 deletions v2/pkg/engine/resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Resolver struct {
maxConcurrency chan struct{}

triggers map[uint64]*trigger
heartbeatSubLock *sync.RWMutex
heartbeatSubLock *sync.Mutex
heartbeatSubscriptions map[*Context]*sub
events chan subscriptionEvent
triggerEventsSem *semaphore.Weighted
Expand Down Expand Up @@ -190,7 +190,7 @@ func New(ctx context.Context, options ResolverOptions) *Resolver {
propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes,
events: make(chan subscriptionEvent),
triggers: make(map[uint64]*trigger),
heartbeatSubLock: &sync.RWMutex{},
heartbeatSubLock: &sync.Mutex{},
heartbeatSubscriptions: make(map[*Context]*sub),
reporter: options.Reporter,
asyncErrorWriter: options.AsyncErrorWriter,
Expand Down Expand Up @@ -479,36 +479,7 @@ func (r *Resolver) handleTriggerInitialized(triggerID uint64) {
}

func (r *Resolver) handleTriggerDone(triggerID uint64) {
trig, ok := r.triggers[triggerID]
if !ok {
return
}
isInitialized := trig.initialized
wg := trig.inFlight
subscriptionCount := len(trig.subscriptions)

delete(r.triggers, triggerID)

go func() {
if wg != nil {
wg.Wait()
}
for c, s := range trig.subscriptions {
s.writer.Complete()
s.mux.Lock()
r.heartbeatSubLock.Lock()
delete(r.heartbeatSubscriptions, c)
r.heartbeatSubLock.Unlock()
delete(trig.subscriptions, c)
s.mux.Unlock()
}
if r.reporter != nil {
r.reporter.SubscriptionCountDec(subscriptionCount)
if isInitialized {
r.reporter.TriggerCountDec(1)
}
}
}()
r.shutdownTrigger(triggerID)
}

func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) {
Expand Down Expand Up @@ -655,22 +626,9 @@ func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) {
removed := 0
for u := range r.triggers {
trig := r.triggers[u]
for ctx, s := range trig.subscriptions {
if s.id == id {

if ctx.Context().Err() == nil {
s.writer.Complete()
}
r.heartbeatSubLock.Lock()
delete(r.heartbeatSubscriptions, ctx)
r.heartbeatSubLock.Unlock()
delete(trig.subscriptions, ctx)
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:removed:%d:%d\n", trig.id, id.SubscriptionID)
}
removed++
}
}
removed += r.shutdownTriggerSubscriptions(u, func(a SubscriptionIdentifier) bool {
return a == id
})
if len(trig.subscriptions) == 0 {
r.shutdownTrigger(trig.id)
}
Expand All @@ -686,22 +644,9 @@ func (r *Resolver) handleRemoveClient(id int64) {
}
removed := 0
for u := range r.triggers {
for c, s := range r.triggers[u].subscriptions {
if s.id.ConnectionID == id && !s.id.internal {

if c.Context().Err() == nil {
s.writer.Complete()
}
r.heartbeatSubLock.Lock()
delete(r.heartbeatSubscriptions, c)
r.heartbeatSubLock.Unlock()
delete(r.triggers[u].subscriptions, c)
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", u, s.id.SubscriptionID)
}
removed++
}
}
removed += r.shutdownTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool {
return sID.ConnectionID == id && !sID.internal
})
if len(r.triggers[u].subscriptions) == 0 {
r.shutdownTrigger(r.triggers[u].id)
}
Expand Down Expand Up @@ -762,7 +707,30 @@ func (r *Resolver) shutdownTrigger(id uint64) {
return
}
count := len(trig.subscriptions)
r.shutdownTriggerSubscriptions(id, nil)
trig.cancel()
delete(r.triggers, id)
if r.options.Debug {
fmt.Printf("resolver:trigger:done:%d\n", trig.id)
}
if r.reporter != nil {
r.reporter.SubscriptionCountDec(count)
if trig.initialized {
r.reporter.TriggerCountDec(1)
}
}
}

func (r *Resolver) shutdownTriggerSubscriptions(id uint64, shutdownMatcher func(a SubscriptionIdentifier) bool) int {
trig, ok := r.triggers[id]
if !ok {
return 0
}
removed := 0
for c, s := range trig.subscriptions {
if shutdownMatcher != nil && !shutdownMatcher(s.id) {
continue
}
if c.Context().Err() == nil {
s.writer.Complete()
}
Expand All @@ -776,18 +744,9 @@ func (r *Resolver) shutdownTrigger(id uint64) {
if r.options.Debug {
fmt.Printf("resolver:trigger:subscription:done:%d:%d\n", trig.id, s.id.SubscriptionID)
}
removed++
}
trig.cancel()
delete(r.triggers, id)
if r.options.Debug {
fmt.Printf("resolver:trigger:done:%d\n", trig.id)
}
if r.reporter != nil {
r.reporter.SubscriptionCountDec(count)
if trig.initialized {
r.reporter.TriggerCountDec(1)
}
}
return removed
}

func (r *Resolver) handleShutdown() {
Expand Down

0 comments on commit ebc619f

Please sign in to comment.