Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Implement connection limit for WebSocket connections" #64

Merged
merged 1 commit into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ func (c *Channel) wsConnectionHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

if !c.connRegistry.CanAccept() {
http.Error(w, "Connection limit reached", http.StatusServiceUnavailable)
return
}

ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: c.config.originPatterns,
})
Expand All @@ -75,8 +80,9 @@ func (c *Channel) wsConnectionHandler() http.Handler {
return
}

conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch)
conn.HandleRequests()
if conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch); conn != nil {
conn.HandleRequests()
}
})
}

Expand Down
58 changes: 58 additions & 0 deletions channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package channel
import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/ksysoev/wasabi/mocks"
Expand Down Expand Up @@ -146,3 +147,60 @@ func TestChannel_Shutdown(t *testing.T) {
t.Errorf("Unexpected error: %v", err)
}
}
func TestChannel_wsConnectionHandler_CannotAcceptNewConnection(t *testing.T) {
path := "/test/path"
dispatcher := mocks.NewMockDispatcher(t)
connRegistry := mocks.NewMockConnectionRegistry(t)
connRegistry.EXPECT().CanAccept().Return(false)

channel := NewChannel(path, dispatcher, connRegistry)

// Create a mock request
mockRequest := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody)

// Create a mock response writer
mockResponseWriter := httptest.NewRecorder()

// Call the wsConnectionHandler method
handler := channel.wsConnectionHandler()

// Serve the mock request
handler.ServeHTTP(mockResponseWriter, mockRequest)

res := mockResponseWriter.Result()

defer res.Body.Close()

if res.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
}

func TestChannel_wsConnectionHandler_CanAcceptNewConnection(t *testing.T) {
path := "/test/path"
dispatcher := mocks.NewMockDispatcher(t)
connRegistry := mocks.NewMockConnectionRegistry(t)
connRegistry.EXPECT().CanAccept().Return(true)

channel := NewChannel(path, dispatcher, connRegistry)

// Create a mock request
mockRequest := httptest.NewRequest(http.MethodGet, "http://example.com", http.NoBody)

// Create a mock response writer
mockResponseWriter := httptest.NewRecorder()

// Call the wsConnectionHandler method
handler := channel.wsConnectionHandler()

// Serve the mock request
handler.ServeHTTP(mockResponseWriter, mockRequest)

res := mockResponseWriter.Result()

defer res.Body.Close()

if res.StatusCode != http.StatusUpgradeRequired {
t.Errorf("Unexpected status code: got %d, expected %d", res.StatusCode, http.StatusUpgradeRequired)
}
}
36 changes: 36 additions & 0 deletions channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package channel

