From 3dc5640d9b0b1df459eb29dfa614632465efc8c8 Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sun, 2 Jun 2024 21:33:08 +0800 Subject: [PATCH] Moves handling connection cloasure to connection registry --- channel/connection.go | 9 +-- channel/connection_registry.go | 65 ++++++++------------ channel/connection_registry_test.go | 31 +--------- channel/connection_test.go | 92 ++++++++++++----------------- 4 files changed, 66 insertions(+), 131 deletions(-) diff --git a/channel/connection.go b/channel/connection.go index 49cd720..ec682a6 100644 --- a/channel/connection.go +++ b/channel/connection.go @@ -35,7 +35,6 @@ type Conn struct { ws *websocket.Conn reqWG *sync.WaitGroup onMessageCB wasabi.OnMessage - onClose chan<- string ctxCancel context.CancelFunc bufferPool *bufferPool state *atomic.Int32 @@ -50,7 +49,6 @@ func NewConnection( ctx context.Context, ws *websocket.Conn, cb wasabi.OnMessage, - onClose chan<- string, bufferPool *bufferPool, concurrencyLimit uint, inActivityTimeout time.Duration, @@ -65,7 +63,6 @@ func NewConnection( ctx: ctx, ctxCancel: cancel, onMessageCB: cb, - onClose: onClose, reqWG: &sync.WaitGroup{}, state: &state, bufferPool: bufferPool, @@ -153,7 +150,7 @@ func (c *Conn) Send(msgType wasabi.MessageType, msg []byte) error { } // close closes the connection. -// It cancels the context, sends the connection ID to the onClose channel, +// It cancels the context // marks the connection as closed, and waits for any pending requests to complete. func (c *Conn) close() { if !c.state.CompareAndSwap(int32(connected), int32(terminated)) && @@ -162,7 +159,6 @@ func (c *Conn) close() { } c.ctxCancel() - c.onClose <- c.id // Terminate the connection immediately. _ = c.ws.CloseNow() @@ -177,7 +173,7 @@ func (c *Conn) close() { // before closing the connection. If the context is canceled, the connection // is closed immediately. If there are no pending requests, the connection is // closed immediately. After closing the connection, the connection state is -// set to terminated and the `onClose` channel is notified with the connection ID. +// set to terminated func (c *Conn) Close(status websocket.StatusCode, reason string, ctx ...context.Context) error { if !c.state.CompareAndSwap(int32(connected), int32(closing)) { return ErrConnectionClosed @@ -202,7 +198,6 @@ func (c *Conn) Close(status websocket.StatusCode, reason string, ctx ...context. c.ctxCancel() c.state.Store(int32(terminated)) - c.onClose <- c.id return nil } diff --git a/channel/connection_registry.go b/channel/connection_registry.go index 6334681..e756d58 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -2,7 +2,6 @@ package channel import ( "context" - "fmt" "sync" "time" @@ -22,7 +21,6 @@ type ConnectionHook func(wasabi.Connection) // ConnectionRegistry is default implementation of ConnectionRegistry type ConnectionRegistry struct { connections map[string]wasabi.Connection - onClose chan string bufferPool *bufferPool onConnect ConnectionHook onDisconnect ConnectionHook @@ -40,7 +38,6 @@ type ConnectionRegistryOption func(*ConnectionRegistry) func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry { reg := &ConnectionRegistry{ connections: make(map[string]wasabi.Connection), - onClose: make(chan string), concurrencyLimit: concurencyLimitPerConnection, bufferPool: newBufferPool(), frameSizeLimit: frameSizeLimitInBytes, @@ -52,8 +49,6 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry opt(reg) } - go reg.handleClose() - return reg } @@ -63,36 +58,49 @@ func (r *ConnectionRegistry) HandleConnection( ws *websocket.Conn, cb wasabi.OnMessage, ) { - r.mu.Lock() - defer r.mu.Unlock() + r.mu.RLock() + isClosed := r.isClosed + numOfConnections := len(r.connections) + r.mu.RUnlock() - if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit { - ws.Close(websocket.StatusTryAgainLater, "Connection limit reached") + if isClosed { + ws.Close(websocket.StatusServiceRestart, "Server is shutting down") return } - if r.isClosed { - ws.Close(websocket.StatusServiceRestart, "Server is shutting down") + if r.connectionLimit > 0 && numOfConnections >= r.connectionLimit { + ws.Close(websocket.StatusTryAgainLater, "Connection limit reached") return } - conn := NewConnection(ctx, ws, cb, r.onClose, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout) - r.connections[conn.ID()] = conn - + conn := NewConnection(ctx, ws, cb, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout) conn.ws.SetReadLimit(r.frameSizeLimit) + id := conn.ID() + + r.mu.Lock() + r.connections[id] = conn + r.mu.Unlock() + if r.onConnect != nil { r.onConnect(conn) } conn.handleRequests() + + r.mu.Lock() + connection := r.connections[id] + delete(r.connections, id) + r.mu.Unlock() + + if r.onDisconnect != nil { + r.onDisconnect(connection) + } } // CanAccept checks if the connection registry can accept new connections. // It returns true if the registry can accept new connections, and false otherwise. func (r *ConnectionRegistry) CanAccept() bool { - fmt.Println("Connection limit", r.connectionLimit) - if r.connectionLimit <= 0 { return true } @@ -100,8 +108,6 @@ func (r *ConnectionRegistry) CanAccept() bool { r.mu.RLock() defer r.mu.RUnlock() - fmt.Println("Connections", len(r.connections)) - return len(r.connections) < r.connectionLimit } @@ -113,29 +119,6 @@ func (r *ConnectionRegistry) GetConnection(id string) wasabi.Connection { return r.connections[id] } -// handleClose handles connection cloasures and removes them from registry -func (r *ConnectionRegistry) handleClose() { - wg := sync.WaitGroup{} - - for id := range r.onClose { - r.mu.Lock() - connection := r.connections[id] - delete(r.connections, id) - r.mu.Unlock() - - if r.onDisconnect != nil { - wg.Add(1) - - go func() { - defer wg.Done() - r.onDisconnect(connection) - }() - } - } - - wg.Wait() -} - // Shutdown closes all connections in the ConnectionRegistry. // It sets the isClosed flag to true, indicating that the registry is closed. // It then iterates over all connections, closes them with the given context, diff --git a/channel/connection_registry_test.go b/channel/connection_registry_test.go index 4898abb..f1248f0 100644 --- a/channel/connection_registry_test.go +++ b/channel/connection_registry_test.go @@ -3,7 +3,6 @@ package channel import ( "context" "net/http/httptest" - "sync" "testing" "time" @@ -121,32 +120,6 @@ func TestConnectionRegistry_GetConnection(t *testing.T) { } } -func TestConnectionRegistry_handleClose(t *testing.T) { - registry := NewConnectionRegistry() - - conn := mocks.NewMockConnection(t) - conn.EXPECT().ID().Return("testID") - registry.connections[conn.ID()] = conn - - var wg sync.WaitGroup - - wg.Add(1) - - go func() { - registry.handleClose() - wg.Done() - }() - - registry.onClose <- conn.ID() - close(registry.onClose) - - wg.Wait() - - if registry.GetConnection(conn.ID()) != nil { - t.Error("Expected connection to be removed from the registry") - } -} - func TestConnectionRegistry_WithMaxFrameLimit(t *testing.T) { registry := NewConnectionRegistry(WithMaxFrameLimit(100)) @@ -279,7 +252,7 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) { t.Error("Expected connection to be passed to onDisconnect hook") } - done <- struct{}{} + close(done) } registry = NewConnectionRegistry(WithOnDisconnectHook(hook)) @@ -320,8 +293,6 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) { t.Error("Expected connection to be handled") } - close(registry.onClose) - select { case <-done: case <-time.After(1 * time.Second): diff --git a/channel/connection_test.go b/channel/connection_test.go index 3dd5037..5dfb724 100644 --- a/channel/connection_test.go +++ b/channel/connection_test.go @@ -47,8 +47,7 @@ var wsHandlerEcho = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request func TestConn_ID(t *testing.T) { ws := &websocket.Conn{} - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) if conn.ID() == "" { t.Error("Expected connection ID to be non-empty") @@ -57,8 +56,7 @@ func TestConn_ID(t *testing.T) { func TestConn_Context(t *testing.T) { ws := &websocket.Conn{} - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) if conn.Context() == nil { t.Error("Expected connection context to be non-nil") @@ -83,8 +81,7 @@ func TestConn_handleRequests(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) // Mock OnMessage callback received := make(chan struct{}) @@ -123,8 +120,7 @@ func TestConn_Send(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) err = conn.Send(wasabi.MsgTypeText, []byte("test message")) if err != nil { @@ -148,24 +144,18 @@ func TestConn_close(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 0) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) + done := make(chan string) - // Mock OnClose channel - closeChan := make(chan string) - conn.onClose = closeChan + go func() { + conn.handleRequests() + close(done) + }() - go conn.close() + conn.close() select { - case id, ok := <-closeChan: - if !ok { - t.Error("Expected OnClose channel to be closed") - } - - if id != conn.ID() { - t.Errorf("Expected ID to be %s, but got %s", conn.ID(), id) - } + case <-done: case <-time.After(1 * time.Second): t.Error("Expected OnClose channel to be called") } @@ -188,13 +178,12 @@ func TestConn_Close_PendingRequests(t *testing.T) { defer func() { _ = ws.CloseNow() }() ctx := context.Background() - closedChan := make(chan string, 1) - c := NewConnection(ctx, ws, nil, closedChan, newBufferPool(), 1, 0) - - c.reqWG.Add(1) + c := NewConnection(ctx, ws, nil, newBufferPool(), 1, 0) + done := make(chan struct{}) go func() { - c.reqWG.Done() + c.handleRequests() + close(done) }() err = c.Close(websocket.StatusNormalClosure, "test reason", ctx) @@ -203,11 +192,8 @@ func TestConn_Close_PendingRequests(t *testing.T) { } select { - case id := <-closedChan: - if id != c.id { - t.Errorf("Expected ID to be %s, but got %s", c.id, id) - } - default: + case <-done: + case <-time.After(1 * time.Second): t.Error("Expected OnClose channel to be called") } } @@ -228,8 +214,14 @@ func TestConn_Close_NoContext(t *testing.T) { defer func() { _ = ws.CloseNow() }() - closedChan := make(chan string, 1) - c := NewConnection(context.Background(), ws, nil, closedChan, newBufferPool(), 1, 0) + c := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 0) + + done := make(chan struct{}) + + go func() { + c.handleRequests() + close(done) + }() err = c.Close(websocket.StatusNormalClosure, "test reason") if err != nil { @@ -237,31 +229,20 @@ func TestConn_Close_NoContext(t *testing.T) { } select { - case id := <-closedChan: - if id != c.id { - t.Errorf("Expected ID to be %s, but got %s", c.id, id) - } - default: + case <-done: + case <-time.After(1 * time.Second): t.Error("Expected OnClose channel to be called") } } func TestConn_Close_AlreadyClosed(t *testing.T) { - closedChan := make(chan string, 1) - - c := NewConnection(context.Background(), &websocket.Conn{}, nil, closedChan, newBufferPool(), 1, 0) + c := NewConnection(context.Background(), &websocket.Conn{}, nil, newBufferPool(), 1, 0) c.state.Store(int32(terminated)) err := c.Close(websocket.StatusNormalClosure, "test reason", context.Background()) if err != ErrConnectionClosed { t.Errorf("Expected error to be %v, but got %v", ErrConnectionClosed, err) } - - select { - case <-closedChan: - t.Error("Expected OnClose channel to not be called") - default: - } } func TestConn_watchInactivity(t *testing.T) { @@ -280,8 +261,14 @@ func TestConn_watchInactivity(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 10*time.Millisecond) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 10*time.Millisecond) + + done := make(chan struct{}) + + go func() { + conn.handleRequests() + close(done) + }() defer conn.Close(websocket.StatusNormalClosure, "", context.Background()) @@ -290,7 +277,7 @@ func TestConn_watchInactivity(t *testing.T) { // Check if the connection was closed due to inactivity select { - case <-onClose: + case <-done: // Expected case <-time.After(1 * time.Second): t.Error("Expected connection to be closed due to inactivity") @@ -313,8 +300,7 @@ func TestConn_watchInactivity_stopping_timer(t *testing.T) { defer func() { _ = ws.CloseNow() }() - onClose := make(chan string, 1) - conn := NewConnection(context.Background(), ws, nil, onClose, newBufferPool(), 1, 10*time.Millisecond) + conn := NewConnection(context.Background(), ws, nil, newBufferPool(), 1, 10*time.Millisecond) ctxClose, cancel := context.WithCancel(context.Background()) cancel()