From c2c76de2451470d1d9325e083fa03683d84209ce Mon Sep 17 00:00:00 2001 From: Alessandro Pagnin Date: Tue, 11 Feb 2025 15:20:24 +0100 Subject: [PATCH] feat: make router-tests less flaky (#1484) --- .github/workflows/router-ci.yaml | 8 +- router-tests/cache_warmup_test.go | 3 +- router-tests/complexity_limits_test.go | 7 + router-tests/config_hot_reload_test.go | 34 +- .../{ => events}/events_config_test.go | 2 +- router-tests/events/kafka_events_test.go | 445 ++++++++------- router-tests/events/nats_events_test.go | 522 +++++++++++------- router-tests/structured_logging_test.go | 2 +- router-tests/telemetry/telemetry_test.go | 27 +- router-tests/testenv/pubsub.go | 3 + router-tests/testenv/testenv.go | 153 +++-- router-tests/timeout_test.go | 2 +- router-tests/websocket_test.go | 496 ++++++++--------- 13 files changed, 950 insertions(+), 754 deletions(-) rename router-tests/{ => events}/events_config_test.go (98%) diff --git a/.github/workflows/router-ci.yaml b/.github/workflows/router-ci.yaml index 7db9b9c7b5..15039b5106 100644 --- a/.github/workflows/router-ci.yaml +++ b/.github/workflows/router-ci.yaml @@ -176,7 +176,11 @@ jobs: redis-cli -u "redis://cosmo:test@127.0.0.1:$port" ping echo "ACL user 'cosmo' created with full access on port $port" done - - uses: nick-fields/retry@v3 + - name: Run Integration tests + working-directory: ./router-tests + run: make test test_params="-run '^Test[^(Flaky)]' --timeout=5m --parallel 10" + - name: Run Flaky Integration tests + uses: nick-fields/retry@v3 with: timeout_minutes: 30 max_attempts: 5 @@ -184,7 +188,7 @@ jobs: retry_on: error command: | cd router-tests - make test test_params="--timeout=5m" + make test test_params="-run '^TestFlaky' --timeout=5m -p 1 --parallel 1" image_scan: if: github.event.pull_request.head.repo.full_name == github.repository diff --git a/router-tests/cache_warmup_test.go b/router-tests/cache_warmup_test.go index 4fde9cf4db..b0e2ffc56e 100644 --- a/router-tests/cache_warmup_test.go +++ b/router-tests/cache_warmup_test.go @@ -659,7 +659,8 @@ func TestCacheWarmup(t *testing.T) { }) } -func TestCacheWarmupMetrics(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyCacheWarmupMetrics(t *testing.T) { t.Run("should emit planning times metrics during warmup", func(t *testing.T) { t.Parallel() diff --git a/router-tests/complexity_limits_test.go b/router-tests/complexity_limits_test.go index 2aabb798cd..85e47d8f23 100644 --- a/router-tests/complexity_limits_test.go +++ b/router-tests/complexity_limits_test.go @@ -3,6 +3,7 @@ package integration import ( "net/http" "testing" + "time" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" @@ -152,6 +153,8 @@ func TestComplexityLimits(t *testing.T) { require.Contains(t, testSpan.Attributes(), otel.WgQueryDepth.Int(3)) require.Contains(t, testSpan.Attributes(), otel.WgQueryDepthCacheHit.Bool(false)) exporter.Reset() + // wait to let cache get consistent + time.Sleep(100 * time.Millisecond) failedRes2, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ Query: `{ employee(id:1) { id details { forename surname } } }`, @@ -163,6 +166,8 @@ func TestComplexityLimits(t *testing.T) { require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepth.Int(3)) require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true)) exporter.Reset() + // wait to let cache get consistent + time.Sleep(100 * time.Millisecond) successRes := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `query { employees { id } }`, @@ -172,6 +177,8 @@ func TestComplexityLimits(t *testing.T) { require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepth.Int(2)) require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepthCacheHit.Bool(false)) exporter.Reset() + // wait to let cache get consistent + time.Sleep(100 * time.Millisecond) successRes2 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `query { employees { id } }`, diff --git a/router-tests/config_hot_reload_test.go b/router-tests/config_hot_reload_test.go index f69d3e270c..10618aebb4 100644 --- a/router-tests/config_hot_reload_test.go +++ b/router-tests/config_hot_reload_test.go @@ -10,6 +10,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/routerconfig" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" @@ -238,28 +239,35 @@ func TestConfigHotReload(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { - var done atomic.Bool - + var startedReq atomic.Bool go func() { - defer done.Store(true) - + startedReq.Store(true) res, err := xEnv.MakeGraphQLRequestWithContext(context.Background(), testenv.GraphQLRequest{ Query: `{ employees { id } }`, }) require.NoError(t, err) - require.Equal(t, res.Response.StatusCode, 200) - require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'."}],"data":{"employees":null}}`, res.Body) + assert.Equal(t, res.Response.StatusCode, 200) + assert.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'."}],"data":{"employees":null}}`, res.Body) }() // Let's wait a bit to make sure all requests are in flight // otherwise the shutdown will be too fast and the wait-group will not be done fully + require.Eventually(t, func() bool { + return startedReq.Load() + }, time.Second*10, time.Millisecond*100) time.Sleep(time.Millisecond * 100) - xEnv.Shutdown() + var done atomic.Bool + go func() { + defer done.Store(true) + + err := xEnv.Router.Shutdown(context.Background()) + assert.ErrorContains(t, err, context.DeadlineExceeded.Error()) + }() require.Eventually(t, func() bool { return done.Load() - }, time.Second*5, time.Millisecond*100) + }, time.Second*20, time.Millisecond*100) }) }) @@ -314,13 +322,15 @@ func TestConfigHotReload(t *testing.T) { // Swap config require.NoError(t, pm.updateConfig(pm.initConfig, "old-1")) - err = conn.ReadJSON(&msg) - // Ensure that the connection is closed. In the future, we might want to send a complete message to the client + // If the operation happen fast enough, ensure that the connection is closed. + // In the future, we might want to send a complete message to the client // and wait until in-flight messages are delivered before closing the connection - var wsErr *websocket.CloseError - require.ErrorAs(t, err, &wsErr) + if err != nil { + var wsErr *websocket.CloseError + require.ErrorAs(t, err, &wsErr) + } require.NoError(t, conn.Close()) diff --git a/router-tests/events_config_test.go b/router-tests/events/events_config_test.go similarity index 98% rename from router-tests/events_config_test.go rename to router-tests/events/events_config_test.go index 5fb60fe603..110d0d4cff 100644 --- a/router-tests/events_config_test.go +++ b/router-tests/events/events_config_test.go @@ -1,4 +1,4 @@ -package integration_test +package events_test import ( "github.com/stretchr/testify/assert" diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index 9916bfb256..caedd86fba 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -6,12 +6,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/wundergraph/cosmo/router/core" "net/http" "sync/atomic" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/wundergraph/cosmo/router/core" + "github.com/hasura/go-graphql-client" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" @@ -20,6 +22,8 @@ import ( "github.com/wundergraph/cosmo/router/pkg/config" ) +const KafkaWaitTimeout = time.Second * 30 + func TestLocalKafka(t *testing.T) { t.Skip("skip only for local testing") @@ -88,17 +92,17 @@ func TestKafkaEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) _ = client.Close() }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(1, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForMessagesSent(1, KafkaWaitTimeout) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -159,32 +163,32 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], ``) // Empty message require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error require.Eventually(t, func() bool { return counter.Load() == 3 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -238,20 +242,20 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(2, time.Second*10) + xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(2, time.Second*10) + xEnv.WaitForMessagesSent(2, KafkaWaitTimeout) require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) _ = client.Close() - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -330,20 +334,20 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(2, time.Second*10) + xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) produceKafkaMessage(t, xEnv, topics[1], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForMessagesSent(4, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForMessagesSent(4, KafkaWaitTimeout) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -398,37 +402,24 @@ func TestKafkaEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) _ = client.Close() }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(1, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForMessagesSent(1, KafkaWaitTimeout) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) t.Run("multipart", func(t *testing.T) { t.Parallel() - assertLineEquals := func(t *testing.T, reader *bufio.Reader, expected string) { - line, _, err := reader.ReadLine() - require.NoError(t, err) - require.Equal(t, expected, string(line)) - } - - assertMultipartPrefix := func(t *testing.T, reader *bufio.Reader) { - assertLineEquals(t, reader, "") - assertLineEquals(t, reader, "--graphql") - assertLineEquals(t, reader, "Content-Type: application/json") - assertLineEquals(t, reader, "") - } - - var multipartHeartbeatInterval = time.Second + var multipartHeartbeatInterval = time.Second * 5 t.Run("subscribe sync", func(t *testing.T) { t.Parallel() @@ -447,6 +438,7 @@ func TestKafkaEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) + var started atomic.Bool var consumed atomic.Uint32 var produced atomic.Uint32 @@ -460,46 +452,38 @@ func TestKafkaEvents(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() reader := bufio.NewReader(resp.Body) + started.Store(true) - require.Eventually(t, func() bool { + assert.Eventually(t, func() bool { return produced.Load() == 1 - }, time.Second*10, time.Millisecond*100) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") + }, KafkaWaitTimeout, time.Millisecond*100) + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") consumed.Add(1) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{}") - consumed.Add(1) - - require.Eventually(t, func() bool { + assert.Eventually(t, func() bool { return produced.Load() == 2 - }, time.Second*10, time.Millisecond*100) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") - + }, KafkaWaitTimeout, time.Millisecond*100) + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") consumed.Add(1) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) + assert.Eventually(t, started.Load, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(1, time.Second*5) produced.Add(1) - require.Eventually(t, func() bool { - return consumed.Load() == 2 - }, time.Second*10, time.Millisecond*100) + assert.Eventually(t, func() bool { + return consumed.Load() == 1 + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - xEnv.WaitForMessagesSent(3, time.Second*5) // 2 messages + the empty one produced.Add(1) - require.Eventually(t, func() bool { - return consumed.Load() == 3 - }, time.Second*10, time.Millisecond*100) + // Wait for the client to finish + require.Eventually(t, func() bool { return consumed.Load() == 2 }, KafkaWaitTimeout*2, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -528,11 +512,10 @@ func TestKafkaEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) }) @@ -585,16 +568,16 @@ func TestKafkaEvents(t *testing.T) { }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -645,16 +628,16 @@ func TestKafkaEvents(t *testing.T) { require.Equal(t, "", string(line)) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -697,139 +680,8 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, err) require.Equal(t, "data: {\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}", string(data)) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) - }) - }) - - t.Run("subscribe async with filter", func(t *testing.T) { - t.Parallel() - - topics := []string{"employeeUpdated", "employeeUpdatedTwo"} - - testenv.Run(t, &testenv.Config{ - RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, - EnableKafka: true, - }, func(t *testing.T, xEnv *testenv.Environment) { - - ensureTopicExists(t, xEnv, topics...) - - type subscriptionPayload struct { - Data struct { - FilteredEmployeeUpdatedMyKafka struct { - ID float64 `graphql:"id"` - Details struct { - Forename string `graphql:"forename"` - Surname string `graphql:"surname"` - } `graphql:"details"` - } `graphql:"filteredEmployeeUpdatedMyKafka(employeeID: 1)"` - } `json:"data"` - } - - // conn.Close() is called in a cleanup defined in the function - conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ - ID: "1", - Type: "subscribe", - Payload: []byte(`{"query":"subscription { filteredEmployeeUpdatedMyKafka(employeeID: 1) { id details { forename, surname } } }"}`), - }) - - require.NoError(t, err) - var msg testenv.WebSocketMessage - var payload subscriptionPayload - - xEnv.WaitForSubscriptionCount(1, time.Second*5) - - var produced atomic.Uint32 - var consumed atomic.Uint32 - const MsgCount = uint32(12) - - go func() { - consumed.Add(1) // the first message is ignored - - require.Eventually(t, func() bool { - return produced.Load() == MsgCount-11 - }, time.Second*5, time.Millisecond*100) - gErr := conn.ReadJSON(&msg) - require.NoError(t, gErr) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - gErr = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, gErr) - require.Equal(t, float64(11), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) - require.Equal(t, "Alexandra", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) - require.Equal(t, "Neuse", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) - consumed.Add(4) // should arrive to 5th message, with id 7 - - require.Eventually(t, func() bool { - return produced.Load() == MsgCount-7 - }, time.Second*5, time.Millisecond*100) - gErr = conn.ReadJSON(&msg) - require.NoError(t, gErr) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - gErr = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, gErr) - require.Equal(t, float64(7), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) - require.Equal(t, "Suvij", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) - require.Equal(t, "Surya", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) - consumed.Add(3) // should arrive to 8th message, with id 4 - - require.Eventually(t, func() bool { - return produced.Load() == MsgCount-4 - }, time.Second*5, time.Millisecond*100) - gErr = conn.ReadJSON(&msg) - require.NoError(t, gErr) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - gErr = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, gErr) - require.Equal(t, float64(4), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) - require.Equal(t, "Björn", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) - require.Equal(t, "Schwenzer", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) - consumed.Add(1) - - require.Eventually(t, func() bool { - return produced.Load() == MsgCount-3 - }, time.Second*5, time.Millisecond*100) - gErr = conn.ReadJSON(&msg) - require.NoError(t, gErr) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - gErr = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, gErr) - require.Equal(t, float64(3), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) - require.Equal(t, "Stefan", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) - require.Equal(t, "Avram", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) - consumed.Add(2) // should arrive to 10th message, with id 2 - - require.Eventually(t, func() bool { - return produced.Load() == MsgCount-1 - }, time.Second*5, time.Millisecond*100) - gErr = conn.ReadJSON(&msg) - require.NoError(t, gErr) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - gErr = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, gErr) - require.Equal(t, float64(1), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) - require.Equal(t, "Jens", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) - require.Equal(t, "Neuse", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) - consumed.Add(1) - }() - - // Events 1, 3, 4, 7, and 11 should be included - for i := MsgCount; i > 0; i-- { - require.Eventually(t, func() bool { - return consumed.Load() >= MsgCount-i - }, time.Second*5, time.Millisecond*100) - produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) - produced.Add(1) - } - - require.Eventually(t, func() bool { - return consumed.Load() == MsgCount && produced.Load() == MsgCount - }, time.Second*10, time.Millisecond*100) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) @@ -869,7 +721,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -877,7 +729,7 @@ func TestKafkaEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return produced.Load() == 1 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr := conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -891,7 +743,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 2 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -905,7 +757,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 11 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -919,7 +771,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 12 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -936,14 +788,14 @@ func TestKafkaEvents(t *testing.T) { for i := uint32(1); i < 13; i++ { require.Eventually(t, func() bool { return consumed.Load() >= i-1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) produced.Add(1) } require.Eventually(t, func() bool { return consumed.Load() == 12 && produced.Load() == 12 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) }) }) @@ -983,7 +835,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -1050,7 +902,7 @@ func TestKafkaEvents(t *testing.T) { for i := uint32(1); i < 13; i++ { require.Eventually(t, func() bool { return consumed.Load() >= i-1 - }, time.Second*5, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) produced.Add(1) } @@ -1097,7 +949,7 @@ func TestKafkaEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) var counter atomic.Uint32 @@ -1125,7 +977,7 @@ func TestKafkaEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) }) }) @@ -1186,29 +1038,162 @@ func TestKafkaEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) produceKafkaMessage(t, xEnv, topics[0], `{asas`) // Invalid message require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error require.Eventually(t, func() bool { return counter.Load() == 3 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) produceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, KafkaWaitTimeout) + xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) + }) + }) +} + +func TestFlakyKafkaEvents(t *testing.T) { + t.Run("subscribe async with filter", func(t *testing.T) { + t.Parallel() + + topics := []string{"employeeUpdated", "employeeUpdatedTwo"} + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + }, func(t *testing.T, xEnv *testenv.Environment) { + + ensureTopicExists(t, xEnv, topics...) + + type subscriptionPayload struct { + Data struct { + FilteredEmployeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"filteredEmployeeUpdatedMyKafka(employeeID: 1)"` + } `json:"data"` + } + + // conn.Close() is called in a cleanup defined in the function + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { filteredEmployeeUpdatedMyKafka(employeeID: 1) { id details { forename, surname } } }"}`), + }) + + require.NoError(t, err) + var msg testenv.WebSocketMessage + var payload subscriptionPayload + + xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) + + var produced atomic.Uint32 + var consumed atomic.Uint32 + const MsgCount = uint32(12) + + go func() { + consumed.Add(1) // the first message is ignored + + require.Eventually(t, func() bool { + return produced.Load() == MsgCount-11 + }, KafkaWaitTimeout, time.Millisecond*100) + gErr := conn.ReadJSON(&msg) + require.NoError(t, gErr) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + gErr = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, gErr) + require.Equal(t, float64(11), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) + require.Equal(t, "Alexandra", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) + require.Equal(t, "Neuse", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) + consumed.Add(4) // should arrive to 5th message, with id 7 + + require.Eventually(t, func() bool { + return produced.Load() == MsgCount-7 + }, KafkaWaitTimeout, time.Millisecond*100) + gErr = conn.ReadJSON(&msg) + require.NoError(t, gErr) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + gErr = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, gErr) + require.Equal(t, float64(7), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) + require.Equal(t, "Suvij", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) + require.Equal(t, "Surya", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) + consumed.Add(3) // should arrive to 8th message, with id 4 + + require.Eventually(t, func() bool { + return produced.Load() == MsgCount-4 + }, KafkaWaitTimeout, time.Millisecond*100) + gErr = conn.ReadJSON(&msg) + require.NoError(t, gErr) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + gErr = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, gErr) + require.Equal(t, float64(4), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) + require.Equal(t, "Björn", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) + require.Equal(t, "Schwenzer", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) + consumed.Add(1) + + require.Eventually(t, func() bool { + return produced.Load() == MsgCount-3 + }, KafkaWaitTimeout, time.Millisecond*100) + gErr = conn.ReadJSON(&msg) + require.NoError(t, gErr) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + gErr = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, gErr) + require.Equal(t, float64(3), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) + require.Equal(t, "Stefan", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) + require.Equal(t, "Avram", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) + consumed.Add(2) // should arrive to 10th message, with id 2 + + require.Eventually(t, func() bool { + return produced.Load() == MsgCount-1 + }, KafkaWaitTimeout, time.Millisecond*100) + gErr = conn.ReadJSON(&msg) + require.NoError(t, gErr) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + gErr = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, gErr) + require.Equal(t, float64(1), payload.Data.FilteredEmployeeUpdatedMyKafka.ID) + require.Equal(t, "Jens", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Forename) + require.Equal(t, "Neuse", payload.Data.FilteredEmployeeUpdatedMyKafka.Details.Surname) + consumed.Add(1) + }() + + // Events 1, 3, 4, 7, and 11 should be included + for i := MsgCount; i > 0; i-- { + require.Eventually(t, func() bool { + return consumed.Load() >= MsgCount-i + }, KafkaWaitTimeout, time.Millisecond*100) + produceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + produced.Add(1) + } + + require.Eventually(t, func() bool { + return consumed.Load() == MsgCount && produced.Load() == MsgCount + }, KafkaWaitTimeout, time.Millisecond*100) }) }) } @@ -1252,7 +1237,7 @@ func produceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri require.Eventually(t, func() bool { return done.Load() - }, time.Second*10, time.Millisecond*100) + }, KafkaWaitTimeout, time.Millisecond*100) require.NoError(t, pErr) diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index e07cd63196..aa7d8a9bc9 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -21,10 +21,39 @@ import ( "github.com/wundergraph/cosmo/router/pkg/config" "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" ) +const NatsWaitTimeout = time.Second * 30 + +func assertLineEquals(t *testing.T, reader *bufio.Reader, expected string) { + line, _, err := reader.ReadLine() + assert.NoError(t, err) + assert.Equal(t, expected, string(line)) +} + +func assertMultipartPrefix(t *testing.T, reader *bufio.Reader) { + assertLineEquals(t, reader, "") + assertLineEquals(t, reader, "--graphql") + assertLineEquals(t, reader, "Content-Type: application/json") + assertLineEquals(t, reader, "") +} + +func assertMultipartValueEventually(t *testing.T, reader *bufio.Reader, expected string) { + assert.Eventually(t, func() bool { + assertMultipartPrefix(t, reader) + line, _, err := reader.ReadLine() + assert.NoError(t, err) + if string(line) == "{}" { + return false + } + assert.Equal(t, expected, string(line)) + return true + }, NatsWaitTimeout, time.Millisecond*100) +} + func TestNatsEvents(t *testing.T) { t.Parallel() @@ -71,16 +100,7 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - var closed atomic.Bool - go func() { - require.Eventually(t, func() bool { - return subscriptionCalled.Load() == 2 - }, time.Second*20, time.Millisecond*100) - require.NoError(t, client.Close()) - closed.Store(true) - }() - - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the first subscription resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -88,6 +108,10 @@ func TestNatsEvents(t *testing.T) { }) require.JSONEq(t, `{"data":{"updateAvailability":{"id":3}}}`, resOne.Body) + assert.Eventually(t, func() bool { + return subscriptionCalled.Load() == 1 + }, NatsWaitTimeout, time.Millisecond*100) + // Trigger the first subscription via NATS err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) require.NoError(t, err) @@ -95,13 +119,19 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - require.Eventually(t, func() bool { - return closed.Load() - }, time.Second*20, time.Millisecond*100) + var closed atomic.Bool + go func() { + defer closed.Store(true) + assert.Eventually(t, func() bool { + return subscriptionCalled.Load() == 2 + }, NatsWaitTimeout, time.Millisecond*100) + assert.NoError(t, client.Close()) + }() - xEnv.WaitForMessagesSent(2, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + assert.Eventually(t, closed.Load, NatsWaitTimeout, time.Millisecond*100) + + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) natsLogs := xEnv.Observer().FilterMessageSnippet("Nats").All() require.Len(t, natsLogs, 4) @@ -137,7 +167,7 @@ func TestNatsEvents(t *testing.T) { subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { oldCount := counter.Load() - counter.Add(1) + defer counter.Add(1) if oldCount == 0 { var gqlErr graphql.Errors @@ -162,40 +192,46 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(``)) // Empty message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(1, time.Second*10) + require.Eventually(t, func() bool { + return counter.Load() == 1 + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(2, time.Second*10) + require.Eventually(t, func() bool { + return counter.Load() == 2 + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","update":{"name":"foo"}}`)) // Missing id require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(3, time.Second*10) + + require.Eventually(t, func() bool { + return counter.Load() == 3 + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(4, time.Second*10) require.Eventually(t, func() bool { return counter.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) @@ -244,7 +280,7 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -265,32 +301,19 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForMessagesSent(2, time.Second*10) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - //xEnv.WaitForConnectionCount(0, time.Second*10) flaky + xEnv.WaitForMessagesSent(2, NatsWaitTimeout) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) t.Run("multipart", func(t *testing.T) { t.Parallel() - assertLineEquals := func(t *testing.T, reader *bufio.Reader, expected string) { - line, _, err := reader.ReadLine() - require.NoError(t, err) - require.Equal(t, expected, string(line)) - } - - assertMultipartPrefix := func(t *testing.T, reader *bufio.Reader) { - assertLineEquals(t, reader, "") - assertLineEquals(t, reader, "--graphql") - assertLineEquals(t, reader, "Content-Type: application/json") - assertLineEquals(t, reader, "") - } - heartbeatInterval := 150 * time.Millisecond t.Run("subscribe with multipart responses", func(t *testing.T) { @@ -315,8 +338,6 @@ func TestNatsEvents(t *testing.T) { var consumed atomic.Uint32 go func() { - defer produced.Add(1) - req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload)) resp, err := xEnv.RouterClient.Do(req) require.NoError(t, err) @@ -331,24 +352,14 @@ func TestNatsEvents(t *testing.T) { reader := bufio.NewReader(resp.Body) - // Read the first part - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") - consumed.Add(1) - - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") consumed.Add(1) - require.Eventually(t, func() bool { - return produced.Load() == 2 - }, time.Second*5, time.Millisecond*100) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdated\":{\"id\":3,\"details\":{\"forename\":\"Stefan\",\"surname\":\"Avram\"}}}}}") consumed.Add(1) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -356,11 +367,10 @@ func TestNatsEvents(t *testing.T) { Query: `mutation { updateAvailability(employeeID: 3, isAvailable: true) { id } }`, }) require.JSONEq(t, `{"data":{"updateAvailability":{"id":3}}}`, res.Body) - produced.Add(1) require.Eventually(t, func() bool { - return consumed.Load() == 2 - }, time.Second*10, time.Millisecond*100) + return consumed.Load() == 1 + }, NatsWaitTimeout, time.Millisecond*100) // Trigger the subscription via NATS err := xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) @@ -371,12 +381,12 @@ func TestNatsEvents(t *testing.T) { produced.Add(1) require.Eventually(t, func() bool { - return consumed.Load() == 3 - }, time.Second*10, time.Millisecond*100) + return consumed.Load() == 2 + }, NatsWaitTimeout, time.Millisecond*100) }) }) - t.Run("subscribe with multipart responses http/1", func(t *testing.T) { + t.Run("subscribe with multipart responses http", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ @@ -413,11 +423,11 @@ func TestNatsEvents(t *testing.T) { assertLineEquals(t, reader, "{}") }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -431,13 +441,12 @@ func TestNatsEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { countFor(count: 3) }"}`) - var counter atomic.Uint32 + var done atomic.Bool - var client http.Client go func() { - defer counter.Add(1) + defer done.Store(true) - client = http.Client{} + client := http.Client{} req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload)) resp, err := client.Do(req) require.NoError(t, err) @@ -447,21 +456,15 @@ func TestNatsEvents(t *testing.T) { reader := bufio.NewReader(resp.Body) // Read the first part - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":0}}}") - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":1}}}") - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":2}}}") - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"data\":{\"countFor\":3}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":0}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":1}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":2}}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"countFor\":3}}}") assertLineEquals(t, reader, "--graphql--") }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) - require.Eventually(t, func() bool { - return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) + require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -494,8 +497,7 @@ func TestNatsEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) - assertMultipartPrefix(t, reader) - assertLineEquals(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") + assertMultipartValueEventually(t, reader, "{\"payload\":{\"errors\":[{\"message\":\"operation type 'subscription' is blocked\"}]}}") } }) }) @@ -547,7 +549,7 @@ func TestNatsEvents(t *testing.T) { require.Error(t, err, io.EOF) // Subscription closed after one time }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -565,7 +567,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -579,9 +581,16 @@ func TestNatsEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } } }"}`) - var requestCompleted atomic.Bool + var done atomic.Bool + var producerDone atomic.Bool + + waitForProducer := func() { + assert.Eventually(t, producerDone.Load, NatsWaitTimeout, time.Millisecond*100) + producerDone.Store(false) + } go func() { + defer done.Store(true) client := http.Client{} req, err := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload)) require.NoError(t, err) @@ -597,6 +606,7 @@ func TestNatsEvents(t *testing.T) { defer resp.Body.Close() reader := bufio.NewReader(resp.Body) + waitForProducer() eventNext, _, err := reader.ReadLine() require.NoError(t, err) require.Equal(t, "event: next", string(eventNext)) @@ -607,6 +617,7 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, err) require.Equal(t, "", string(line)) + waitForProducer() eventNext, _, err = reader.ReadLine() require.NoError(t, err) require.Equal(t, "event: next", string(eventNext)) @@ -616,11 +627,9 @@ func TestNatsEvents(t *testing.T) { line, _, err = reader.ReadLine() require.NoError(t, err) require.Equal(t, "", string(line)) - - requestCompleted.Store(true) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription @@ -628,17 +637,23 @@ func TestNatsEvents(t *testing.T) { Query: `mutation { updateAvailability(employeeID: 3, isAvailable: true) { id } }`, }) require.JSONEq(t, `{"data":{"updateAvailability":{"id":3}}}`, res.Body) + err := xEnv.NatsConnectionDefault.Flush() + require.NoError(t, err) + producerDone.Store(true) + + assert.Eventually(t, func() bool { + return !producerDone.Load() + }, NatsWaitTimeout, time.Millisecond*100) // Trigger the subscription via NATS - err := xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) + err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"id":3,"__typename": "Employee"}`)) require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) + producerDone.Store(true) - require.Eventually(t, func() bool { - return requestCompleted.Load() - }, time.Second*10, time.Millisecond*100) + require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -755,7 +770,7 @@ func TestNatsEvents(t *testing.T) { require.Equal(t, "", string(line)) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Send a mutation to trigger the subscription res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -772,10 +787,10 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) }) }) @@ -873,73 +888,6 @@ func TestNatsEvents(t *testing.T) { }) }) - t.Run("subscribe to multiple subjects", func(t *testing.T) { - t.Parallel() - - testenv.Run(t, &testenv.Config{ - RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, - EnableNats: true, - ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { - engineExecutionConfiguration.WebSocketClientReadTimeout = time.Second - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - type subscriptionPayload struct { - Data struct { - EmployeeUpdatedMyNats struct { - ID float64 `graphql:"id"` - } `graphql:"employeeUpdatedMyNats(id: 12)"` - } `json:"data"` - } - - // conn.Close() is called in a cleanup defined in the function - conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ - ID: "1", - Type: "subscribe", - Payload: []byte(`{"query":"subscription { employeeUpdatedMyNats(id: 12) { id }}"}`), - }) - require.NoError(t, err) - var msg testenv.WebSocketMessage - var payload subscriptionPayload - - xEnv.WaitForSubscriptionCount(1, time.Second*20) - - // Trigger the first subscription via NATS - err = xEnv.NatsConnectionMyNats.Publish(xEnv.GetPubSubName("employeeUpdatedMyNats.12"), []byte(`{"id":13,"__typename":"Employee"}`)) - require.NoError(t, err) - - err = xEnv.NatsConnectionMyNats.Flush() - require.NoError(t, err) - - xEnv.WaitForMessagesSent(1, time.Second*10) - - err = conn.ReadJSON(&msg) - require.NoError(t, err) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - err = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, err) - require.Equal(t, float64(13), payload.Data.EmployeeUpdatedMyNats.ID) - - // Trigger the first subscription via NATS - err = xEnv.NatsConnectionMyNats.Publish(xEnv.GetPubSubName("employeeUpdatedMyNatsTwo.12"), []byte(`{"id":99,"__typename":"Employee"}`)) - require.NoError(t, err) - - err = xEnv.NatsConnectionMyNats.Flush() - require.NoError(t, err) - - xEnv.WaitForMessagesSent(2, time.Second*10) - - err = conn.ReadJSON(&msg) - require.NoError(t, err) - require.Equal(t, "1", msg.ID) - require.Equal(t, "next", msg.Type) - err = json.Unmarshal(msg.Payload, &payload) - require.NoError(t, err) - require.Equal(t, float64(99), payload.Data.EmployeeUpdatedMyNats.ID) - }) - }) - t.Run("subscribe with stream and consumer", func(t *testing.T) { t.Parallel() @@ -980,7 +928,7 @@ func TestNatsEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Trigger the first subscription via NATS err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.12"), []byte(`{"id":13,"__typename":"Employee"}`)) @@ -1003,7 +951,7 @@ func TestNatsEvents(t *testing.T) { Type: "complete", }) require.NoError(t, err) - xEnv.WaitForSubscriptionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) var complete testenv.WebSocketMessage err = conn.ReadJSON(&complete) @@ -1024,7 +972,7 @@ func TestNatsEvents(t *testing.T) { Payload: []byte(`{"query":"subscription { employeeUpdatedNatsStream(id: 12) { id }}"}`), }) require.NoError(t, err) - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) err = conn.ReadJSON(&msg) require.NoError(t, err) @@ -1101,7 +1049,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return counter.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -1139,7 +1087,7 @@ func TestNatsEvents(t *testing.T) { var msg testenv.WebSocketMessage var payload subscriptionPayload - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) var produced atomic.Uint32 var consumed atomic.Uint32 @@ -1147,7 +1095,7 @@ func TestNatsEvents(t *testing.T) { go func() { require.Eventually(t, func() bool { return produced.Load() == 1 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr := conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1161,7 +1109,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 2 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1175,7 +1123,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 4 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1189,7 +1137,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 5 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1203,7 +1151,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 6 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1217,7 +1165,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 8 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1231,7 +1179,7 @@ func TestNatsEvents(t *testing.T) { require.Eventually(t, func() bool { return produced.Load() == 9 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1241,11 +1189,11 @@ func TestNatsEvents(t *testing.T) { require.Equal(t, float64(8), payload.Data.FilteredEmployeeUpdated.ID) require.Equal(t, "Nithin", payload.Data.FilteredEmployeeUpdated.Details.Forename) require.Equal(t, "Kumar", payload.Data.FilteredEmployeeUpdated.Details.Surname) - consumed.Add(2) // should skip two messages + consumed.Add(3) // should skip two messages require.Eventually(t, func() bool { return produced.Load() == 12 - }, time.Second*10, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) gErr = conn.ReadJSON(&msg) require.NoError(t, gErr) require.Equal(t, "1", msg.ID) @@ -1266,9 +1214,8 @@ func TestNatsEvents(t *testing.T) { // Events 1, 3, 4, 5, 7, 8, and 11 should be included for i := uint32(1); i < 13; i++ { require.Eventually(t, func() bool { - return consumed.Load() >= i-1 - }, time.Second*10, time.Millisecond*100) - + return consumed.Load() >= i + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.1"), []byte(fmt.Sprintf(`{"id":%d,"__typename":"Employee"}`, i))) require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() @@ -1277,8 +1224,8 @@ func TestNatsEvents(t *testing.T) { } require.Eventually(t, func() bool { - return consumed.Load() == 11 && produced.Load() == 13 - }, time.Second*10, time.Millisecond*100) + return consumed.Load() == 12 && produced.Load() == 13 + }, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -1292,13 +1239,19 @@ func TestNatsEvents(t *testing.T) { subscribePayload := []byte(`{"query":"subscription { filteredEmployeeUpdated(id: 1) { id details { forename surname } } }"}`) - var requestsDone atomic.Bool + var done atomic.Bool + var producerDone atomic.Bool + + waitForProducer := func() { + assert.Eventually(t, producerDone.Load, NatsWaitTimeout, time.Millisecond*100) + producerDone.Store(false) + } tick := make(chan struct{}, 1) - timeout := time.After(time.Second * 10) + timeout := time.After(NatsWaitTimeout) go func() { - defer requestsDone.Store(true) + defer done.Store(true) client := http.Client{} req, gErr := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload)) @@ -1321,6 +1274,7 @@ func TestNatsEvents(t *testing.T) { reader := bufio.NewReader(resp.Body) + waitForProducer() eventNext, _, gErr := reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1337,6 +1291,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1353,6 +1308,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1369,6 +1325,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1385,6 +1342,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1401,6 +1359,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1417,6 +1376,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1433,6 +1393,7 @@ func TestNatsEvents(t *testing.T) { require.Fail(t, "timeout") } + waitForProducer() eventNext, _, gErr = reader.ReadLine() require.NoError(t, gErr) require.Equal(t, "event: next", string(eventNext)) @@ -1444,7 +1405,7 @@ func TestNatsEvents(t *testing.T) { require.Equal(t, "", string(line)) }() - xEnv.WaitForSubscriptionCount(1, time.Second*5) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) // Trigger the subscription via NATS err := xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.1"), []byte(`{"id":1,"__typename": "Employee"}`)) @@ -1453,6 +1414,8 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) + producerDone.Store(true) + // Events 1, 3, 4, 5, 7, 8, and 11 should be included for i := 1; i < 13; i++ { @@ -1460,6 +1423,9 @@ func TestNatsEvents(t *testing.T) { case 1, 3, 4, 5, 7, 8, 11: select { case <-tick: + assert.Eventually(t, func() bool { + return !producerDone.Load() + }, NatsWaitTimeout, time.Millisecond*100) case <-timeout: require.Fail(t, "timeout") } @@ -1471,11 +1437,10 @@ func TestNatsEvents(t *testing.T) { err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) + producerDone.Store(true) } - require.Eventually(t, func() bool { - return requestsDone.Load() - }, time.Second*10, time.Millisecond*100) + require.Eventually(t, done.Load, NatsWaitTimeout, time.Millisecond*100) }) }) @@ -1507,19 +1472,25 @@ func TestNatsEvents(t *testing.T) { subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { defer consumed.Add(1) - oldCount := produced.Load() + oldCount := consumed.Load() + require.Eventually(t, func() bool { + return oldCount == produced.Load()-1 + }, NatsWaitTimeout, time.Millisecond*100) - if oldCount == 1 { + if oldCount == 0 { var gqlErr graphql.Errors require.ErrorAs(t, errValue, &gqlErr) - require.Equal(t, "Invalid message received", gqlErr[0].Message) - } else if oldCount == 2 || oldCount == 4 { - require.NoError(t, errValue) - require.JSONEq(t, `{"employeeUpdated":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(dataValue)) - } else if oldCount == 3 { + assert.Equal(t, "Invalid message received", gqlErr[0].Message) + } else if oldCount == 1 { + assert.NoError(t, errValue) + assert.JSONEq(t, `{"employeeUpdated":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(dataValue)) + } else if oldCount == 2 { var gqlErr graphql.Errors require.ErrorAs(t, errValue, &gqlErr) - require.Equal(t, "Cannot return null for non-nullable field 'Subscription.employeeUpdated.id'.", gqlErr[0].Message) + assert.Equal(t, "Cannot return null for non-nullable field 'Subscription.employeeUpdated.id'.", gqlErr[0].Message) + } else if oldCount == 3 { + assert.NoError(t, errValue) + assert.JSONEq(t, `{"employeeUpdated":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(dataValue)) } return nil @@ -1532,53 +1503,184 @@ func TestNatsEvents(t *testing.T) { require.NoError(t, clientErr) }() - xEnv.WaitForSubscriptionCount(1, time.Second*10) + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{asas`)) // Invalid message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(1, time.Second*10) + xEnv.WaitForMessagesSent(1, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 1 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(2, time.Second*10) + xEnv.WaitForMessagesSent(2, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 2 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","update":{"name":"foo"}}`)) // Missing id require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(3, time.Second*10) + xEnv.WaitForMessagesSent(3, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 3 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message require.NoError(t, err) err = xEnv.NatsConnectionDefault.Flush() require.NoError(t, err) - xEnv.WaitForMessagesSent(4, time.Second*10) + xEnv.WaitForMessagesSent(4, NatsWaitTimeout) produced.Add(1) require.Eventually(t, func() bool { return consumed.Load() == 4 - }, time.Second*5, time.Millisecond*100) + }, NatsWaitTimeout, time.Millisecond*100) require.NoError(t, client.Close()) - xEnv.WaitForSubscriptionCount(0, time.Second*10) - xEnv.WaitForConnectionCount(0, time.Second*10) + xEnv.WaitForSubscriptionCount(0, NatsWaitTimeout) + xEnv.WaitForConnectionCount(0, NatsWaitTimeout) + }) + }) + + t.Run("shutdown doesn't wait indefinitely", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, + EnableNats: true, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Delay: time.Minute, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscription struct { + employeeUpdated struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdated(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + t.Cleanup(func() { + _ = client.Close() + }) + + var consumed atomic.Uint32 + + subscriptionID, err := client.Subscribe(&subscription, nil, func(dataValue []byte, errValue error) error { + defer consumed.Add(1) + return nil + }) + require.NoError(t, err) + require.NotEqual(t, "", subscriptionID) + + go func() { + clientErr := client.Run() + require.NoError(t, clientErr) + }() + + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) + + err = xEnv.NatsConnectionDefault.Publish(xEnv.GetPubSubName("employeeUpdated.3"), []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) // Correct message + require.NoError(t, err) + err = xEnv.NatsConnectionDefault.Flush() + require.NoError(t, err) + + assert.NoError(t, client.Close()) + + var completed atomic.Bool + go func() { + defer completed.Store(true) + xEnv.Shutdown() + assert.NoError(t, err) + }() + + assert.Eventually(t, completed.Load, NatsWaitTimeout, time.Millisecond*100) + }) + }) +} + +func TestFlakyNatsEvents(t *testing.T) { + t.Run("subscribe to multiple subjects", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, + EnableNats: true, + ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { + engineExecutionConfiguration.WebSocketClientReadTimeout = time.Second + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + type subscriptionPayload struct { + Data struct { + EmployeeUpdatedMyNats struct { + ID float64 `graphql:"id"` + } `graphql:"employeeUpdatedMyNats(id: 12)"` + } `json:"data"` + } + + // conn.Close() is called in a cleanup defined in the function + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { employeeUpdatedMyNats(id: 12) { id }}"}`), + }) + require.NoError(t, err) + var msg testenv.WebSocketMessage + var payload subscriptionPayload + + xEnv.WaitForSubscriptionCount(1, NatsWaitTimeout) + + // Trigger the first subscription via NATS + err = xEnv.NatsConnectionMyNats.Publish(xEnv.GetPubSubName("employeeUpdatedMyNats.12"), []byte(`{"id":13,"__typename":"Employee"}`)) + require.NoError(t, err) + + err = xEnv.NatsConnectionMyNats.Flush() + require.NoError(t, err) + + xEnv.WaitForMessagesSent(1, NatsWaitTimeout) + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + err = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, err) + require.Equal(t, float64(13), payload.Data.EmployeeUpdatedMyNats.ID) + + // Trigger the first subscription via NATS + err = xEnv.NatsConnectionMyNats.Publish(xEnv.GetPubSubName("employeeUpdatedMyNatsTwo.12"), []byte(`{"id":99,"__typename":"Employee"}`)) + require.NoError(t, err) + + err = xEnv.NatsConnectionMyNats.Flush() + require.NoError(t, err) + + xEnv.WaitForMessagesSent(2, NatsWaitTimeout) + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + err = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, err) + require.Equal(t, float64(99), payload.Data.EmployeeUpdatedMyNats.ID) }) }) } diff --git a/router-tests/structured_logging_test.go b/router-tests/structured_logging_test.go index 564bfb4732..e8565e3789 100644 --- a/router-tests/structured_logging_test.go +++ b/router-tests/structured_logging_test.go @@ -295,7 +295,7 @@ func TestAccessLogsFileOutput(t *testing.T) { }) } -func TestAccessLogs(t *testing.T) { +func TestFlakyAccessLogs(t *testing.T) { t.Parallel() t.Run("Simple without custom attributes", func(t *testing.T) { diff --git a/router-tests/telemetry/telemetry_test.go b/router-tests/telemetry/telemetry_test.go index 3b8c5542eb..a96d21fe0f 100644 --- a/router-tests/telemetry/telemetry_test.go +++ b/router-tests/telemetry/telemetry_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" @@ -30,7 +32,7 @@ import ( integration "github.com/wundergraph/cosmo/router-tests" ) -func TestEngineStatisticsTelemetry(t *testing.T) { +func TestFlakyEngineStatisticsTelemetry(t *testing.T) { t.Parallel() t.Run("Should provide correct metrics for one subscription over SSE", func(t *testing.T) { @@ -209,6 +211,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn.ReadJSON(&res) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(1, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 1, Connections: 1, @@ -220,6 +223,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn.ReadJSON(&complete) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(2, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 1, Connections: 1, @@ -288,6 +292,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { wg.Wait() xEnv.WaitForSubscriptionCount(2, time.Second*5) + xEnv.WaitForTriggerCount(1, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, @@ -300,6 +305,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err := conn1.ReadJSON(&res) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(1, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, Connections: 2, @@ -310,6 +316,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn2.ReadJSON(&res) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(2, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, Connections: 2, @@ -324,6 +331,7 @@ func TestEngineStatisticsTelemetry(t *testing.T) { err = conn2.ReadJSON(&complete) require.NoError(t, err) + xEnv.WaitForMinMessagesSent(4, time.Second*5) xEnv.AssertEngineStatistics(t, metricReader, testenv.EngineStatisticAssertion{ Subscriptions: 2, Connections: 2, @@ -460,7 +468,8 @@ func TestEngineStatisticsTelemetry(t *testing.T) { }) } -func TestOperationCacheTelemetry(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyOperationCacheTelemetry(t *testing.T) { t.Parallel() const ( @@ -2522,7 +2531,8 @@ func TestOperationCacheTelemetry(t *testing.T) { }) } -func TestRuntimeTelemetry(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyRuntimeTelemetry(t *testing.T) { t.Parallel() const employeesIDData = `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}` @@ -2899,7 +2909,8 @@ func TestRuntimeTelemetry(t *testing.T) { }) } -func TestTelemetry(t *testing.T) { +// Is set as Flaky so that when running the tests it will be run separately and retried if it fails +func TestFlakyTelemetry(t *testing.T) { t.Parallel() const employeesIDData = `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}` @@ -4033,7 +4044,7 @@ func TestTelemetry(t *testing.T) { }) require.NoError(t, err) require.Equal(t, `{"data":{"rootFieldWithListArg":["a"]}}`, res.Body) - require.Equal(t, "HIT", res.Response.Header.Get(core.PersistedOperationCacheHeader)) + assert.Equal(t, "HIT", res.Response.Header.Get(core.PersistedOperationCacheHeader)) sn = exporter.GetSpans().Snapshots() @@ -8593,8 +8604,10 @@ func TestTelemetry(t *testing.T) { require.Equal(t, `{"errors":[{"message":"The total number of fields 2 exceeds the limit allowed (1)"}]}`, failedRes2.Body) testSpan2 := integration.RequireSpanWithName(t, exporter, "Operation - Validate") - require.Contains(t, testSpan2.Attributes(), otel.WgQueryTotalFields.Int(2)) - require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true)) + assert.Contains(t, testSpan2.Attributes(), otel.WgQueryTotalFields.Int(2)) + assert.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true)) + assert.Equal(t, codes.Unset, testSpan2.Status().Code) + assert.Equal(t, []sdktrace.Event(nil), testSpan2.Events()) exporter.Reset() successRes := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ diff --git a/router-tests/testenv/pubsub.go b/router-tests/testenv/pubsub.go index c44c3389d2..1bf3dcd14e 100644 --- a/router-tests/testenv/pubsub.go +++ b/router-tests/testenv/pubsub.go @@ -96,6 +96,9 @@ func setupNatsData(t testing.TB) (*NatsData, error) { nats.MaxReconnects(10), nats.ReconnectWait(1*time.Second), nats.Timeout(5*time.Second), + nats.ErrorHandler(func(conn *nats.Conn, subscription *nats.Subscription, err error) { + t.Log(err) + }), ) if err != nil { return nil, err diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 26ad12010b..cb8a8aa68a 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -92,10 +92,12 @@ func init() { func Run(t *testing.T, cfg *Config, f func(t *testing.T, xEnv *Environment)) { t.Helper() env, err := createTestEnv(t, cfg) + if env != nil { + t.Cleanup(env.Shutdown) + } if err != nil { t.Fatalf("could not create environment: %s", err) } - t.Cleanup(env.Shutdown) f(t, env) if cfg.AssertCacheMetrics != nil { assertCacheMetrics(t, env, cfg.AssertCacheMetrics.BaseGraphAssertions, "") @@ -120,10 +122,12 @@ func FailsOnStartup(t *testing.T, cfg *Config, f func(t *testing.T, err error)) func RunWithError(t *testing.T, cfg *Config, f func(t *testing.T, xEnv *Environment)) error { t.Helper() env, err := createTestEnv(t, cfg) + if env != nil { + t.Cleanup(env.Shutdown) + } if err != nil { return err } - t.Cleanup(env.Shutdown) f(t, env) if cfg.AssertCacheMetrics != nil { assertCacheMetrics(t, env, cfg.AssertCacheMetrics.BaseGraphAssertions, "") @@ -136,10 +140,12 @@ func Bench(b *testing.B, cfg *Config, f func(b *testing.B, xEnv *Environment)) { b.Helper() b.StopTimer() env, err := createTestEnv(b, cfg) + if env != nil { + b.Cleanup(env.Shutdown) + } if err != nil { b.Fatalf("could not create environment: %s", err) } - b.Cleanup(env.Shutdown) b.StartTimer() f(b, env) if cfg.AssertCacheMetrics != nil { @@ -401,6 +407,27 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { ctx, cancel := context.WithCancelCause(context.Background()) + var ( + logObserver *observer.ObservedLogs + ) + + if oc := cfg.LogObservation; oc.Enabled { + var zCore zapcore.Core + zCore, logObserver = observer.New(oc.LogLevel) + cfg.Logger = logging.NewZapLoggerWithCore(zCore, true) + } else { + ec := zap.NewProductionEncoderConfig() + ec.EncodeDuration = zapcore.SecondsDurationEncoder + ec.TimeKey = "time" + + syncer := zapcore.AddSync(os.Stderr) + cfg.Logger = logging.NewZapLogger(syncer, false, true, zapcore.WarnLevel) + } + + if cfg.AccessLogger == nil { + cfg.AccessLogger = cfg.Logger + } + counters := &SubgraphRequestCount{ Global: atomic.NewInt64(0), Employees: atomic.NewInt64(0), @@ -425,7 +452,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { getPubSubName := GetPubSubNameFn(pubSubPrefix) employees := &Subgraph{ - handler: subgraphs.EmployeesHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.EmployeesHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Employees.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -435,7 +462,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } family := &Subgraph{ - handler: subgraphs.FamilyHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.FamilyHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Family.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -445,7 +472,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } hobbies := &Subgraph{ - handler: subgraphs.HobbiesHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.HobbiesHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Hobbies.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -455,7 +482,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } products := &Subgraph{ - handler: subgraphs.ProductsHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.ProductsHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Products.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -465,7 +492,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } productsFg := &Subgraph{ - handler: subgraphs.ProductsFGHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.ProductsFGHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.ProductsFg.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -475,7 +502,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } test1 := &Subgraph{ - handler: subgraphs.Test1Handler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.Test1Handler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Test1.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -485,7 +512,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } availability := &Subgraph{ - handler: subgraphs.AvailabilityHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.AvailabilityHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Availability.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -495,7 +522,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } mood := &Subgraph{ - handler: subgraphs.MoodHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.MoodHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Mood.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -505,7 +532,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { } countries := &Subgraph{ - handler: subgraphs.CountriesHandler(subgraphOptions(ctx, t, natsSetup, getPubSubName)), + handler: subgraphs.CountriesHandler(subgraphOptions(ctx, t, cfg.Logger, natsSetup, getPubSubName)), middleware: cfg.Subgraphs.Countries.Middleware, globalMiddleware: cfg.Subgraphs.GlobalMiddleware, globalCounter: counters.Global, @@ -580,27 +607,6 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { client = retryClient.StandardClient() } - var ( - logObserver *observer.ObservedLogs - ) - - if oc := cfg.LogObservation; oc.Enabled { - var zCore zapcore.Core - zCore, logObserver = observer.New(oc.LogLevel) - cfg.Logger = logging.NewZapLoggerWithCore(zCore, true) - } else { - ec := zap.NewProductionEncoderConfig() - ec.EncodeDuration = zapcore.SecondsDurationEncoder - ec.TimeKey = "time" - - syncer := zapcore.AddSync(os.Stderr) - cfg.Logger = logging.NewZapLogger(syncer, false, true, zapcore.ErrorLevel) - } - - if cfg.AccessLogger == nil { - cfg.AccessLogger = cfg.Logger - } - kafkaStarted.Wait() rr, err := configureRouter(listenerAddr, cfg, &routerConfig, cdn, natsSetup) @@ -1137,10 +1143,13 @@ func (e *Environment) Shutdown() { ctx, cancel := context.WithTimeout(e.Context, e.shutdownDelay) defer cancel() + // Terminate test server resources + e.cancel(ErrEnvironmentClosed) + // Gracefully shutdown router if e.Router != nil { err := e.Router.Shutdown(ctx) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { + if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { e.t.Errorf("could not shutdown router: %s", err) } } @@ -1150,9 +1159,6 @@ func (e *Environment) Shutdown() { s.CloseClientConnections() } - // Terminate test server resources - e.cancel(ErrEnvironmentClosed) - for _, s := range e.Servers { // Do not call s.Close() here, as it will get stuck on connections left open! lErr := s.Listener.Close() @@ -1907,7 +1913,71 @@ func (e *Environment) WaitForTriggerCount(desiredCount uint64, timeout time.Dura } } -func subgraphOptions(ctx context.Context, t testing.TB, natsData *NatsData, pubSubName func(string) string) *subgraphs.SubgraphOptions { +func DeflakeWSReadMessage(t testing.TB, conn *websocket.Conn) (messageType int, p []byte, err error) { + for i := 0; i < 5; i++ { + messageType, p, err = conn.ReadMessage() + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return messageType, p, err +} + +func DeflakeWSReadJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) { + for i := 0; i < 5; i++ { + err = conn.ReadJSON(v) + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return err +} + +func DeflakeWSWriteMessage(t testing.TB, conn *websocket.Conn, messageType int, data []byte) (err error) { + for i := 0; i < 5; i++ { + err = conn.WriteMessage(messageType, data) + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return err +} + +func DeflakeWSWriteJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) { + for i := 0; i < 5; i++ { + err = conn.WriteJSON(v) + if err != nil && strings.Contains(err.Error(), "connection reset by peer") { + t.Log("connection reset by peer found, retrying...") + err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + require.NoError(t, err) + time.Sleep(time.Duration(i*200) * time.Millisecond) + continue + } + break + } + + return err +} + +func subgraphOptions(ctx context.Context, t testing.TB, logger *zap.Logger, natsData *NatsData, pubSubName func(string) string) *subgraphs.SubgraphOptions { if natsData == nil { return &subgraphs.SubgraphOptions{ NatsPubSubByProviderID: map[string]pubsub_datasource.NatsPubSub{}, @@ -1916,13 +1986,10 @@ func subgraphOptions(ctx context.Context, t testing.TB, natsData *NatsData, pubS } natsPubSubByProviderID := make(map[string]pubsub_datasource.NatsPubSub, len(demoNatsProviders)) for _, sourceName := range demoNatsProviders { - natsConnection, err := nats.Connect(natsData.Server.ClientURL()) - require.NoError(t, err) - - js, err := jetstream.New(natsConnection) + js, err := jetstream.New(natsData.Connections[0]) require.NoError(t, err) - natsPubSubByProviderID[sourceName] = pubsubNats.NewConnector(zap.NewNop(), natsConnection, js, "hostname", "listenaddr").New(ctx) + natsPubSubByProviderID[sourceName] = pubsubNats.NewConnector(logger, natsData.Connections[0], js, "hostname", "listenaddr").New(ctx) } return &subgraphs.SubgraphOptions{ diff --git a/router-tests/timeout_test.go b/router-tests/timeout_test.go index aba2814f8b..07934cb97e 100644 --- a/router-tests/timeout_test.go +++ b/router-tests/timeout_test.go @@ -13,7 +13,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/config" ) -func TestTimeouts(t *testing.T) { +func TestFlakyTimeouts(t *testing.T) { t.Parallel() const queryEmployeeWithHobby = `{ diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index e3d174a812..a0f60f7c54 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "go.uber.org/zap" "io" "math/big" "net" @@ -16,6 +15,8 @@ import ( "testing" "time" + "go.uber.org/zap" + "github.com/buger/jsonparser" "github.com/gorilla/websocket" "github.com/hasura/go-graphql-client" @@ -49,20 +50,20 @@ func TestWebSockets(t *testing.T) { testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -98,20 +99,20 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id startDate } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) require.Equal(t, `[{"message":"Unauthorized"}]`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -147,20 +148,20 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id startDate } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) require.Equal(t, `[{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",0,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",1,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",2,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",3,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",4,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",5,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",6,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",7,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",8,"startDate"]},{"message":"Unauthorized to load field 'Query.employees.startDate', Reason: not authenticated.","path":["employees",9,"startDate"]}]`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -198,7 +199,7 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } startDate }}"}`), @@ -217,7 +218,7 @@ func TestWebSockets(t *testing.T) { }() var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) @@ -257,7 +258,7 @@ func TestWebSockets(t *testing.T) { "Authorization": []string{"Bearer " + token}, } conn := xEnv.InitGraphQLWebSocketConnection(header, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } startDate }}"}`), @@ -274,7 +275,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) }() var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) @@ -317,7 +318,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) initialPayload := []byte(`{"Authorization":"Bearer ` + token + `"}`) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, initialPayload) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -336,7 +337,7 @@ func TestWebSockets(t *testing.T) { }() var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) @@ -375,7 +376,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { require.NoError(t, err) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -383,7 +384,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) payload, err := json.Marshal(res.Payload) @@ -425,7 +426,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) initialPayload := []byte(`{"Authorization": true }`) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, initialPayload) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -433,7 +434,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) payload, err := json.Marshal(res.Payload) @@ -464,7 +465,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) initialPayload := []byte(`{"Authorization":"` + token + `"}`) conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, initialPayload) - err = conn.WriteJSON(testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id }}"}`), @@ -486,7 +487,7 @@ func TestWebSockets(t *testing.T) { require.Eventually(t, done.Load, time.Second*5, time.Millisecond*100) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) @@ -515,18 +516,6 @@ func TestWebSockets(t *testing.T) { expectConnectAndReadCurrentTime(t, xEnv) }) }) - t.Run("subscription with multiple reconnects and netPoll disabled", func(t *testing.T) { - t.Parallel() - - testenv.Run(t, &testenv.Config{ - ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { - engineExecutionConfiguration.EnableNetPoll = false - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - expectConnectAndReadCurrentTime(t, xEnv) - expectConnectAndReadCurrentTime(t, xEnv) - }) - }) t.Run("subscription with header propagation", func(t *testing.T) { t.Parallel() @@ -575,27 +564,27 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, message, err := conn.ReadMessage() + _, message, err := testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) require.Equal(t, `{"type":"connection_init","payload":{"Custom-Auth":"test","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Reverse-Canonical-Header-Name":"matches as well","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"token":"Bearer Something"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) require.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription{currentTime {unixTime timeStamp}}","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Reverse-Canonical-Header-Name":"matches as well","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"token":"Bearer Something"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) if errors.Is(err, websocket.ErrCloseSent) { return } require.Equal(t, `{"id":"1","type":"complete"}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) require.NoError(t, err) }) }, @@ -625,7 +614,7 @@ func TestWebSockets(t *testing.T) { }, []byte(`{"Custom-Auth":"test"}`), ) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -635,7 +624,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -644,7 +633,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -653,7 +642,7 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) @@ -716,29 +705,29 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, message, err := conn.ReadMessage() + _, message, err := testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) message = jsonparser.Delete(message, "payload", "extensions", "upgradeHeaders", "Sec-Websocket-Key") // Sec-Websocket-Key is a random value require.Equal(t, `{"type":"connection_init","payload":{"Custom-Auth":"test","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Connection":"Upgrade","Ignored":"ignored","Not-Allowlisted-But-Forwarded":"but still part of the origin upgrade request","Reverse-Canonical-Header-Name":"matches as well","Sec-Websocket-Protocol":"graphql-transport-ws","Sec-Websocket-Version":"13","Upgrade":"websocket","User-Agent":"Go-http-client/1.1","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"ignored":"ignored","token":"Bearer Something","x-custom-auth":"customAuth"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) require.NoError(t, err) message = jsonparser.Delete(message, "payload", "extensions", "upgradeHeaders", "Sec-Websocket-Key") // Sec-Websocket-Key is a random value require.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription{currentTime {unixTime timeStamp}}","extensions":{"upgradeHeaders":{"Authorization":"Bearer test","Canonical-Header-Name":"matches","Connection":"Upgrade","Ignored":"ignored","Not-Allowlisted-But-Forwarded":"but still part of the origin upgrade request","Reverse-Canonical-Header-Name":"matches as well","Sec-Websocket-Protocol":"graphql-transport-ws","Sec-Websocket-Version":"13","Upgrade":"websocket","User-Agent":"Go-http-client/1.1","X-Custom-Auth":"customAuth"},"upgradeQueryParams":{"ignored":"ignored","token":"Bearer Something","x-custom-auth":"customAuth"},"initialPayload":{"Custom-Auth":"test"}}}}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}`)) require.NoError(t, err) - _, message, err = conn.ReadMessage() + _, message, err = testenv.DeflakeWSReadMessage(t, conn) if errors.Is(err, websocket.ErrCloseSent) { return } require.Equal(t, `{"id":"1","type":"complete"}`, string(message)) - err = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) + err = testenv.DeflakeWSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"complete","id":"1"}`)) require.NoError(t, err) }) }, @@ -769,7 +758,7 @@ func TestWebSockets(t *testing.T) { }, []byte(`{"Custom-Auth":"test"}`), ) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -779,7 +768,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -788,7 +777,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -797,7 +786,7 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) @@ -879,7 +868,7 @@ func TestWebSockets(t *testing.T) { conn := xEnv.InitGraphQLWebSocketConnection(http.Header{ "Authorization": []string{"Bearer test"}, }, nil, []byte(`{"Custom-Auth":"test"}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -889,7 +878,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -898,7 +887,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -907,14 +896,14 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { @@ -995,7 +984,7 @@ func TestWebSockets(t *testing.T) { conn := xEnv.InitGraphQLWebSocketConnection(http.Header{ "Authorization": []string{"Bearer test"}, }, nil, []byte(`{"Custom-Auth":"test"}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -1005,7 +994,7 @@ func TestWebSockets(t *testing.T) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -1014,7 +1003,7 @@ func TestWebSockets(t *testing.T) { require.Equal(t, float64(1), payload.Data.CurrentTime.UnixTime) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -1023,14 +1012,14 @@ func TestWebSockets(t *testing.T) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { @@ -1060,7 +1049,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -1069,7 +1058,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1099,7 +1088,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -1108,7 +1097,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1135,7 +1124,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { returnsError }"}`), @@ -1144,7 +1133,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) @@ -1174,7 +1163,7 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { returnsError }"}`), @@ -1183,151 +1172,19 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "error", msg.Type) require.Equal(t, `[{"message":"Unable to subscribe"}]`, string(msg.Payload)) }) }) - t.Run("multiple subscriptions one connection", func(t *testing.T) { - t.Parallel() - - testenv.Run(t, &testenv.Config{ - ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { - engineExecutionConfiguration.WebSocketClientReadTimeout = time.Second - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - client := graphql.NewSubscriptionClient(xEnv.GraphQLWebSocketSubscriptionURL()). - WithProtocol(graphql.GraphQLWS) - - var wg sync.WaitGroup - - var subscriptionCountEmp struct { - CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` - } - var ( - firstCountEmpID, countEmpID, countEmp2ID, countHobID string - firstCountEmp, countEmp, countEmp2, countHob int - err error - variables = map[string]interface{}{ - "max": 10, - "interval": 200, - } - ) - - wg.Add(1) - - firstCountEmpID, err = client.Subscribe(&subscriptionCountEmp, map[string]interface{}{ - "max": 5, - "interval": 100, - }, func(dataValue []byte, errValue error) error { - require.NoError(t, errValue) - data := subscriptionCountEmp - err := jsonutil.UnmarshalGraphQL(dataValue, &data) - require.NoError(t, err) - require.Equal(t, firstCountEmp, data.CountEmp) - if firstCountEmp == 5 { - wg.Done() - err = client.Unsubscribe(firstCountEmpID) - require.NoError(t, err) - } - firstCountEmp++ - - return nil - }) - require.NoError(t, err) - require.NotEqual(t, "", firstCountEmpID) - - wg.Add(1) - - countEmpID, err = client.Subscribe(&subscriptionCountEmp, variables, func(dataValue []byte, errValue error) error { - require.NoError(t, errValue) - data := subscriptionCountEmp - err := jsonutil.UnmarshalGraphQL(dataValue, &data) - require.NoError(t, err) - require.Equal(t, countEmp, data.CountEmp) - if countEmp == 5 { - wg.Done() - err = client.Unsubscribe(countEmpID) - require.NoError(t, err) - } - countEmp++ - - return nil - }) - require.NoError(t, err) - require.NotEqual(t, "", countEmpID) - - var subscriptionCountEmp2 struct { - CountEmp int `graphql:"countEmp2(max: $max, intervalMilliseconds: $interval)"` - } - - wg.Add(1) - - countEmp2ID, err = client.Subscribe(&subscriptionCountEmp2, variables, func(dataValue []byte, errValue error) error { - require.NoError(t, errValue) - data := subscriptionCountEmp2 - err := jsonutil.UnmarshalGraphQL(dataValue, &data) - require.NoError(t, err) - require.Equal(t, countEmp2, data.CountEmp) - if countEmp2 == 5 { - wg.Done() - err = client.Unsubscribe(countEmp2ID) - require.NoError(t, err) - } - countEmp2++ - - return nil - }) - require.NoError(t, err) - require.NotEqual(t, "", countEmp2ID) - - var subscriptionCountHob struct { - CountHob int `graphql:"countHob(max: $max, intervalMilliseconds: $interval)"` - } - - wg.Add(1) - - countHobID, err = client.Subscribe(&subscriptionCountHob, variables, func(dataValue []byte, errValue error) error { - require.NoError(t, errValue) - data := subscriptionCountHob - err := jsonutil.UnmarshalGraphQL(dataValue, &data) - require.NoError(t, err) - require.Equal(t, countHob, data.CountHob) - if countHob == 5 { - wg.Done() - err = client.Unsubscribe(countHobID) - require.NoError(t, err) - } - countHob++ - - return nil - }) - require.NoError(t, err) - require.NotEqual(t, "", countHobID) - - wg.Add(1) - go func() { - defer wg.Done() - require.NoError(t, client.Run()) - }() - - wg.Wait() - - xEnv.WaitForSubscriptionCount(0, time.Second*5) - xEnv.WaitForConnectionCount(0, time.Second*5) - xEnv.WaitForTriggerCount(0, time.Second*5) - - require.NoError(t, client.Close()) - }) - }) t.Run("error", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { does_not_exist }"}`), @@ -1336,7 +1193,7 @@ func TestWebSockets(t *testing.T) { err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "error", msg.Type) // Payload should be an array of GraphQLError @@ -1535,14 +1392,14 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123": 456, "extensions": {"hello": "world"}}`)) var err error - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }"}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1555,14 +1412,14 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123": 456, "extensions": {"hello": "world"}}`)) var err error - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }"}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1577,20 +1434,20 @@ func TestWebSockets(t *testing.T) { conn := xEnv.InitGraphQLWebSocketConnection(map[string][]string{ "X-Feature-Flag": {"myff"}, }, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id productCount } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) require.JSONEq(t, `{"data":{"employees":[{"id":1,"productCount":5},{"id":2,"productCount":2},{"id":3,"productCount":2},{"id":4,"productCount":3},{"id":5,"productCount":2},{"id":7,"productCount":0},{"id":8,"productCount":2},{"id":10,"productCount":3},{"id":11,"productCount":1},{"id":12,"productCount":4}]}}`, string(res.Payload)) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "complete", complete.Type) require.Equal(t, "1", complete.ID) @@ -1603,14 +1460,14 @@ func TestWebSockets(t *testing.T) { testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"{ employees { id productCount } }"}`), }) require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) @@ -1631,7 +1488,7 @@ func TestWebSockets(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { does_not_exist }"}`), @@ -1639,7 +1496,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) // Discard the first message var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) xEnv.Shutdown() _, _, err = conn.NextReader() @@ -1660,7 +1517,7 @@ func TestWebSockets(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { does_not_exist }"}`), @@ -1668,7 +1525,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) // Discard the first message var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) xEnv.Shutdown() _, _, err = conn.NextReader() @@ -1688,14 +1545,14 @@ func TestWebSockets(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123":456,"extensions":{"hello":"world"}}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }"}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1710,14 +1567,14 @@ func TestWebSockets(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { // "extensions" in the request should override the "extensions" in initial payload conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"123":456,"extensions":{"hello":"world"}}`)) - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { initialPayload(repeat:3) }","extensions":{"hello":"world2"}}`), }) require.NoError(t, err) var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, `{"data":{"initialPayload":{"123":456,"extensions":{"hello":"world2","initialPayload":{"123":456,"extensions":{"hello":"world"}}}}}}`, string(msg.Payload)) }) @@ -1739,7 +1596,7 @@ func TestWebSockets(t *testing.T) { Type: "subscribe", Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } } }"}`), } - err := conn.WriteJSON(&sub1) + err := testenv.DeflakeWSWriteJSON(t, conn, &sub1) require.NoError(t, err) sub2 := testenv.WebSocketMessage{ @@ -1747,7 +1604,7 @@ func TestWebSockets(t *testing.T) { Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), } - err = conn.WriteJSON(&sub2) + err = testenv.DeflakeWSWriteJSON(t, conn, &sub2) require.NoError(t, err) xEnv.WaitForSubscriptionCount(2, time.Second*5) @@ -1765,7 +1622,7 @@ func TestWebSockets(t *testing.T) { var msg testenv.WebSocketMessage for { - err := conn.ReadJSON(&msg) + err := testenv.DeflakeWSReadJSON(t, conn, &msg) if err != nil { return } @@ -1782,10 +1639,10 @@ func TestWebSockets(t *testing.T) { ID: "1", Type: "complete", } - err = conn.WriteJSON(&stop) + err = testenv.DeflakeWSWriteJSON(t, conn, &stop) require.NoError(t, err) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) @@ -1800,10 +1657,10 @@ func TestWebSockets(t *testing.T) { ID: "2", Type: "complete", } - err = conn.WriteJSON(&stop) + err = testenv.DeflakeWSWriteJSON(t, conn, &stop) require.NoError(t, err) var complete testenv.WebSocketMessage - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "2", complete.ID) require.Equal(t, "complete", complete.Type) @@ -1814,7 +1671,7 @@ func TestWebSockets(t *testing.T) { terminate := testenv.WebSocketMessage{ Type: "connection_terminate", } - err = conn.WriteJSON(&terminate) + err = testenv.DeflakeWSWriteJSON(t, conn, &terminate) require.NoError(t, err) _, _, err = conn.NextReader() require.Error(t, err) @@ -1892,19 +1749,19 @@ func TestWebSockets(t *testing.T) { } conn := xEnv.InitAbsintheWebSocketConnection(nil, json.RawMessage(`["1", "1", "__absinthe__:control", "phx_join", {}]`)) - err := conn.WriteJSON(json.RawMessage(`["1", "1", "__absinthe__:control", "doc", {"query":"subscription { currentTime { unixTime timeStamp }}" }]`)) + err := testenv.DeflakeWSWriteJSON(t, conn, json.RawMessage(`["1", "1", "__absinthe__:control", "doc", {"query":"subscription { currentTime { unixTime timeStamp }}" }]`)) require.NoError(t, err) var msg json.RawMessage var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) h := sha256.New() h.Write([]byte("1")) operationId := new(big.Int).SetBytes(h.Sum(nil)) require.Equal(t, string(msg), fmt.Sprintf(`["1","1","__absinthe__:control","phx_reply",{"status":"ok","response":{"subscriptionId":"__absinthe__:doc:1:%s"}}]`, operationId)) - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Contains(t, string(msg), `["1","1","__absinthe__:control","subscription:data"`) var data []json.RawMessage @@ -1916,7 +1773,7 @@ func TestWebSockets(t *testing.T) { unix1 := payload.Result.Data.CurrentTime.UnixTime - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Contains(t, string(msg), `["1","1","__absinthe__:control","subscription:data"`) err = json.Unmarshal(msg, &data) @@ -1929,19 +1786,19 @@ func TestWebSockets(t *testing.T) { require.Greater(t, unix2, unix1) // Sending a complete must stop the subscription - err = conn.WriteJSON(json.RawMessage(`["1", "1", "__absinthe__:control", "phx_leave", {}]`)) + err = testenv.DeflakeWSWriteJSON(t, conn, json.RawMessage(`["1", "1", "__absinthe__:control", "phx_leave", {}]`)) require.NoError(t, err) var complete json.RawMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, string(complete), fmt.Sprintf(`["1","","__absinthe__:control","phx_reply",{"status":"ok","response":{"subscriptionId":"__absinthe__:doc:1:%s"}}]`, operationId)) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) { @@ -1968,7 +1825,7 @@ func TestWebSockets(t *testing.T) { })}, }, func(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime } }"}`), @@ -1976,7 +1833,7 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) var res testenv.WebSocketMessage - err = conn.ReadJSON(&res) + err = testenv.DeflakeWSReadJSON(t, conn, &res) require.NoError(t, err) require.Equal(t, "next", res.Type) require.Equal(t, "1", res.ID) @@ -1987,6 +1844,153 @@ func TestWebSockets(t *testing.T) { }) } +func TestFlakyWebSockets(t *testing.T) { + t.Run("subscription with multiple reconnects and netPoll disabled", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { + engineExecutionConfiguration.EnableNetPoll = false + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + expectConnectAndReadCurrentTime(t, xEnv) + expectConnectAndReadCurrentTime(t, xEnv) + }) + }) + t.Run("multiple subscriptions one connection", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { + engineExecutionConfiguration.WebSocketClientReadTimeout = time.Second + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + client := graphql.NewSubscriptionClient(xEnv.GraphQLWebSocketSubscriptionURL()). + WithProtocol(graphql.GraphQLWS) + + var wg sync.WaitGroup + + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + var ( + firstCountEmpID, countEmpID, countEmp2ID, countHobID string + firstCountEmp, countEmp, countEmp2, countHob int + err error + variables = map[string]interface{}{ + "max": 10, + "interval": 200, + } + ) + + wg.Add(1) + + firstCountEmpID, err = client.Subscribe(&subscriptionCountEmp, map[string]interface{}{ + "max": 5, + "interval": 100, + }, func(dataValue []byte, errValue error) error { + require.NoError(t, errValue) + data := subscriptionCountEmp + err := jsonutil.UnmarshalGraphQL(dataValue, &data) + require.NoError(t, err) + require.Equal(t, firstCountEmp, data.CountEmp) + if firstCountEmp == 5 { + wg.Done() + err = client.Unsubscribe(firstCountEmpID) + require.NoError(t, err) + } + firstCountEmp++ + + return nil + }) + require.NoError(t, err) + require.NotEqual(t, "", firstCountEmpID) + + wg.Add(1) + + countEmpID, err = client.Subscribe(&subscriptionCountEmp, variables, func(dataValue []byte, errValue error) error { + require.NoError(t, errValue) + data := subscriptionCountEmp + err := jsonutil.UnmarshalGraphQL(dataValue, &data) + require.NoError(t, err) + require.Equal(t, countEmp, data.CountEmp) + if countEmp == 5 { + wg.Done() + err = client.Unsubscribe(countEmpID) + require.NoError(t, err) + } + countEmp++ + + return nil + }) + require.NoError(t, err) + require.NotEqual(t, "", countEmpID) + + var subscriptionCountEmp2 struct { + CountEmp int `graphql:"countEmp2(max: $max, intervalMilliseconds: $interval)"` + } + + wg.Add(1) + + countEmp2ID, err = client.Subscribe(&subscriptionCountEmp2, variables, func(dataValue []byte, errValue error) error { + require.NoError(t, errValue) + data := subscriptionCountEmp2 + err := jsonutil.UnmarshalGraphQL(dataValue, &data) + require.NoError(t, err) + require.Equal(t, countEmp2, data.CountEmp) + if countEmp2 == 5 { + wg.Done() + err = client.Unsubscribe(countEmp2ID) + require.NoError(t, err) + } + countEmp2++ + + return nil + }) + require.NoError(t, err) + require.NotEqual(t, "", countEmp2ID) + + var subscriptionCountHob struct { + CountHob int `graphql:"countHob(max: $max, intervalMilliseconds: $interval)"` + } + + wg.Add(1) + + countHobID, err = client.Subscribe(&subscriptionCountHob, variables, func(dataValue []byte, errValue error) error { + require.NoError(t, errValue) + data := subscriptionCountHob + err := jsonutil.UnmarshalGraphQL(dataValue, &data) + require.NoError(t, err) + require.Equal(t, countHob, data.CountHob) + if countHob == 5 { + wg.Done() + err = client.Unsubscribe(countHobID) + require.NoError(t, err) + } + countHob++ + + return nil + }) + require.NoError(t, err) + require.NotEqual(t, "", countHobID) + + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, client.Run()) + }() + + wg.Wait() + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + xEnv.WaitForConnectionCount(0, time.Second*5) + xEnv.WaitForTriggerCount(0, time.Second*5) + + require.NoError(t, client.Close()) + }) + }) +} + func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { type currentTimePayload struct { Data struct { @@ -2000,7 +2004,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) defer conn.Close() - err := conn.WriteJSON(&testenv.WebSocketMessage{ + err := testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "subscribe", Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), @@ -2010,7 +2014,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { var payload currentTimePayload // Read a result and store its timestamp, next result should be 1 second later - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -2019,7 +2023,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { unix1 := payload.Data.CurrentTime.UnixTime - err = conn.ReadJSON(&msg) + err = testenv.DeflakeWSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) require.Equal(t, "next", msg.Type) @@ -2030,7 +2034,7 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { require.Greater(t, unix2, unix1) // Sending a complete must stop the subscription - err = conn.WriteJSON(&testenv.WebSocketMessage{ + err = testenv.DeflakeWSWriteJSON(t, conn, &testenv.WebSocketMessage{ ID: "1", Type: "complete", }) @@ -2039,14 +2043,14 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { var complete testenv.WebSocketMessage err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - err = conn.ReadJSON(&complete) + err = testenv.DeflakeWSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) require.NoError(t, err) - _, _, err = conn.ReadMessage() + _, _, err = testenv.DeflakeWSReadMessage(t, conn) require.Error(t, err) var netErr net.Error if errors.As(err, &netErr) {