diff --git a/channel/connection_registry.go b/channel/connection_registry.go index 934f433..2eba3ac 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -1,7 +1,6 @@ package channel import ( - "bytes" "context" "sync" @@ -11,6 +10,7 @@ import ( const ( DefaultConcurencyLimitPerConnection = 25 + FrameSizeLimitInBytes = 32768 ) // DefaultConnectionRegistry is default implementation of ConnectionRegistry @@ -20,15 +20,23 @@ type DefaultConnectionRegistry struct { bufferPool *bufferPool concurrencyLimit uint mu sync.RWMutex + frameSizeLimit int64 } +type ConnectionRegistryOption func(*DefaultConnectionRegistry) + // NewDefaultConnectionRegistry creates new instance of DefaultConnectionRegistry -func NewDefaultConnectionRegistry() *DefaultConnectionRegistry { +func NewDefaultConnectionRegistry(opts ...ConnectionRegistryOption) *DefaultConnectionRegistry { reg := &DefaultConnectionRegistry{ connections: make(map[string]wasabi.Connection), onClose: make(chan string), concurrencyLimit: DefaultConcurencyLimitPerConnection, bufferPool: newBufferPool(), + frameSizeLimit: FrameSizeLimitInBytes, + } + + for _, opt := range opts { + opt(reg) } go reg.handleClose() @@ -48,6 +56,8 @@ func (r *DefaultConnectionRegistry) AddConnection( conn := NewConnection(ctx, ws, cb, r.onClose, r.bufferPool, r.concurrencyLimit) r.connections[conn.ID()] = conn + conn.ws.SetReadLimit(r.frameSizeLimit) + return conn } @@ -68,25 +78,14 @@ func (r *DefaultConnectionRegistry) handleClose() { } } -type bufferPool struct { - pool *sync.Pool -} - -func newBufferPool() *bufferPool { - return &bufferPool{ - pool: &sync.Pool{ - New: func() interface{} { - return &bytes.Buffer{} - }, - }, +// WithMaxFrameLimit sets the maximum frame size limit for incomming messages to the ConnectionRegistry. +// The limit parameter specifies the maximum frame size limit in bytes. +// This option can be used when creating a new DefaultConnectionRegistry instance. +// The default frame size limit is 32768 bytes. +// If the limit is set to -1, the frame size limit is disabled. +// When the frame size limit is exceeded, the connection is closed with status 1009 (message too large). +func WithMaxFrameLimit(limit int64) ConnectionRegistryOption { + return func(r *DefaultConnectionRegistry) { + r.frameSizeLimit = limit } } - -func (p *bufferPool) get() *bytes.Buffer { - return p.pool.Get().(*bytes.Buffer) -} - -func (p *bufferPool) put(b *bytes.Buffer) { - b.Reset() - p.pool.Put(b) -} diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index 73c9791..a0a80f0 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -2,6 +2,7 @@ package channel import ( "context" + "net/http/httptest" "sync" "testing" @@ -11,8 +12,22 @@ import ( ) func TestDefaultConnectionRegistry_AddConnection(t *testing.T) { + server := httptest.NewServer(wsHandlerEcho) + defer server.Close() + url := "ws://" + server.Listener.Addr().String() + + ws, resp, err := websocket.Dial(context.Background(), url, nil) + + if err != nil { + t.Error(err) + } + + if resp.Body != nil { + resp.Body.Close() + } + ctx := context.Background() - ws := &websocket.Conn{} + cb := func(wasabi.Connection, wasabi.MessageType, []byte) {} registry := NewDefaultConnectionRegistry() @@ -72,3 +87,11 @@ func TestDefaultConnectionRegistry_handleClose(t *testing.T) { t.Error("Expected connection to be removed from the registry") } } + +func TestDefaultConnectionRegistry_WithMaxFrameLimit(t *testing.T) { + registry := NewDefaultConnectionRegistry(WithMaxFrameLimit(100)) + + if registry.frameSizeLimit != 100 { + t.Errorf("Unexpected frame size limit: got %d, expected %d", registry.frameSizeLimit, 100) + } +}