From f2c8be99079c3f5f3adc6f654c84662131766f9b Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sun, 2 Jun 2024 20:52:05 +0800 Subject: [PATCH 1/2] Refactor HandlingRequest method on connection --- channel/channel.go | 4 +- channel/connection.go | 4 +- channel/connection_registry.go | 11 +-- channel/connection_registry_test.go | 103 ++++++++++++++++++++++------ channel/connection_test.go | 4 +- channel/connection_wrapper.go | 5 -- channel/connection_wrapper_test.go | 10 --- interfaces.go | 5 +- mocks/mock_Connection.go | 32 --------- mocks/mock_ConnectionRegistry.go | 85 ++++++++++------------- 10 files changed, 129 insertions(+), 134 deletions(-) diff --git a/channel/channel.go b/channel/channel.go index 31f9ae8..77ced9e 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -80,9 +80,7 @@ func (c *Channel) wsConnectionHandler() http.Handler { return } - if conn := c.connRegistry.AddConnection(ctx, ws, c.disptacher.Dispatch); conn != nil { - conn.HandleRequests() - } + c.connRegistry.HandleConnection(ctx, ws, c.disptacher.Dispatch) }) } diff --git a/channel/connection.go b/channel/connection.go index 2bc1d72..49cd720 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -91,8 +91,8 @@ func (c *Conn) Context() context.Context { return c.ctx } -// HandleRequests handles incoming messages -func (c *Conn) HandleRequests() { +// handleRequests handles incoming messages +func (c *Conn) handleRequests() { defer c.close() for c.ctx.Err() == nil { diff --git a/channel/connection_registry.go b/channel/connection_registry.go index 8aa9a87..6334681 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -58,21 +58,22 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry } // AddConnection adds new Websocket connection to registry -func (r *ConnectionRegistry) AddConnection( +func (r *ConnectionRegistry) HandleConnection( ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage, -) wasabi.Connection { +) { r.mu.Lock() defer r.mu.Unlock() if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit { ws.Close(websocket.StatusTryAgainLater, "Connection limit reached") - return nil + return } if r.isClosed { - return nil + ws.Close(websocket.StatusServiceRestart, "Server is shutting down") + return } conn := NewConnection(ctx, ws, cb, r.onClose, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout) @@ -84,7 +85,7 @@ func (r *ConnectionRegistry) AddConnection( r.onConnect(conn) } - return conn + conn.handleRequests() } // CanAccept checks if the connection registry can accept new connections. diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index 3fa1641..4898abb 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -12,7 +12,7 @@ import ( "nhooyr.io/websocket" ) -func TestConnectionRegistry_AddConnection(t *testing.T) { +func TestConnectionRegistry_HandleConnection(t *testing.T) { server := httptest.NewServer(wsHandlerEcho) defer server.Close() url := "ws://" + server.Listener.Addr().String() @@ -20,28 +20,50 @@ func TestConnectionRegistry_AddConnection(t *testing.T) { ws, resp, err := websocket.Dial(context.Background(), url, nil) if err != nil { - t.Error(err) + t.Errorf("Unexpected error dialing websocket: %v", err) + } + + err = ws.Write(context.Background(), websocket.MessageText, []byte("test")) + if err != nil { + t.Errorf("Unexpected error writing to websocket: %v", err) } if resp.Body != nil { resp.Body.Close() } - ctx := context.Background() - - cb := func(wasabi.Connection, wasabi.MessageType, []byte) {} + ready := make(chan struct{}) + cb := func(wasabi.Connection, wasabi.MessageType, []byte) { + close(ready) + } registry := NewConnectionRegistry() - conn := registry.AddConnection(ctx, ws, cb) + ctx, cancel := context.WithCancel(context.Background()) - if conn == nil { - t.Error("Expected connection to be created") + done := make(chan struct{}) + go func() { + registry.HandleConnection(ctx, ws, cb) + close(done) + }() + + select { + case <-ready: + case <-time.After(1 * time.Second): + t.Error("Expected connection to be handled") } - if _, ok := registry.connections[conn.ID()]; !ok { + if len(registry.connections) != 1 { t.Error("Expected connection to be added to the registry") } + + cancel() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected connection to be closed") + } } func TestConnectionRegistry_AddConnection_ToClosedRegistry(t *testing.T) { @@ -67,10 +89,16 @@ func TestConnectionRegistry_AddConnection_ToClosedRegistry(t *testing.T) { cb := func(wasabi.Connection, wasabi.MessageType, []byte) {} - conn := registry.AddConnection(ctx, ws, cb) + done := make(chan struct{}) + go func() { + registry.HandleConnection(ctx, ws, cb) + close(done) + }() - if conn != nil { - t.Error("Expected connection to be nil") + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected connection to be closed") } } @@ -181,6 +209,7 @@ func TestConnectionRegistry_WithInActivityTimeout(t *testing.T) { t.Errorf("Unexpected inactivity timeout: got %s, expected %s", registry.inActivityTimeout, 5*time.Minute) } } + func TestConnectionRegistry_WithOnConnect(t *testing.T) { registry := NewConnectionRegistry() @@ -217,7 +246,20 @@ func TestConnectionRegistry_WithOnConnect(t *testing.T) { t.Error("Expected onConnect callback to be set") } - registry.AddConnection(context.Background(), ws, func(wasabi.Connection, wasabi.MessageType, []byte) {}) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + done := make(chan struct{}) + go func() { + registry.HandleConnection(ctx, ws, func(wasabi.Connection, wasabi.MessageType, []byte) {}) + close(done) + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected connection to be closed") + } if !executed { t.Error("Expected onConnect callback to be executed") @@ -260,11 +302,24 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) { resp.Body.Close() } - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + cb := func(wasabi.Connection, wasabi.MessageType, []byte) {} - conn := registry.AddConnection(ctx, ws, cb) - registry.onClose <- conn.ID() + ready := make(chan struct{}) + + go func() { + registry.HandleConnection(ctx, ws, cb) + close(ready) + }() + + select { + case <-ready: + case <-time.After(1 * time.Second): + t.Error("Expected connection to be handled") + } + close(registry.onClose) select { @@ -292,11 +347,9 @@ func TestConnectionRegistry_AddConnection_ConnectionLimitReached(t *testing.T) { registry := NewConnectionRegistry(WithConnectionLimit(2)) conn1 := mocks.NewMockConnection(t) conn2 := mocks.NewMockConnection(t) - conn3 := mocks.NewMockConnection(t) conn1.EXPECT().ID().Return("conn1") conn2.EXPECT().ID().Return("conn2") - conn3.EXPECT().ID().Return("conn3") registry.connections[conn1.ID()] = conn1 registry.connections[conn2.ID()] = conn2 @@ -317,13 +370,19 @@ func TestConnectionRegistry_AddConnection_ConnectionLimitReached(t *testing.T) { resp.Body.Close() } - conn := registry.AddConnection(ctx, ws, cb) + done := make(chan struct{}) + go func() { + registry.HandleConnection(ctx, ws, cb) + close(done) + }() - if conn != nil { - t.Error("Expected connection to be nil") + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected connection to be handled") } - if _, ok := registry.connections[conn3.ID()]; ok { + if len(registry.connections) != 2 { t.Error("Expected connection to not be added to the registry") } } diff --git a/channel/connection_test.go b/channel/connection_test.go index 1417cad..3dd5037 100644 --- a/channel/connection_test.go +++ b/channel/connection_test.go @@ -65,7 +65,7 @@ func TestConn_Context(t *testing.T) { } } -func TestConn_HandleRequests(t *testing.T) { +func TestConn_handleRequests(t *testing.T) { server := httptest.NewServer(wsHandlerEcho) defer server.Close() @@ -91,7 +91,7 @@ func TestConn_HandleRequests(t *testing.T) { conn.onMessageCB = func(c wasabi.Connection, msgType wasabi.MessageType, data []byte) { received <- struct{}{} } - go conn.HandleRequests() + go conn.handleRequests() // Send message to trigger OnMessage callback err = ws.Write(context.Background(), websocket.MessageText, []byte("test message")) diff --git a/channel/connection_wrapper.go b/channel/connection_wrapper.go index 72967c4..9852d01 100644 --- a/channel/connection_wrapper.go +++ b/channel/connection_wrapper.go @@ -48,11 +48,6 @@ func (cw *ConnectionWrapper) Context() context.Context { return cw.connection.Context() } -// HandleRequests handles incoming requests on the connection. -func (cw *ConnectionWrapper) HandleRequests() { - cw.connection.HandleRequests() -} - // Send sends a message of the specified type and content over the connection. // If an onSendWrapper function is set, it will be called instead of directly sending the message. // The onSendWrapper function should have the signature func(connection Connection, msgType MessageType, msg []byte) error. diff --git a/channel/connection_wrapper_test.go b/channel/connection_wrapper_test.go index 82e673d..0673c27 100644 --- a/channel/connection_wrapper_test.go +++ b/channel/connection_wrapper_test.go @@ -43,16 +43,6 @@ func TestConnectionWrapper_Context(t *testing.T) { assert.Equal(t, expectedContext, actualContext) } -func TestConnectionWrapper_HandleRequests(t *testing.T) { - mockConnection := mocks.NewMockConnection(t) - wrapper := NewConnectionWrapper(mockConnection) - - mockConnection.On("HandleRequests").Once() - - wrapper.HandleRequests() - - mockConnection.AssertExpectations(t) -} func TestConnectionWrapper_Send_WithOnSendWrapper(t *testing.T) { mockConnection := mocks.NewMockConnection(t) wrapper := NewConnectionWrapper(mockConnection) diff --git a/interfaces.go b/interfaces.go index 8a75acb..30998b0 100644 --- a/interfaces.go +++ b/interfaces.go @@ -34,7 +34,6 @@ type Connection interface { Send(msgType MessageType, msg []byte) error Context() context.Context ID() string - HandleRequests() Close(status websocket.StatusCode, reason string, closingCtx ...context.Context) error } @@ -52,11 +51,11 @@ type Channel interface { // ConnectionRegistry is interface for connection registries type ConnectionRegistry interface { - AddConnection( + HandleConnection( ctx context.Context, ws *websocket.Conn, cb OnMessage, - ) Connection + ) GetConnection(id string) Connection Close(ctx ...context.Context) error CanAccept() bool diff --git a/mocks/mock_Connection.go b/mocks/mock_Connection.go index a12a34e..e31c4ab 100644 --- a/mocks/mock_Connection.go +++ b/mocks/mock_Connection.go @@ -134,38 +134,6 @@ func (_c *MockConnection_Context_Call) RunAndReturn(run func() context.Context) return _c } -// HandleRequests provides a mock function with given fields: -func (_m *MockConnection) HandleRequests() { - _m.Called() -} - -// MockConnection_HandleRequests_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HandleRequests' -type MockConnection_HandleRequests_Call struct { - *mock.Call -} - -// HandleRequests is a helper method to define mock.On call -func (_e *MockConnection_Expecter) HandleRequests() *MockConnection_HandleRequests_Call { - return &MockConnection_HandleRequests_Call{Call: _e.mock.On("HandleRequests")} -} - -func (_c *MockConnection_HandleRequests_Call) Run(run func()) *MockConnection_HandleRequests_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockConnection_HandleRequests_Call) Return() *MockConnection_HandleRequests_Call { - _c.Call.Return() - return _c -} - -func (_c *MockConnection_HandleRequests_Call) RunAndReturn(run func()) *MockConnection_HandleRequests_Call { - _c.Call.Return(run) - return _c -} - // ID provides a mock function with given fields: func (_m *MockConnection) ID() string { ret := _m.Called() diff --git a/mocks/mock_ConnectionRegistry.go b/mocks/mock_ConnectionRegistry.go index d927b40..5bba9f1 100644 --- a/mocks/mock_ConnectionRegistry.go +++ b/mocks/mock_ConnectionRegistry.go @@ -26,56 +26,6 @@ func (_m *MockConnectionRegistry) EXPECT() *MockConnectionRegistry_Expecter { return &MockConnectionRegistry_Expecter{mock: &_m.Mock} } -// AddConnection provides a mock function with given fields: ctx, ws, cb -func (_m *MockConnectionRegistry) AddConnection(ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage) wasabi.Connection { - ret := _m.Called(ctx, ws, cb) - - if len(ret) == 0 { - panic("no return value specified for AddConnection") - } - - var r0 wasabi.Connection - if rf, ok := ret.Get(0).(func(context.Context, *websocket.Conn, wasabi.OnMessage) wasabi.Connection); ok { - r0 = rf(ctx, ws, cb) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(wasabi.Connection) - } - } - - return r0 -} - -// MockConnectionRegistry_AddConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddConnection' -type MockConnectionRegistry_AddConnection_Call struct { - *mock.Call -} - -// AddConnection is a helper method to define mock.On call -// - ctx context.Context -// - ws *websocket.Conn -// - cb wasabi.OnMessage -func (_e *MockConnectionRegistry_Expecter) AddConnection(ctx interface{}, ws interface{}, cb interface{}) *MockConnectionRegistry_AddConnection_Call { - return &MockConnectionRegistry_AddConnection_Call{Call: _e.mock.On("AddConnection", ctx, ws, cb)} -} - -func (_c *MockConnectionRegistry_AddConnection_Call) Run(run func(ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage)) *MockConnectionRegistry_AddConnection_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*websocket.Conn), args[2].(wasabi.OnMessage)) - }) - return _c -} - -func (_c *MockConnectionRegistry_AddConnection_Call) Return(_a0 wasabi.Connection) *MockConnectionRegistry_AddConnection_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockConnectionRegistry_AddConnection_Call) RunAndReturn(run func(context.Context, *websocket.Conn, wasabi.OnMessage) wasabi.Connection) *MockConnectionRegistry_AddConnection_Call { - _c.Call.Return(run) - return _c -} - // CanAccept provides a mock function with given fields: func (_m *MockConnectionRegistry) CanAccept() bool { ret := _m.Called() @@ -228,6 +178,41 @@ func (_c *MockConnectionRegistry_GetConnection_Call) RunAndReturn(run func(strin return _c } +// HandleConnection provides a mock function with given fields: ctx, ws, cb +func (_m *MockConnectionRegistry) HandleConnection(ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage) { + _m.Called(ctx, ws, cb) +} + +// MockConnectionRegistry_HandleConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HandleConnection' +type MockConnectionRegistry_HandleConnection_Call struct { + *mock.Call +} + +// HandleConnection is a helper method to define mock.On call +// - ctx context.Context +// - ws *websocket.Conn +// - cb wasabi.OnMessage +func (_e *MockConnectionRegistry_Expecter) HandleConnection(ctx interface{}, ws interface{}, cb interface{}) *MockConnectionRegistry_HandleConnection_Call { + return &MockConnectionRegistry_HandleConnection_Call{Call: _e.mock.On("HandleConnection", ctx, ws, cb)} +} + +func (_c *MockConnectionRegistry_HandleConnection_Call) Run(run func(ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage)) *MockConnectionRegistry_HandleConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*websocket.Conn), args[2].(wasabi.OnMessage)) + }) + return _c +} + +func (_c *MockConnectionRegistry_HandleConnection_Call) Return() *MockConnectionRegistry_HandleConnection_Call { + _c.Call.Return() + return _c +} + +func (_c *MockConnectionRegistry_HandleConnection_Call) RunAndReturn(run func(context.Context, *websocket.Conn, wasabi.OnMessage)) *MockConnectionRegistry_HandleConnection_Call { + _c.Call.Return(run) + return _c +} + // NewMockConnectionRegistry creates a new instance of MockConnectionRegistry. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockConnectionRegistry(t interface { From 3dc5640d9b0b1df459eb29dfa614632465efc8c8 Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sun, 2 Jun 2024 21:33:08 +0800 Subject: [PATCH 2/2] Moves handling connection cloasure to connection registry --- channel/connection.go | 9 +-- channel/connection_registry.go | 65 ++++++++------------ channel/connection_registry_test.go | 31 +--------- channel/connection_test.go | 92 ++++++++++++----------------- 4 files changed, 66 insertions(+), 131 deletions(-) diff --git a/channel/connection.go b/channel/connection.go index 49cd720..ec682a6 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -35,7 +35,6 @@ type Conn struct { ws *websocket.Conn reqWG *sync.WaitGroup onMessageCB wasabi.OnMessage - onClose chan<- string ctxCancel context.CancelFunc bufferPool *bufferPool state *atomic.Int32 @@ -50,7 +49,6 @@ func NewConnection( ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage, - onClose chan<- string, bufferPool *bufferPool, concurrencyLimit uint, inActivityTimeout time.Duration, @@ -65,7 +63,6 @@ func NewConnection( ctx: ctx, ctxCancel: cancel, onMessageCB: cb, - onClose: onClose, reqWG: &sync.WaitGroup{}, state: &state, bufferPool: bufferPool, @@ -153,7 +150,7 @@ func (c *Conn) Send(msgType wasabi.MessageType, msg []byte) error { } // close closes the connection. -// It cancels the context, sends the connection ID to the onClose channel, +// It cancels the context // marks the connection as closed, and waits for any pending requests to complete. func (c *Conn) close() { if !c.state.CompareAndSwap(int32(connected), int32(terminated)) && @@ -162,7 +159,6 @@ func (c *Conn) close() { } c.ctxCancel() - c.onClose <- c.id // Terminate the connection immediately. _ = c.ws.CloseNow() @@ -177,7 +173,7 @@ func (c *Conn) close() { // before closing the connection. If the context is canceled, the connection // is closed immediately. If there are no pending requests, the connection is // closed immediately. After closing the connection, the connection state is -// set to terminated and the `onClose` channel is notified with the connection ID. +// set to terminated func (c *Conn) Close(status websocket.StatusCode, reason string, ctx ...context.Context) error { if !c.state.CompareAndSwap(int32(connected), int32(closing)) { return ErrConnectionClosed @@ -202,7 +198,6 @@ func (c *Conn) Close(status websocket.StatusCode, reason string, ctx ...context. c.ctxCancel() c.state.Store(int32(terminated)) - c.onClose <- c.id return nil } diff --git a/channel/connection_registry.go b/channel/connection_registry.go index 6334681..e756d58 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -2,7 +2,6 @@ package channel import ( "context" - "fmt" "sync" "time" @@ -22,7 +21,6 @@ type ConnectionHook func(wasabi.Connection) // ConnectionRegistry is default implementation of ConnectionRegistry type ConnectionRegistry struct { connections map[string]wasabi.Connection - onClose chan string bufferPool *bufferPool onConnect ConnectionHook onDisconnect ConnectionHook @@ -40,7 +38,6 @@ type ConnectionRegistryOption func(*ConnectionRegistry) func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry { reg := &ConnectionRegistry{ connections: make(map[string]wasabi.Connection), - onClose: make(chan string), concurrencyLimit: concurencyLimitPerConnection, bufferPool: newBufferPool(), frameSizeLimit: frameSizeLimitInBytes, @@ -52,8 +49,6 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry opt(reg) } - go reg.handleClose() - return reg } @@ -63,36 +58,49 @@ func (r *ConnectionRegistry) HandleConnection( ws *websocket.Conn, cb wasabi.OnMessage, ) { - r.mu.Lock() - defer r.mu.Unlock() + r.mu.RLock() + isClosed := r.isClosed + numOfConnections := len(r.connections) + r.mu.RUnlock() - if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit { - ws.Close(websocket.StatusTryAgainLater, "Connection limit reached") + if isClosed { + ws.Close(websocket.StatusServiceRestart, "Server is shutting down") return } - if r.isClosed { - ws.Close(websocket.StatusServiceRestart, "Server is shutting down") + if r.connectionLimit > 0 && numOfConnections >= r.connectionLimit { + ws.Close(websocket.StatusTryAgainLater, "Connection limit reached") return } - conn := NewConnection(ctx, ws, cb, r.onClose, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout) - r.connections[conn.ID()] = conn - + conn := NewConnection(ctx, ws, cb, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout) conn.ws.SetReadLimit(r.frameSizeLimit) + id := conn.ID() + + r.mu.Lock() + r.connections[id] = conn + r.mu.Unlock() + if r.onConnect != nil { r.onConnect(conn) } conn.handleRequests() + + r.mu.Lock() + connection := r.connections[id] + delete(r.connections, id) + r.mu.Unlock() + + if r.onDisconnect != nil { + r.onDisconnect(connection) + } } // CanAccept checks if the connection registry can accept new connections. // It returns true if the registry can accept new connections, and false otherwise. func (r *ConnectionRegistry) CanAccept() bool { - fmt.Println("Connection limit", r.connectionLimit) - if r.connectionLimit <= 0 { return true } @@ -100,8 +108,6 @@ func (r *ConnectionRegistry) CanAccept() bool { r.mu.RLock() defer r.mu.RUnlock() - fmt.Println("Connections", len(r.connections)) - return len(r.connections) < r.connectionLimit } @@ -113,29 +119,6 @@ func (r *ConnectionRegistry) GetConnection(id string) wasabi.Connection { return r.connections[id] } -// handleClose handles connection cloasures and removes them from registry -func (r *ConnectionRegistry) handleClose() { - wg := sync.WaitGroup{} - - for id := range r.onClose { - r.mu.Lock() - connection := r.connections[id] - delete(r.connections, id) - r.mu.Unlock() - - if r.onDisconnect != nil { - wg.Add(1) - - go func() { - defer wg.Done() - r.onDisconnect(connection) - }() - } - } - - wg.Wait() -} - // Shutdown closes all connections in the ConnectionRegistry. // It sets the isClosed flag to true, indicating that the registry is closed. // It then iterates over all connections, closes them with the given context, diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index 4898abb..f1248f0 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -3,7 +3,6 @@ package channel import ( "context" "net/http/httptest" - "sync" "testing" "time" @@ -121,32 +120,6 @@ func TestConnectionRegistry_GetConnection(t *testing.T) { } } -func TestConnectionRegistry_handleClose(t *testing.T) { - registry := NewConnectionRegistry() - - conn := mocks.NewMockConnection(t) - conn.EXPECT().ID().Return("testID") - registry.connections[conn.ID()] = conn - - var wg sync.WaitGroup - - wg.Add(1) - - go func() { - registry.handleClose() - wg.Done() - }() - - registry.onClose <- conn.ID() - close(registry.onClose) - - wg.Wait() - - if registry.GetConnection(conn.ID()) != nil { - t.Error("Expected connection to be removed from the registry") - } -} - func TestConnectionRegistry_WithMaxFrameLimit(t *testing.T) { registry := NewConnectionRegistry(WithMaxFrameLimit(100)) @@ -279,7 +252,7 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) { t.Error("Expected connection to be passed to onDisconnect hook") } - done <- struct{}{} + close(done) } registry = NewConnectionRegistry(WithOnDisconnectHook(hook)) @@ -320,8 +293,6 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) { t.Error("Expected connection to be handled") } - close(registry.onClose) - select { case <-done: case <-time.After(1 * time.Second): diff --git a/channel/connection_test.go b/channel/connection_test.go index 3dd5037..5dfb724 100644 --- a/channel/connection_test.go +++ b/channel/connection_test.go @@ -47,8 +47,7 @@ var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request func TestConn_ID(t *testing.T) { ws := &websocket.Conn{} - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) if conn.ID() == "" { t.Error("Expected connection ID to be non-empty") @@ -57,8 +56,7 @@ func TestConn_ID(t *testing.T) { func TestConn_Context(t *testing.T) { ws := &websocket.Conn{} - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) if conn.Context() == nil { t.Error("Expected connection context to be non-nil") @@ -83,8 +81,7 @@ func TestConn_handleRequests(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) // Mock OnMessage callback received := make(chan struct{}) @@ -123,8 +120,7 @@ func TestConn_Send(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) err = conn.Send(wasabi.MsgTypeText, []byte("test message")) if err != nil { @@ -148,24 +144,18 @@ func TestConn_close(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) + done := make(chan string) - // Mock OnClose channel - closeChan := make(chan string) - conn.onClose = closeChan + go func() { + conn.handleRequests() + close(done) + }() - go conn.close() + conn.close() select { - case id, ok := <-closeChan: - if !ok { - t.Error("Expected OnClose channel to be closed") - } - - if id != conn.ID() { - t.Errorf("Expected ID to be %s, but got %s", conn.ID(), id) - } + case <-done: case <-time.After(1 * time.Second): t.Error("Expected OnClose channel to be called") } @@ -188,13 +178,12 @@ func TestConn_Close_PendingRequests(t *testing.T) { defer func() { _ = ws.CloseNow() }() ctx := context.Background() - closedChan := make(chan string, 1) - c := NewConnection(ctx, ws, nil, closedChan, newBufferPool(), 1, 0) - - c.reqWG.Add(1) + c := NewConnection(ctx, ws, nil, newBufferPool(), 1, 0) + done := make(chan struct{}) go func() { - c.reqWG.Done() + c.handleRequests() + close(done) }() err = c.Close(websocket.StatusNormalClosure, "test reason", ctx) @@ -203,11 +192,8 @@ func TestConn_Close_PendingRequests(t *testing.T) { } select { - case id := <-closedChan: - if id != c.id { - t.Errorf("Expected ID to be %s, but got %s", c.id, id) - } - default: + case <-done: + case <-time.After(1 * time.Second): t.Error("Expected OnClose channel to be called") } } @@ -228,8 +214,14 @@ func TestConn_Close_NoContext(t *testing.T) { defer func() { _ = ws.CloseNow() }() - closedChan := make(chan string, 1) - c := NewConnection(context.Background(), ws, nil, closedChan, newBufferPool(), 1, 0) + c := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) + + done := make(chan struct{}) + + go func() { + c.handleRequests() + close(done) + }() err = c.Close(websocket.StatusNormalClosure, "test reason") if err != nil { @@ -237,31 +229,20 @@ func TestConn_Close_NoContext(t *testing.T) { } select { - case id := <-closedChan: - if id != c.id { - t.Errorf("Expected ID to be %s, but got %s", c.id, id) - } - default: + case <-done: + case <-time.After(1 * time.Second): t.Error("Expected OnClose channel to be called") } } func TestConn_Close_AlreadyClosed(t *testing.T) { - closedChan := make(chan string, 1) - - c := NewConnection(context.Background(), &websocket.Conn{}, nil, closedChan, newBufferPool(), 1, 0) + c := NewConnection(context.Background(), &websocket.Conn{}, nil, newBufferPool(), 1, 0) c.state.Store(int32(terminated)) err := c.Close(websocket.StatusNormalClosure, "test reason", context.Background()) if err != ErrConnectionClosed { t.Errorf("Expected error to be %v, but got %v", ErrConnectionClosed, err) } - - select { - case <-closedChan: - t.Error("Expected OnClose channel to not be called") - default: - } } func TestConn_watchInactivity(t *testing.T) { @@ -280,8 +261,14 @@ func TestConn_watchInactivity(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 10*time.Millisecond) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 10*time.Millisecond) + + done := make(chan struct{}) + + go func() { + conn.handleRequests() + close(done) + }() defer conn.Close(websocket.StatusNormalClosure, "", context.Background()) @@ -290,7 +277,7 @@ func TestConn_watchInactivity(t *testing.T) { // Check if the connection was closed due to inactivity select { - case <-onClose: + case <-done: // Expected case <-time.After(1 * time.Second): t.Error("Expected connection to be closed due to inactivity") @@ -313,8 +300,7 @@ func TestConn_watchInactivity_stopping_timer(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string, 1) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 10*time.Millisecond) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 10*time.Millisecond) ctxClose, cancel := context.WithCancel(context.Background()) cancel()