Skip to content

Commit

Permalink
Implements possibility to limit number of active WS connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ksysoev committed Jun 2, 2024
1 parent 51aea57 commit 25bdd6b
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 2 deletions.
10 changes: 8 additions & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ func (c *Channel) wsConnectionHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if !c.connRegistry.CanAccept() {
http.Error(w, "Connection limit reached", http.StatusServiceUnavailable)
return
}

ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: c.config.originPatterns,
})
Expand All @@ -75,8 +80,9 @@ func (c *Channel) wsConnectionHandler() http.Handler {
return
}

conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch)
conn.HandleRequests()
if conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch); conn != nil {
conn.HandleRequests()
}
})
}

Expand Down
58 changes: 58 additions & 0 deletions channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package channel
import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/ksysoev/wasabi/mocks"
Expand Down Expand Up @@ -146,3 +147,60 @@ func TestChannel_Shutdown(t *testing.T) {
t.Errorf("Unexpected error: %v", err)
}
}
func TestChannel_wsConnectionHandler_CannotAcceptNewConnection(t *testing.T) {
path := "/test/path"
dispatcher := mocks.NewMockDispatcher(t)
connRegistry := mocks.NewMockConnectionRegistry(t)
connRegistry.EXPECT().CanAccept().Return(false)

channel := NewChannel(path, dispatcher, connRegistry)

// Create a mock request
mockRequest := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody)

// Create a mock response writer
mockResponseWriter := httptest.NewRecorder()

// Call the wsConnectionHandler method
handler := channel.wsConnectionHandler()

// Serve the mock request
handler.ServeHTTP(mockResponseWriter, mockRequest)

res := mockResponseWriter.Result()

defer res.Body.Close()

if res.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
}

func TestChannel_wsConnectionHandler_CanAcceptNewConnection(t *testing.T) {
path := "/test/path"
dispatcher := mocks.NewMockDispatcher(t)
connRegistry := mocks.NewMockConnectionRegistry(t)
connRegistry.EXPECT().CanAccept().Return(true)

channel := NewChannel(path, dispatcher, connRegistry)

// Create a mock request
mockRequest := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody)

// Create a mock response writer
mockResponseWriter := httptest.NewRecorder()

// Call the wsConnectionHandler method
handler := channel.wsConnectionHandler()

// Serve the mock request
handler.ServeHTTP(mockResponseWriter, mockRequest)

res := mockResponseWriter.Result()

defer res.Body.Close()

if res.StatusCode != http.StatusUpgradeRequired {
t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusUpgradeRequired)
}
}
36 changes: 36 additions & 0 deletions channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package channel

import (
"context"
"fmt"
"sync"
"time"

Expand All @@ -13,6 +14,7 @@ const (
concurencyLimitPerConnection = 25
frameSizeLimitInBytes = 32768
inActivityTimeout = 0 * time.Second
connectionLimt = -1
)

type ConnectionHook func(wasabi.Connection)
Expand All @@ -25,6 +27,7 @@ type ConnectionRegistry struct {
onConnect ConnectionHook
onDisconnect ConnectionHook
concurrencyLimit uint
connectionLimit int
frameSizeLimit int64
inActivityTimeout time.Duration
mu sync.RWMutex
Expand All @@ -42,6 +45,7 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry
bufferPool: newBufferPool(),
frameSizeLimit: frameSizeLimitInBytes,
isClosed: false,
connectionLimit: connectionLimt,
}

for _, opt := range opts {
Expand All @@ -62,6 +66,11 @@ func (r *ConnectionRegistry) AddConnection(
r.mu.Lock()
defer r.mu.Unlock()

if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit {
ws.Close(websocket.StatusTryAgainLater, "Connection limit reached")
return nil
}

if r.isClosed {
return nil
}
Expand All @@ -78,6 +87,23 @@ func (r *ConnectionRegistry) AddConnection(
return conn
}

// CanAccept checks if the connection registry can accept new connections.
// It returns true if the registry can accept new connections, and false otherwise.
func (r *ConnectionRegistry) CanAccept() bool {
fmt.Println("Connection limit", r.connectionLimit)

if r.connectionLimit <= 0 {
return true
}

r.mu.RLock()
defer r.mu.RUnlock()

fmt.Println("Connections", len(r.connections))

return len(r.connections) < r.connectionLimit
}

// GetConnection returns connection by id
func (r *ConnectionRegistry) GetConnection(id string) wasabi.Connection {
r.mu.RLock()
Expand Down Expand Up @@ -189,3 +215,13 @@ func WithOnDisconnectHook(cb ConnectionHook) ConnectionRegistryOption {
r.onDisconnect = cb
}
}

// WithConnectionLimit sets the maximum number of connections that can be accepted by the ConnectionRegistry.
// The default connection limit is -1, which means there is no limit on the number of connections.
// If the connection limit is set to a positive integer, the ConnectionRegistry will not accept new connections
// once the number of active connections reaches the specified limit.
func WithConnectionLimit(limit int) ConnectionRegistryOption {
return func(r *ConnectionRegistry) {
r.connectionLimit = limit
}
}
92 changes: 92 additions & 0 deletions channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,95 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) {
t.Error("Expected onDisconnect hook to be executed")
}
}

func TestConnectionRegistry_WithConnectionLimit(t *testing.T) {
registry := NewConnectionRegistry()

if registry.connectionLimit != -1 {
t.Errorf("Unexpected connection limit: got %d, expected %d", registry.connectionLimit, -1)
}

registry = NewConnectionRegistry(WithConnectionLimit(10))

if registry.connectionLimit != 10 {
t.Errorf("Unexpected connection limit: got %d, expected %d", registry.connectionLimit, 10)
}
}

func TestConnectionRegistry_AddConnection_ConnectionLimitReached(t *testing.T) {
registry := NewConnectionRegistry(WithConnectionLimit(2))
conn1 := mocks.NewMockConnection(t)
conn2 := mocks.NewMockConnection(t)
conn3 := mocks.NewMockConnection(t)

conn1.EXPECT().ID().Return("conn1")
conn2.EXPECT().ID().Return("conn2")
conn3.EXPECT().ID().Return("conn3")

registry.connections[conn1.ID()] = conn1
registry.connections[conn2.ID()] = conn2

ctx := context.Background()
cb := func(wasabi.Connection, wasabi.MessageType, []byte) {}

server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

if resp.Body != nil {
resp.Body.Close()
}

conn := registry.AddConnection(ctx, ws, cb)

if conn != nil {
t.Error("Expected connection to be nil")
}

if _, ok := registry.connections[conn3.ID()]; ok {
t.Error("Expected connection to not be added to the registry")
}
}

func TestConnectionRegistry_CanAccept_ConnectionLimitNotSet(t *testing.T) {
registry := NewConnectionRegistry()

if !registry.CanAccept() {
t.Error("Expected CanAccept to return true when connection limit is not set")
}

conn := mocks.NewMockConnection(t)
conn.EXPECT().ID().Return("conn1")

registry.connections[conn.ID()] = conn

if !registry.CanAccept() {
t.Error("Expected CanAccept to return true when connection limit is not set")
}
}

func TestConnectionRegistry_CanAccept_ConnectionLimitReached(t *testing.T) {
registry := NewConnectionRegistry(WithConnectionLimit(2))

conn1 := mocks.NewMockConnection(t)
conn1.EXPECT().ID().Return("conn1")

registry.connections[conn1.ID()] = conn1

if !registry.CanAccept() {
t.Error("Expected CanAccept to return true when connection limit is reached")
}

conn2 := mocks.NewMockConnection(t)
conn2.EXPECT().ID().Return("conn2")
registry.connections[conn2.ID()] = conn2

if registry.CanAccept() {
t.Error("Expected CanAccept to return false when connection limit is reached")
}
}
1 change: 1 addition & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ type ConnectionRegistry interface {
) Connection
GetConnection(id string) Connection
Close(ctx ...context.Context) error
CanAccept() bool
}
45 changes: 45 additions & 0 deletions mocks/mock_ConnectionRegistry.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 25bdd6b

Please sign in to comment.