import (
"context"
"fmt"
"sync"
"time"

Expand All @@ -13,6 +14,7 @@ const (
concurencyLimitPerConnection = 25
frameSizeLimitInBytes = 32768
inActivityTimeout = 0 * time.Second
connectionLimt = -1
)

type ConnectionHook func(wasabi.Connection)
Expand All @@ -25,6 +27,7 @@ type ConnectionRegistry struct {
onConnect ConnectionHook
onDisconnect ConnectionHook
concurrencyLimit uint
connectionLimit int
frameSizeLimit int64
inActivityTimeout time.Duration
mu sync.RWMutex
Expand All @@ -42,6 +45,7 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry
bufferPool: newBufferPool(),
frameSizeLimit: frameSizeLimitInBytes,
isClosed: false,
connectionLimit: connectionLimt,
}

for _, opt := range opts {
Expand All @@ -62,6 +66,11 @@ func (r *ConnectionRegistry) AddConnection(
r.mu.Lock()
defer r.mu.Unlock()

if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit {
ws.Close(websocket.StatusTryAgainLater, "Connection limit reached")
return nil
}

if r.isClosed {
return nil
}
Expand All @@ -78,6 +87,23 @@ func (r *ConnectionRegistry) AddConnection(
return conn
}

// CanAccept checks if the connection registry can accept new connections.
// It returns true if the registry can accept new connections, and false otherwise.
func (r *ConnectionRegistry) CanAccept() bool {
fmt.Println("Connection limit", r.connectionLimit)

if r.connectionLimit <= 0 {
return true
}

r.mu.RLock()
defer r.mu.RUnlock()

fmt.Println("Connections", len(r.connections))

return len(r.connections) < r.connectionLimit
}

// GetConnection returns connection by id
func (r *ConnectionRegistry) GetConnection(id string) wasabi.Connection {
r.mu.RLock()
Expand Down Expand Up @@ -189,3 +215,13 @@ func WithOnDisconnectHook(cb ConnectionHook) ConnectionRegistryOption {
r.onDisconnect = cb
}
}

// WithConnectionLimit sets the maximum number of connections that can be accepted by the ConnectionRegistry.
// The default connection limit is -1, which means there is no limit on the number of connections.
// If the connection limit is set to a positive integer, the ConnectionRegistry will not accept new connections
// once the number of active connections reaches the specified limit.
func WithConnectionLimit(limit int) ConnectionRegistryOption {
return func(r *ConnectionRegistry) {
r.connectionLimit = limit
}
}
92 changes: 92 additions & 0 deletions channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,95 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) {
t.Error("Expected onDisconnect hook to be executed")
}
}

func TestConnectionRegistry_WithConnectionLimit(t *testing.T) {
registry := NewConnectionRegistry()

if registry.connectionLimit != -1 {
t.Errorf("Unexpected connection limit: got %d, expected %d", registry.connectionLimit, -1)
}

registry = NewConnectionRegistry(WithConnectionLimit(10))

if registry.connectionLimit != 10 {
t.Errorf("Unexpected connection limit: got %d, expected %d", registry.connectionLimit, 10)
}
}

func TestConnectionRegistry_AddConnection_ConnectionLimitReached(t *testing.T) {
registry := NewConnectionRegistry(WithConnectionLimit(2))
conn1 := mocks.NewMockConnection(t)
conn2 := mocks.NewMockConnection(t)
conn3 := mocks.NewMockConnection(t)

conn1.EXPECT().ID().Return("conn1")
conn2.EXPECT().ID().Return("conn2")
conn3.EXPECT().ID().Return("conn3")

registry.connections[conn1.ID()] = conn1
registry.connections[conn2.ID()] = conn2

ctx := context.Background()
cb := func(wasabi.Connection, wasabi.MessageType, []byte) {}

server := httptest.NewServer(wsHandlerEcho)
defer server.Close()
url := "ws://" + server.Listener.Addr().String()

ws, resp, err := websocket.Dial(context.Background(), url, nil)
if err != nil {
t.Errorf("Unexpected error dialing websocket: %v", err)
}

if resp.Body != nil {
resp.Body.Close()
}

conn := registry.AddConnection(ctx, ws, cb)

if conn != nil {
t.Error("Expected connection to be nil")
}

if _, ok := registry.connections[conn3.ID()]; ok {
t.Error("Expected connection to not be added to the registry")
}
}

func TestConnectionRegistry_CanAccept_ConnectionLimitNotSet(t *testing.T) {
registry := NewConnectionRegistry()

if !registry.CanAccept() {
t.Error("Expected CanAccept to return true when connection limit is not set")
}

conn := mocks.NewMockConnection(t)
conn.EXPECT().ID().Return("conn1")

registry.connections[conn.ID()] = conn

if !registry.CanAccept() {
t.Error("Expected CanAccept to return true when connection limit is not set")
}
}

func TestConnectionRegistry_CanAccept_ConnectionLimitReached(t *testing.T) {
registry := NewConnectionRegistry(WithConnectionLimit(2))

conn1 := mocks.NewMockConnection(t)
conn1.EXPECT().ID().Return("conn1")

registry.connections[conn1.ID()] = conn1

if !registry.CanAccept() {
t.Error("Expected CanAccept to return true when connection limit is reached")
}

conn2 := mocks.NewMockConnection(t)
conn2.EXPECT().ID().Return("conn2")
registry.connections[conn2.ID()] = conn2

if registry.CanAccept() {
t.Error("Expected CanAccept to return false when connection limit is reached")
}
}
1 change: 1 addition & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ type ConnectionRegistry interface {
) Connection
GetConnection(id string) Connection
Close(ctx ...context.Context) error
CanAccept() bool
}
45 changes: 45 additions & 0 deletions mocks/mock_ConnectionRegistry.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading