diff --git a/gateway/mw_streaming.go b/gateway/mw_streaming.go index 4d0235ede71..e461005a997 100644 --- a/gateway/mw_streaming.go +++ b/gateway/mw_streaming.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" @@ -22,6 +23,8 @@ import ( const ( // ExtensionTykStreaming is the oas extension for tyk streaming ExtensionTykStreaming = "x-tyk-streaming" + StreamGCInterval = 1 * time.Minute + StreamInactiveLimit = 10 * time.Minute ) // StreamsConfig represents a stream configuration @@ -38,11 +41,18 @@ var globalStreamCounter atomic.Int64 // StreamingMiddleware is a middleware that handles streaming functionality type StreamingMiddleware struct { *BaseMiddleware + + createStreamManagerLock sync.Mutex + streamManagerCache sync.Map // Map of payload hash to StreamManager + streamManagers sync.Map // Map of consumer group IDs to StreamManager ctx context.Context cancel context.CancelFunc allowedUnsafe []string defaultStreamManager *StreamManager + + lastActivity sync.Map // Map of stream IDs to last activity time + } // StreamManager is responsible for creating a single stream @@ -53,6 +63,8 @@ type StreamManager struct { mw *StreamingMiddleware dryRun bool listenPaths []string + + lastActivity atomic.Value // Last activity time for the StreamManager } func (sm *StreamManager) initStreams(r *http.Request, config *StreamsConfig) { @@ -114,6 +126,34 @@ func (sm *StreamManager) removeStream(streamID string) error { return nil } +func (s *StreamingMiddleware) garbageCollect() { + s.Logger().Debug("Starting garbage collection for inactive stream managers") + now := time.Now() + + s.streamManagerCache.Range(func(key, value interface{}) bool { + manager := value.(*StreamManager) + if manager == s.defaultStreamManager { + return true + } + + lastActivityTime := manager.lastActivity.Load().(time.Time) + if now.Sub(lastActivityTime) > StreamInactiveLimit { + s.Logger().Infof("Removing inactive stream manager: %v", key) + manager.streams.Range(func(streamKey, streamValue interface{}) bool { + streamID := streamKey.(string) + err := manager.removeStream(streamID) + if err != nil { + s.Logger().Errorf("Error removing stream %s: %v", streamID, err) + } + return true + }) + s.streamManagerCache.Delete(key) + } + + return true + }) +} + // Name is StreamingMiddleware func (s *StreamingMiddleware) Name() string { return "StreamingMiddleware" @@ -150,20 +190,56 @@ func (s *StreamingMiddleware) Init() { s.Logger().Debug("Initializing default stream manager") s.defaultStreamManager = s.createStreamManager(nil) + + // Start garbage collection routine + go func() { + ticker := time.NewTicker(StreamGCInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.garbageCollect() + case <-s.ctx.Done(): + return + } + } + }() } func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManager { - newStreamManager := &StreamManager{ - muxer: mux.NewRouter(), - mw: s, - dryRun: r == nil, + streamsConfig := s.getStreamsConfig(r) + configJSON, _ := json.Marshal(streamsConfig) + cacheKey := fmt.Sprintf("%x", sha256.Sum256(configJSON)) + + // Critical section starts here + // This section is called by ProcessRequest method of the middleware implementation + // Concurrent requests can call this method at the same time and those requests + // creates new StreamManagers and store them concurrently, as a result + // the returned stream manager has overwritten by a different one by leaking + // the previously stored StreamManager. + s.createStreamManagerLock.Lock() + defer s.createStreamManagerLock.Unlock() + + s.Logger().Debug("Attempting to load stream manager from cache") + s.Logger().Debugf("Cache key: %s", cacheKey) + if cachedManager, found := s.streamManagerCache.Load(cacheKey); found { + s.Logger().Debug("Found cached stream manager") + return cachedManager.(*StreamManager) } - streamID := fmt.Sprintf("_%d", time.Now().UnixNano()) - s.streamManagers.Store(streamID, newStreamManager) - // Call initStreams for the new StreamManager - newStreamManager.initStreams(r, s.getStreamsConfig(r)) + newStreamManager := &StreamManager{ + muxer: mux.NewRouter(), + mw: s, + dryRun: r == nil, + lastActivity: atomic.Value{}, + } + newStreamManager.lastActivity.Store(time.Now()) + newStreamManager.initStreams(r, streamsConfig) + if r != nil { + s.streamManagerCache.Store(cacheKey, newStreamManager) + } return newStreamManager } @@ -372,11 +448,8 @@ func (s *StreamingMiddleware) Unload() { s.cancel() s.Logger().Debug("Closing active streams") - s.streamManagers.Range(func(_, value interface{}) bool { - manager, ok := value.(*StreamManager) - if !ok { - return true - } + s.streamManagerCache.Range(func(_, value interface{}) bool { + manager := value.(*StreamManager) manager.streams.Range(func(_, streamValue interface{}) bool { if stream, ok := streamValue.(*streaming.Stream); ok { if err := stream.Reset(); err != nil { @@ -389,6 +462,7 @@ func (s *StreamingMiddleware) Unload() { }) s.streamManagers = sync.Map{} + s.streamManagerCache = sync.Map{} s.Logger().Info("All streams successfully removed") } @@ -411,14 +485,9 @@ func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter, h.sm.routeLock.Lock() h.muxer.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { - defer func() { - // Stop the stream when the HTTP request finishes - if err := h.sm.removeStream(h.streamID); err != nil { - h.logger.Errorf("Failed to stop stream %s: %v", h.streamID, err) - } - }() - + h.sm.lastActivity.Store(time.Now()) f(w, r) + h.sm.lastActivity.Store(time.Now()) }) h.sm.routeLock.Unlock() h.logger.Debugf("Registered handler for path: %s", path) diff --git a/gateway/mw_streaming_test.go b/gateway/mw_streaming_test.go index c55ef7641dc..f0909225368 100644 --- a/gateway/mw_streaming_test.go +++ b/gateway/mw_streaming_test.go @@ -1,6 +1,7 @@ package gateway import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -147,6 +148,17 @@ streams: static_fields: '@service': benthos ` +const bentoHTTPServerTemplate = ` +streams: + test: + input: + http_server: + path: /post + timeout: 1s + output: + http_server: + ws_path: /subscribe +` func TestStreamingAPISingleClient(t *testing.T) { ctx := context.Background() @@ -273,23 +285,39 @@ func TestStreamingAPIMultipleClients(t *testing.T) { t.Cleanup(func() { nc.Close() }) + subject := "test" + messages := make(map[string]struct{}) for i := 0; i < totalMessages; i++ { - require.NoError(t, nc.Publish(subject, []byte(fmt.Sprintf("Hello %d", i))), "failed to publish message to subject") - } - - // Read messages from all clients - for clientID, wsConn := range wsConns { - err = wsConn.SetReadDeadline(time.Now().Add(5000 * time.Millisecond)) - require.NoError(t, err, fmt.Sprintf("error setting read deadline for client %d", clientID)) + message := fmt.Sprintf("Hello %d", i) + messages[message] = struct{}{} + require.NoError(t, nc.Publish(subject, []byte(message)), "failed to publish message to subject") + } + + // Read messages from all subscribers + // Messages are distributed in a round-robin fashion, count the number of messages and check the messages individually. + var readMessages int + for readMessages < totalMessages { + for clientID, wsConn := range wsConns { + // We need to stop waiting for a message if the subscriber is consumed all of its received messages. + err = wsConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err, fmt.Sprintf("error setting read deadline for client %d", clientID)) + + _, data, err := wsConn.ReadMessage() + if os.IsTimeout(err) { + continue + } + require.NoError(t, err, fmt.Sprintf("error reading message for client %d", clientID)) - for i := 0; i < totalMessages; i++ { - _, p, err := wsConn.ReadMessage() - require.NoError(t, err, fmt.Sprintf("error reading message for client %d, message %d", clientID, i)) - assert.Equal(t, fmt.Sprintf("Hello %d", i), string(p), fmt.Sprintf("message not equal for client %d", clientID)) + message := string(data) + _, ok := messages[message] + require.True(t, ok, fmt.Sprintf("message is unknown or consumed before %s", message)) + delete(messages, message) + readMessages++ } } - + // Consumed all messages + require.Empty(t, messages) } func setUpStreamAPI(ts *Test, apiName string, streamConfig string) error { @@ -781,3 +809,128 @@ func TestWebSocketConnectionClosedOnAPIReload(t *testing.T) { t.Log("WebSocket connection was successfully closed on API reload") } + +func TestStreamingAPISingleClient_Input_HTTPServer(t *testing.T) { + ts := StartTest(func(globalConf *config.Config) { + globalConf.Streaming.Enabled = true + }) + t.Cleanup(func() { + ts.Close() + }) + + apiName := "test-api" + if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil { + t.Fatal(err) + } + + const totalMessages = 3 + + dialer := websocket.Dialer{ + HandshakeTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + wsURL := strings.Replace(ts.URL, "http", "ws", 1) + fmt.Sprintf("/%s/subscribe", apiName) + wsConn, _, err := dialer.Dial(wsURL, nil) + require.NoError(t, err, "failed to connect to ws server") + t.Cleanup(func() { + if err = wsConn.Close(); err != nil { + t.Logf("failed to close ws connection: %v", err) + } + }) + + publishURL := fmt.Sprintf("%s/%s/post", ts.URL, apiName) + for i := 0; i < totalMessages; i++ { + data := []byte(fmt.Sprintf("{\"test\": \"message %d\"}", i)) + resp, err := http.Post(publishURL, "application/json", bytes.NewReader(data)) + require.NoError(t, err) + _ = resp.Body.Close() + } + + err = wsConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + require.NoError(t, err, "error setting read deadline") + + for i := 0; i < totalMessages; i++ { + _, p, err := wsConn.ReadMessage() + require.NoError(t, err, "error reading message") + assert.Equal(t, fmt.Sprintf("{\"test\": \"message %d\"}", i), string(p), "message not equal") + } +} + +func TestStreamingAPIMultipleClients_Input_HTTPServer(t *testing.T) { + // Testing input http -> output http (3 output instances and 10 messages) + // Messages are distributed in a round-robin fashion. + + ts := StartTest(func(globalConf *config.Config) { + globalConf.Streaming.Enabled = true + }) + t.Cleanup(func() { + ts.Close() + }) + + apiName := "test-api" + if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil { + t.Fatal(err) + } + + const ( + totalSubscribers = 3 + totalMessages = 10 + ) + dialer := websocket.Dialer{ + HandshakeTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + wsURL := strings.Replace(ts.URL, "http", "ws", 1) + fmt.Sprintf("/%s/subscribe", apiName) + + // Create multiple WebSocket connections + var wsConns []*websocket.Conn + for i := 0; i < totalSubscribers; i++ { + wsConn, _, err := dialer.Dial(wsURL, nil) + require.NoError(t, err, fmt.Sprintf("failed to connect to ws server for client %d", i)) + wsConns = append(wsConns, wsConn) + t.Cleanup(func() { + if err := wsConn.Close(); err != nil { + t.Logf("failed to close ws connection: %v", err) + } + }) + } + + // Publish 10 messages + messages := make(map[string]struct{}) + publishURL := fmt.Sprintf("%s/%s/post", ts.URL, apiName) + for i := 0; i < totalMessages; i++ { + message := fmt.Sprintf("{\"test\": \"message %d\"}", i) + messages[message] = struct{}{} + + data := []byte(message) + resp, err := http.Post(publishURL, "application/json", bytes.NewReader(data)) + require.NoError(t, err) + _ = resp.Body.Close() + } + + // Read messages from all subscribers + // Messages are distributed in a round-robin fashion, count the number of messages and check the messages individually. + var readMessages int + for readMessages < totalMessages { + for clientID, wsConn := range wsConns { + // We need to stop waiting for a message if the subscriber is consumed all of its received messages. + err := wsConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + require.NoError(t, err, fmt.Sprintf("error while setting read deadline for client %d", clientID)) + + _, data, err := wsConn.ReadMessage() + if os.IsTimeout(err) { + continue + } + require.NoError(t, err, fmt.Sprintf("error while reading message %d", clientID)) + + message := string(data) + _, ok := messages[message] + require.True(t, ok, fmt.Sprintf("message is unknown or consumed before %s", message)) + delete(messages, message) + readMessages++ + } + } + require.Empty(t, messages) +}