diff --git a/channel/channel.go b/channel/channel.go index 3e8d8ba..31f9ae8 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -67,6 +67,11 @@ func (c *Channel) wsConnectionHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + if !c.connRegistry.CanAccept() { + http.Error(w, "Connection limit reached", http.StatusServiceUnavailable) + return + } + ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: c.config.originPatterns, }) @@ -75,8 +80,9 @@ func (c *Channel) wsConnectionHandler() http.Handler { return } - conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch) - conn.HandleRequests() + if conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch); conn != nil { + conn.HandleRequests() + } }) } diff --git a/channel/channel_test.go b/channel/channel_test.go index e1ecab2..ebe34af 100644 --- a/channel/channel_test.go +++ b/channel/channel_test.go @@ -3,6 +3,7 @@ package channel import ( "context" "net/http" + "net/http/httptest" "testing" "github.com/ksysoev/wasabi/mocks" @@ -146,3 +147,60 @@ func TestChannel_Shutdown(t *testing.T) { t.Errorf("Unexpected error: %v", err) } } +func TestChannel_wsConnectionHandler_CannotAcceptNewConnection(t *testing.T) { + path := "/test/path" + dispatcher := mocks.NewMockDispatcher(t) + connRegistry := mocks.NewMockConnectionRegistry(t) + connRegistry.EXPECT().CanAccept().Return(false) + + channel := NewChannel(path, dispatcher, connRegistry) + + // Create a mock request + mockRequest := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) + + // Create a mock response writer + mockResponseWriter := httptest.NewRecorder() + + // Call the wsConnectionHandler method + handler := channel.wsConnectionHandler() + + // Serve the mock request + handler.ServeHTTP(mockResponseWriter, mockRequest) + + res := mockResponseWriter.Result() + + defer res.Body.Close() + + if res.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusServiceUnavailable) + } +} + +func TestChannel_wsConnectionHandler_CanAcceptNewConnection(t *testing.T) { + path := "/test/path" + dispatcher := mocks.NewMockDispatcher(t) + connRegistry := mocks.NewMockConnectionRegistry(t) + connRegistry.EXPECT().CanAccept().Return(true) + + channel := NewChannel(path, dispatcher, connRegistry) + + // Create a mock request + mockRequest := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody) + + // Create a mock response writer + mockResponseWriter := httptest.NewRecorder() + + // Call the wsConnectionHandler method + handler := channel.wsConnectionHandler() + + // Serve the mock request + handler.ServeHTTP(mockResponseWriter, mockRequest) + + res := mockResponseWriter.Result() + + defer res.Body.Close() + + if res.StatusCode != http.StatusUpgradeRequired { + t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusUpgradeRequired) + } +} diff --git a/channel/connection_registry.go b/channel/connection_registry.go index 3d727c6..8aa9a87 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -2,6 +2,7 @@ package channel import ( "context" + "fmt" "sync" "time" @@ -13,6 +14,7 @@ const ( concurencyLimitPerConnection = 25 frameSizeLimitInBytes = 32768 inActivityTimeout = 0 * time.Second + connectionLimt = -1 ) type ConnectionHook func(wasabi.Connection) @@ -25,6 +27,7 @@ type ConnectionRegistry struct { onConnect ConnectionHook onDisconnect ConnectionHook concurrencyLimit uint + connectionLimit int frameSizeLimit int64 inActivityTimeout time.Duration mu sync.RWMutex @@ -42,6 +45,7 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry bufferPool: newBufferPool(), frameSizeLimit: frameSizeLimitInBytes, isClosed: false, + connectionLimit: connectionLimt, } for _, opt := range opts { @@ -62,6 +66,11 @@ func (r *ConnectionRegistry) AddConnection( r.mu.Lock() defer r.mu.Unlock() + if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit { + ws.Close(websocket.StatusTryAgainLater, "Connection limit reached") + return nil + } + if r.isClosed { return nil } @@ -78,6 +87,23 @@ func (r *ConnectionRegistry) AddConnection( return conn } +// CanAccept checks if the connection registry can accept new connections. +// It returns true if the registry can accept new connections, and false otherwise. +func (r *ConnectionRegistry) CanAccept() bool { + fmt.Println("Connection limit", r.connectionLimit) + + if r.connectionLimit <= 0 { + return true + } + + r.mu.RLock() + defer r.mu.RUnlock() + + fmt.Println("Connections", len(r.connections)) + + return len(r.connections) < r.connectionLimit +} + // GetConnection returns connection by id func (r *ConnectionRegistry) GetConnection(id string) wasabi.Connection { r.mu.RLock() @@ -189,3 +215,13 @@ func WithOnDisconnectHook(cb ConnectionHook) ConnectionRegistryOption { r.onDisconnect = cb } } + +// WithConnectionLimit sets the maximum number of connections that can be accepted by the ConnectionRegistry. +// The default connection limit is -1, which means there is no limit on the number of connections. +// If the connection limit is set to a positive integer, the ConnectionRegistry will not accept new connections +// once the number of active connections reaches the specified limit. +func WithConnectionLimit(limit int) ConnectionRegistryOption { + return func(r *ConnectionRegistry) { + r.connectionLimit = limit + } +} diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index ed2ff82..3fa1641 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -273,3 +273,95 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) { t.Error("Expected onDisconnect hook to be executed") } } + +func TestConnectionRegistry_WithConnectionLimit(t *testing.T) { + registry := NewConnectionRegistry() + + if registry.connectionLimit != -1 { + t.Errorf("Unexpected connection limit: got %d, expected %d", registry.connectionLimit, -1) + } + + registry = NewConnectionRegistry(WithConnectionLimit(10)) + + if registry.connectionLimit != 10 { + t.Errorf("Unexpected connection limit: got %d, expected %d", registry.connectionLimit, 10) + } +} + +func TestConnectionRegistry_AddConnection_ConnectionLimitReached(t *testing.T) { + registry := NewConnectionRegistry(WithConnectionLimit(2)) + conn1 := mocks.NewMockConnection(t) + conn2 := mocks.NewMockConnection(t) + conn3 := mocks.NewMockConnection(t) + + conn1.EXPECT().ID().Return("conn1") + conn2.EXPECT().ID().Return("conn2") + conn3.EXPECT().ID().Return("conn3") + + registry.connections[conn1.ID()] = conn1 + registry.connections[conn2.ID()] = conn2 + + ctx := context.Background() + cb := func(wasabi.Connection, wasabi.MessageType, []byte) {} + + server := httptest.NewServer(wsHandlerEcho) + defer server.Close() + url := "ws://" + server.Listener.Addr().String() + + ws, resp, err := websocket.Dial(context.Background(), url, nil) + if err != nil { + t.Errorf("Unexpected error dialing websocket: %v", err) + } + + if resp.Body != nil { + resp.Body.Close() + } + + conn := registry.AddConnection(ctx, ws, cb) + + if conn != nil { + t.Error("Expected connection to be nil") + } + + if _, ok := registry.connections[conn3.ID()]; ok { + t.Error("Expected connection to not be added to the registry") + } +} + +func TestConnectionRegistry_CanAccept_ConnectionLimitNotSet(t *testing.T) { + registry := NewConnectionRegistry() + + if !registry.CanAccept() { + t.Error("Expected CanAccept to return true when connection limit is not set") + } + + conn := mocks.NewMockConnection(t) + conn.EXPECT().ID().Return("conn1") + + registry.connections[conn.ID()] = conn + + if !registry.CanAccept() { + t.Error("Expected CanAccept to return true when connection limit is not set") + } +} + +func TestConnectionRegistry_CanAccept_ConnectionLimitReached(t *testing.T) { + registry := NewConnectionRegistry(WithConnectionLimit(2)) + + conn1 := mocks.NewMockConnection(t) + conn1.EXPECT().ID().Return("conn1") + + registry.connections[conn1.ID()] = conn1 + + if !registry.CanAccept() { + t.Error("Expected CanAccept to return true when connection limit is reached") + } + + conn2 := mocks.NewMockConnection(t) + conn2.EXPECT().ID().Return("conn2") + registry.connections[conn2.ID()] = conn2 + + if registry.CanAccept() { + t.Error("Expected CanAccept to return false when connection limit is reached") + } +} diff --git a/interfaces.go b/interfaces.go index 8d07b8b..8a75acb 100644 --- a/interfaces.go +++ b/interfaces.go @@ -59,4 +59,5 @@ type ConnectionRegistry interface { ) Connection GetConnection(id string) Connection Close(ctx ...context.Context) error + CanAccept() bool } diff --git a/mocks/mock_ConnectionRegistry.go b/mocks/mock_ConnectionRegistry.go index 57d6972..d927b40 100644 --- a/mocks/mock_ConnectionRegistry.go +++ b/mocks/mock_ConnectionRegistry.go @@ -76,6 +76,51 @@ func (_c *MockConnectionRegistry_AddConnection_Call) RunAndReturn(run func(conte return _c } +// CanAccept provides a mock function with given fields: +func (_m *MockConnectionRegistry) CanAccept() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for CanAccept") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockConnectionRegistry_CanAccept_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CanAccept' +type MockConnectionRegistry_CanAccept_Call struct { + *mock.Call +} + +// CanAccept is a helper method to define mock.On call +func (_e *MockConnectionRegistry_Expecter) CanAccept() *MockConnectionRegistry_CanAccept_Call { + return &MockConnectionRegistry_CanAccept_Call{Call: _e.mock.On("CanAccept")} +} + +func (_c *MockConnectionRegistry_CanAccept_Call) Run(run func()) *MockConnectionRegistry_CanAccept_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockConnectionRegistry_CanAccept_Call) Return(_a0 bool) *MockConnectionRegistry_CanAccept_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockConnectionRegistry_CanAccept_Call) RunAndReturn(run func() bool) *MockConnectionRegistry_CanAccept_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: ctx func (_m *MockConnectionRegistry) Close(ctx ...context.Context) error { _va := make([]interface{}, len(ctx))