Skip to content

Commit

Permalink
feat: execute subscription writes on main goroutine in synchronous re…
Browse files Browse the repository at this point in the history
…solve subscriptions
  • Loading branch information
jensneuse committed Dec 2, 2024
1 parent f7a31e8 commit acdaf47
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
61 changes: 44 additions & 17 deletions v2/pkg/engine/resolve/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ type sub struct {
id SubscriptionIdentifier
completed chan struct{}
lastWrite time.Time
// executor is an optional argument that allows us to "schedule" the execution of an update on another thread
// e.g. if we're using SSE/Multipart Fetch, we can run the execution on the goroutine of the http request
// this ensures that ctx cancellation works properly when a client disconnects
executor chan func()
}

func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput []byte) {
Expand Down Expand Up @@ -495,6 +499,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
id: add.id,
completed: add.completed,
lastWrite: time.Now(),
executor: add.executor,
}
if add.ctx.ExecutionOptions.SendHeartbeat {
r.heartbeatSubscriptions[add.ctx] = s
Expand Down Expand Up @@ -687,6 +692,9 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
trig.inFlight = wg
for c, s := range trig.subscriptions {
c, s := c, s
if err := c.ctx.Err(); err != nil {
continue // no need to schedule an event update when the client already disconnected
}
skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf)
if err != nil {
r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer)
Expand All @@ -695,12 +703,22 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
if skip {
continue
}

wg.Add(1)
go func() {
defer wg.Done()
fn := func() {
r.executeSubscriptionUpdate(c, s, data)
}()
}
go func(fn func()) {
defer wg.Done()
if s.executor != nil {
select {
case <-r.ctx.Done():
case <-c.ctx.Done():
case s.executor <- fn:
}
} else {
fn()
}
}(fn)
}
}

Expand Down Expand Up @@ -825,6 +843,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID)
}
completed := make(chan struct{})
executor := make(chan func())
select {
case <-r.ctx.Done():
return r.ctx.Err()
Expand All @@ -838,25 +857,32 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
writer: writer,
id: id,
completed: completed,
executor: executor,
},
}:
}
select {
case <-r.ctx.Done():
// the resolver ctx was canceled
// this will trigger the shutdown of the trigger (on another goroutine)
// as such, we need to wait for the trigger to be shutdown
// otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer
// as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done
Loop: // execute fn on the main thread of the incoming request until ctx is done
for {
select {
case <-completed:
return r.ctx.Err()
case <-time.After(time.Second * 5):
return r.ctx.Err()
case <-r.ctx.Done():
// the resolver ctx was canceled
// this will trigger the shutdown of the trigger (on another goroutine)
// as such, we need to wait for the trigger to be shutdown
// otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer
// as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done
select {
case <-completed:
return r.ctx.Err()
case <-time.After(time.Second * 5):
return r.ctx.Err()
case <-ctx.Context().Done():
return ctx.Context().Err()
}
case <-ctx.Context().Done():
return ctx.Context().Err()
break Loop
case fn := <-executor:
fn()
}
case <-ctx.Context().Done():
}
if r.options.Debug {
fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID)
Expand Down Expand Up @@ -1008,6 +1034,7 @@ type addSubscription struct {
writer SubscriptionResponseWriter
id SubscriptionIdentifier
completed chan struct{}
executor chan func()
}

type subscriptionEventKind int
Expand Down
5 changes: 5 additions & 0 deletions v2/pkg/engine/resolve/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5201,6 +5201,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
resolver := newResolver(c)

ctx := &Context{
ctx: context.Background(),
Variables: astjson.MustParseBytes([]byte(`{"id":1}`)),
}

Expand Down Expand Up @@ -5296,6 +5297,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
resolver := newResolver(c)

ctx := &Context{
ctx: context.Background(),
Variables: astjson.MustParseBytes([]byte(`{"id":2}`)),
}

Expand Down Expand Up @@ -5389,6 +5391,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
resolver := newResolver(c)

ctx := &Context{
ctx: context.Background(),
Variables: astjson.MustParseBytes([]byte(`{"ids":[1,2]}`)),
}

Expand Down Expand Up @@ -5487,6 +5490,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
resolver := newResolver(c)

ctx := &Context{
ctx: context.Background(),
Variables: astjson.MustParseBytes([]byte(`{"ids":["2","3"]}`)),
}

Expand Down Expand Up @@ -5595,6 +5599,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
resolver := newResolver(c)

ctx := &Context{
ctx: context.Background(),
Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)),
}

Expand Down

0 comments on commit acdaf47

Please sign in to comment.