diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index ce45bbd4a..51febd082 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -88,6 +88,8 @@ func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLRespon var multipartSubHeartbeatInterval = 15 * time.Millisecond +const testMaxSubscriptionWorkers = 1 + func newResolver(ctx context.Context) *Resolver { return New(ctx, ResolverOptions{ MaxConcurrency: 1024, @@ -96,6 +98,7 @@ func newResolver(ctx context.Context) *Resolver { PropagateSubgraphStatusCodes: true, AsyncErrorWriter: &TestErrorWriter{}, MultipartSubHeartbeatInterval: multipartSubHeartbeatInterval, + MaxSubscriptionWorkers: testMaxSubscriptionWorkers, }) } @@ -4858,6 +4861,8 @@ func createFakeStream(messageFunc messageFunc, delay time.Duration, onStart func type messageFunc func(counter int) (message string, done bool) +var fakeStreamRequestId atomic.Int32 + type _fakeStream struct { messageFunc messageFunc onStart func(input []byte) @@ -4880,7 +4885,7 @@ func (f *_fakeStream) AwaitIsDone(t *testing.T, timeout time.Duration) { } func (f *_fakeStream) UniqueRequestID(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) { - _, err = xxh.WriteString("fakeStream") + _, err = xxh.WriteString(fmt.Sprintf("%d", fakeStreamRequestId.Add(1))) if err != nil { return } @@ -5391,6 +5396,60 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { `{"data":null,"extensions":{"queryPlan":{"version":"1","kind":"Sequence","trigger":{"kind":"Trigger","path":"countryUpdated","subgraphName":"country","subgraphId":"0","fetchId":0,"query":"subscription { countryUpdated { name time { local } } }"},"children":[{"kind":"Single","fetch":{"kind":"Single","path":"countryUpdated.time","subgraphName":"time","subgraphId":"1","fetchId":1,"dependsOnFetchIds":[0],"query":"query($representations: [_Any!]!){\n _entities(representations: $representations){\n ... on Time {\n __typename\n local\n }\n }\n}"}}]}}}`, }, recorder.Messages()) }) + + t.Run("should successfully allow more than one subscription using http multipart", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false + }, 0, func(input []byte) { + assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + }) + + resolver, plan, _, _ := setup(c, fakeStream) + + ctx := &Context{ + ctx: context.Background(), + ExecutionOptions: ExecutionOptions{ + SendHeartbeat: true, + }, + } + + const numSubscriptions = testMaxSubscriptionWorkers + 1 + var resolverCompleted atomic.Uint32 + var recorderCompleted atomic.Uint32 + for i := 0; i < numSubscriptions; i++ { + recorder := &SubscriptionRecorder{ + buf: &bytes.Buffer{}, + messages: []string{}, + complete: atomic.Bool{}, + } + recorder.complete.Store(false) + + go func() { + defer recorderCompleted.Add(1) + + recorder.AwaitAnyMessageCount(t, defaultTimeout) + }() + + go func() { + defer resolverCompleted.Add(1) + + err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder) + assert.ErrorIs(t, err, context.Canceled) + }() + } + assert.Eventually(t, func() bool { + return recorderCompleted.Load() == numSubscriptions + }, defaultTimeout, time.Millisecond*100) + + cancel() + + assert.Eventually(t, func() bool { + return resolverCompleted.Load() == numSubscriptions + }, defaultTimeout, time.Millisecond*100) + }) } func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {