Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-13139] Request times out in some cases when sending input via http inputs #6601

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
109 changes: 89 additions & 20 deletions gateway/mw_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
Expand All @@ -22,6 +23,8 @@ import (
const (
// ExtensionTykStreaming is the oas extension for tyk streaming
ExtensionTykStreaming = "x-tyk-streaming"
StreamGCInterval = 1 * time.Minute
StreamInactiveLimit = 10 * time.Minute
)

// StreamsConfig represents a stream configuration
Expand All @@ -38,11 +41,18 @@ var globalStreamCounter atomic.Int64
// StreamingMiddleware is a middleware that handles streaming functionality
type StreamingMiddleware struct {
*BaseMiddleware

createStreamManagerLock sync.Mutex
streamManagerCache sync.Map // Map of payload hash to StreamManager

streamManagers sync.Map // Map of consumer group IDs to StreamManager
ctx context.Context
cancel context.CancelFunc
allowedUnsafe []string
defaultStreamManager *StreamManager

lastActivity sync.Map // Map of stream IDs to last activity time

}

// StreamManager is responsible for creating a single stream
Expand All @@ -53,6 +63,8 @@ type StreamManager struct {
mw *StreamingMiddleware
dryRun bool
listenPaths []string

lastActivity atomic.Value // Last activity time for the StreamManager
}

func (sm *StreamManager) initStreams(r *http.Request, config *StreamsConfig) {
Expand Down Expand Up @@ -114,6 +126,34 @@ func (sm *StreamManager) removeStream(streamID string) error {
return nil
}

func (s *StreamingMiddleware) garbageCollect() {
s.Logger().Debug("Starting garbage collection for inactive stream managers")
now := time.Now()

s.streamManagerCache.Range(func(key, value interface{}) bool {
manager := value.(*StreamManager)
if manager == s.defaultStreamManager {
return true
}

lastActivityTime := manager.lastActivity.Load().(time.Time)
if now.Sub(lastActivityTime) > StreamInactiveLimit {
s.Logger().Infof("Removing inactive stream manager: %v", key)
manager.streams.Range(func(streamKey, streamValue interface{}) bool {
streamID := streamKey.(string)
err := manager.removeStream(streamID)
if err != nil {
s.Logger().Errorf("Error removing stream %s: %v", streamID, err)
}
return true
})
s.streamManagerCache.Delete(key)
}

return true
})
}

// Name is StreamingMiddleware
func (s *StreamingMiddleware) Name() string {
return "StreamingMiddleware"
Expand Down Expand Up @@ -150,20 +190,56 @@ func (s *StreamingMiddleware) Init() {

s.Logger().Debug("Initializing default stream manager")
s.defaultStreamManager = s.createStreamManager(nil)

// Start garbage collection routine
go func() {
ticker := time.NewTicker(StreamGCInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
s.garbageCollect()
case <-s.ctx.Done():
return
}
}
}()
}

func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManager {
newStreamManager := &StreamManager{
muxer: mux.NewRouter(),
mw: s,
dryRun: r == nil,
streamsConfig := s.getStreamsConfig(r)
configJSON, _ := json.Marshal(streamsConfig)
cacheKey := fmt.Sprintf("%x", md5.Sum(configJSON))
buraksezer marked this conversation as resolved.
Show resolved Hide resolved

// Critical section starts here
// This section is called by ProcessRequest method of the middleware implementation
// Concurrent requests can call this method at the same time and those requests
// creates new StreamManagers and store them concurrently, as a result
// the returned stream manager has overwritten by a different one by leaking
// the previously stored StreamManager.
s.createStreamManagerLock.Lock()
defer s.createStreamManagerLock.Unlock()

s.Logger().Debug("Attempting to load stream manager from cache")
s.Logger().Debugf("Cache key: %s", cacheKey)
if cachedManager, found := s.streamManagerCache.Load(cacheKey); found {
s.Logger().Debug("Found cached stream manager")
return cachedManager.(*StreamManager)
}
streamID := fmt.Sprintf("_%d", time.Now().UnixNano())
s.streamManagers.Store(streamID, newStreamManager)

// Call initStreams for the new StreamManager
newStreamManager.initStreams(r, s.getStreamsConfig(r))
newStreamManager := &StreamManager{
muxer: mux.NewRouter(),
mw: s,
dryRun: r == nil,
lastActivity: atomic.Value{},
}
newStreamManager.lastActivity.Store(time.Now())
newStreamManager.initStreams(r, streamsConfig)

if r != nil {
s.streamManagerCache.Store(cacheKey, newStreamManager)
}
return newStreamManager
}

Expand Down Expand Up @@ -372,11 +448,8 @@ func (s *StreamingMiddleware) Unload() {
s.cancel()

s.Logger().Debug("Closing active streams")
s.streamManagers.Range(func(_, value interface{}) bool {
manager, ok := value.(*StreamManager)
if !ok {
return true
}
s.streamManagerCache.Range(func(_, value interface{}) bool {
manager := value.(*StreamManager)
manager.streams.Range(func(_, streamValue interface{}) bool {
if stream, ok := streamValue.(*streaming.Stream); ok {
if err := stream.Reset(); err != nil {
Expand All @@ -389,6 +462,7 @@ func (s *StreamingMiddleware) Unload() {
})

s.streamManagers = sync.Map{}
s.streamManagerCache = sync.Map{}

s.Logger().Info("All streams successfully removed")
}
Expand All @@ -411,14 +485,9 @@ func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter,

h.sm.routeLock.Lock()
h.muxer.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
defer func() {
// Stop the stream when the HTTP request finishes
if err := h.sm.removeStream(h.streamID); err != nil {
h.logger.Errorf("Failed to stop stream %s: %v", h.streamID, err)
}
}()

h.sm.lastActivity.Store(time.Now())
f(w, r)
h.sm.lastActivity.Store(time.Now())
})
h.sm.routeLock.Unlock()
h.logger.Debugf("Registered handler for path: %s", path)
Expand Down
177 changes: 165 additions & 12 deletions gateway/mw_streaming_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gateway

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
Expand Down Expand Up @@ -147,6 +148,17 @@ streams:
static_fields:
'@service': benthos
`
const bentoHTTPServerTemplate = `
streams:
test:
input:
http_server:
path: /post
timeout: 1s
output:
http_server:
ws_path: /subscribe
`

func TestStreamingAPISingleClient(t *testing.T) {
ctx := context.Background()
Expand Down Expand Up @@ -273,23 +285,39 @@ func TestStreamingAPIMultipleClients(t *testing.T) {
t.Cleanup(func() {
nc.Close()
})

subject := "test"
messages := make(map[string]struct{})
for i := 0; i < totalMessages; i++ {
require.NoError(t, nc.Publish(subject, []byte(fmt.Sprintf("Hello %d", i))), "failed to publish message to subject")
}

// Read messages from all clients
for clientID, wsConn := range wsConns {
err = wsConn.SetReadDeadline(time.Now().Add(5000 * time.Millisecond))
require.NoError(t, err, fmt.Sprintf("error setting read deadline for client %d", clientID))
message := fmt.Sprintf("Hello %d", i)
messages[message] = struct{}{}
require.NoError(t, nc.Publish(subject, []byte(message)), "failed to publish message to subject")
}

// Read messages from all subscribers
// Messages are distributed in a round-robin fashion, count the number of messages and check the messages individually.
var readMessages int
for readMessages < totalMessages {
for clientID, wsConn := range wsConns {
// We need to stop waiting for a message if the subscriber is consumed all of its received messages.
err = wsConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
require.NoError(t, err, fmt.Sprintf("error setting read deadline for client %d", clientID))

_, data, err := wsConn.ReadMessage()
if os.IsTimeout(err) {
continue
}
require.NoError(t, err, fmt.Sprintf("error reading message for client %d", clientID))

for i := 0; i < totalMessages; i++ {
_, p, err := wsConn.ReadMessage()
require.NoError(t, err, fmt.Sprintf("error reading message for client %d, message %d", clientID, i))
assert.Equal(t, fmt.Sprintf("Hello %d", i), string(p), fmt.Sprintf("message not equal for client %d", clientID))
message := string(data)
_, ok := messages[message]
require.True(t, ok, fmt.Sprintf("message is unknown or consumed before %s", message))
delete(messages, message)
readMessages++
}
}

// Consumed all messages
require.Empty(t, messages)
}

func setUpStreamAPI(ts *Test, apiName string, streamConfig string) error {
Expand Down Expand Up @@ -781,3 +809,128 @@ func TestWebSocketConnectionClosedOnAPIReload(t *testing.T) {

t.Log("WebSocket connection was successfully closed on API reload")
}

func TestStreamingAPISingleClient_Input_HTTPServer(t *testing.T) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
ts.Close()
})

apiName := "test-api"
if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil {
t.Fatal(err)
}

const totalMessages = 3

dialer := websocket.Dialer{
HandshakeTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}

wsURL := strings.Replace(ts.URL, "http", "ws", 1) + fmt.Sprintf("/%s/subscribe", apiName)
wsConn, _, err := dialer.Dial(wsURL, nil)
require.NoError(t, err, "failed to connect to ws server")
t.Cleanup(func() {
if err = wsConn.Close(); err != nil {
t.Logf("failed to close ws connection: %v", err)
}
})

publishURL := fmt.Sprintf("%s/%s/post", ts.URL, apiName)
for i := 0; i < totalMessages; i++ {
data := []byte(fmt.Sprintf("{\"test\": \"message %d\"}", i))
resp, err := http.Post(publishURL, "application/json", bytes.NewReader(data))
require.NoError(t, err)
_ = resp.Body.Close()
}

err = wsConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
require.NoError(t, err, "error setting read deadline")

for i := 0; i < totalMessages; i++ {
_, p, err := wsConn.ReadMessage()
require.NoError(t, err, "error reading message")
assert.Equal(t, fmt.Sprintf("{\"test\": \"message %d\"}", i), string(p), "message not equal")
}
}

func TestStreamingAPIMultipleClients_Input_HTTPServer(t *testing.T) {
// Testing input http -> output http (3 output instances and 10 messages)
// Messages are distributed in a round-robin fashion.

ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
ts.Close()
})

apiName := "test-api"
if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil {
t.Fatal(err)
}

const (
totalSubscribers = 3
totalMessages = 10
)
dialer := websocket.Dialer{
HandshakeTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}

wsURL := strings.Replace(ts.URL, "http", "ws", 1) + fmt.Sprintf("/%s/subscribe", apiName)

// Create multiple WebSocket connections
var wsConns []*websocket.Conn
for i := 0; i < totalSubscribers; i++ {
wsConn, _, err := dialer.Dial(wsURL, nil)
require.NoError(t, err, fmt.Sprintf("failed to connect to ws server for client %d", i))
wsConns = append(wsConns, wsConn)
t.Cleanup(func() {
if err := wsConn.Close(); err != nil {
t.Logf("failed to close ws connection: %v", err)
}
})
}

// Publish 10 messages
messages := make(map[string]struct{})
publishURL := fmt.Sprintf("%s/%s/post", ts.URL, apiName)
for i := 0; i < totalMessages; i++ {
message := fmt.Sprintf("{\"test\": \"message %d\"}", i)
messages[message] = struct{}{}

data := []byte(message)
resp, err := http.Post(publishURL, "application/json", bytes.NewReader(data))
require.NoError(t, err)
_ = resp.Body.Close()
}

// Read messages from all subscribers
// Messages are distributed in a round-robin fashion, count the number of messages and check the messages individually.
var readMessages int
for readMessages < totalMessages {
for clientID, wsConn := range wsConns {
// We need to stop waiting for a message if the subscriber is consumed all of its received messages.
err := wsConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
require.NoError(t, err, fmt.Sprintf("error while setting read deadline for client %d", clientID))

_, data, err := wsConn.ReadMessage()
if os.IsTimeout(err) {
continue
}
require.NoError(t, err, fmt.Sprintf("error while reading message %d", clientID))

message := string(data)
_, ok := messages[message]
require.True(t, ok, fmt.Sprintf("message is unknown or consumed before %s", message))
delete(messages, message)
readMessages++
}
}
require.Empty(t, messages)
}
Loading