Skip to content

Commit

Permalink
Moves handling connection cloasure to connection registry
Browse files Browse the repository at this point in the history
  • Loading branch information
ksysoev committed Jun 2, 2024
1 parent f2c8be9 commit 3dc5640
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 131 deletions.
9 changes: 2 additions & 7 deletions channel/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ type Conn struct {
ws *websocket.Conn
reqWG *sync.WaitGroup
onMessageCB wasabi.OnMessage
onClose chan<- string
ctxCancel context.CancelFunc
bufferPool *bufferPool
state *atomic.Int32
Expand All @@ -50,7 +49,6 @@ func NewConnection(
ctx context.Context,
ws *websocket.Conn,
cb wasabi.OnMessage,
onClose chan<- string,
bufferPool *bufferPool,
concurrencyLimit uint,
inActivityTimeout time.Duration,
Expand All @@ -65,7 +63,6 @@ func NewConnection(
ctx: ctx,
ctxCancel: cancel,
onMessageCB: cb,
onClose: onClose,
reqWG: &sync.WaitGroup{},
state: &state,
bufferPool: bufferPool,
Expand Down Expand Up @@ -153,7 +150,7 @@ func (c *Conn) Send(msgType wasabi.MessageType, msg []byte) error {
}

// close closes the connection.
// It cancels the context, sends the connection ID to the onClose channel,
// It cancels the context
// marks the connection as closed, and waits for any pending requests to complete.
func (c *Conn) close() {
if !c.state.CompareAndSwap(int32(connected), int32(terminated)) &&
Expand All @@ -162,7 +159,6 @@ func (c *Conn) close() {
}

c.ctxCancel()
c.onClose <- c.id

// Terminate the connection immediately.
_ = c.ws.CloseNow()
Expand All @@ -177,7 +173,7 @@ func (c *Conn) close() {
// before closing the connection. If the context is canceled, the connection
// is closed immediately. If there are no pending requests, the connection is
// closed immediately. After closing the connection, the connection state is
// set to terminated and the `onClose` channel is notified with the connection ID.
// set to terminated
func (c *Conn) Close(status websocket.StatusCode, reason string, ctx ...context.Context) error {
if !c.state.CompareAndSwap(int32(connected), int32(closing)) {
return ErrConnectionClosed
Expand All @@ -202,7 +198,6 @@ func (c *Conn) Close(status websocket.StatusCode, reason string, ctx ...context.

c.ctxCancel()
c.state.Store(int32(terminated))
c.onClose <- c.id

return nil
}
Expand Down
65 changes: 24 additions & 41 deletions channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package channel

import (
"context"
"fmt"
"sync"
"time"

Expand All @@ -22,7 +21,6 @@ type ConnectionHook func(wasabi.Connection)
// ConnectionRegistry is default implementation of ConnectionRegistry
type ConnectionRegistry struct {
connections map[string]wasabi.Connection
onClose chan string
bufferPool *bufferPool
onConnect ConnectionHook
onDisconnect ConnectionHook
Expand All @@ -40,7 +38,6 @@ type ConnectionRegistryOption func(*ConnectionRegistry)
func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry {
reg := &ConnectionRegistry{
connections: make(map[string]wasabi.Connection),
onClose: make(chan string),
concurrencyLimit: concurencyLimitPerConnection,
bufferPool: newBufferPool(),
frameSizeLimit: frameSizeLimitInBytes,
Expand All @@ -52,8 +49,6 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry
opt(reg)
}

go reg.handleClose()

return reg
}

Expand All @@ -63,45 +58,56 @@ func (r *ConnectionRegistry) HandleConnection(
ws *websocket.Conn,
cb wasabi.OnMessage,
) {
r.mu.Lock()
defer r.mu.Unlock()
r.mu.RLock()
isClosed := r.isClosed
numOfConnections := len(r.connections)
r.mu.RUnlock()

if r.connectionLimit > 0 && len(r.connections) >= r.connectionLimit {
ws.Close(websocket.StatusTryAgainLater, "Connection limit reached")
if isClosed {
ws.Close(websocket.StatusServiceRestart, "Server is shutting down")
return
}

if r.isClosed {
ws.Close(websocket.StatusServiceRestart, "Server is shutting down")
if r.connectionLimit > 0 && numOfConnections >= r.connectionLimit {
ws.Close(websocket.StatusTryAgainLater, "Connection limit reached")
return
}

conn := NewConnection(ctx, ws, cb, r.onClose, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout)
r.connections[conn.ID()] = conn

conn := NewConnection(ctx, ws, cb, r.bufferPool, r.concurrencyLimit, r.inActivityTimeout)
conn.ws.SetReadLimit(r.frameSizeLimit)

id := conn.ID()

r.mu.Lock()
r.connections[id] = conn
r.mu.Unlock()

if r.onConnect != nil {
r.onConnect(conn)
}

conn.handleRequests()

r.mu.Lock()
connection := r.connections[id]
delete(r.connections, id)
r.mu.Unlock()

if r.onDisconnect != nil {
r.onDisconnect(connection)
}
}

// CanAccept checks if the connection registry can accept new connections.
// It returns true if the registry can accept new connections, and false otherwise.
func (r *ConnectionRegistry) CanAccept() bool {
fmt.Println("Connection limit", r.connectionLimit)

if r.connectionLimit <= 0 {
return true
}

r.mu.RLock()
defer r.mu.RUnlock()

fmt.Println("Connections", len(r.connections))

return len(r.connections) < r.connectionLimit
}

Expand All @@ -113,29 +119,6 @@ func (r *ConnectionRegistry) GetConnection(id string) wasabi.Connection {
return r.connections[id]
}

// handleClose handles connection cloasures and removes them from registry
func (r *ConnectionRegistry) handleClose() {
wg := sync.WaitGroup{}

for id := range r.onClose {
r.mu.Lock()
connection := r.connections[id]
delete(r.connections, id)
r.mu.Unlock()

if r.onDisconnect != nil {
wg.Add(1)

go func() {
defer wg.Done()
r.onDisconnect(connection)
}()
}
}

wg.Wait()
}

// Shutdown closes all connections in the ConnectionRegistry.
// It sets the isClosed flag to true, indicating that the registry is closed.
// It then iterates over all connections, closes them with the given context,
Expand Down
31 changes: 1 addition & 30 deletions channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package channel
import (
"context"
"net/http/httptest"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -121,32 +120,6 @@ func TestConnectionRegistry_GetConnection(t *testing.T) {
}
}

func TestConnectionRegistry_handleClose(t *testing.T) {
registry := NewConnectionRegistry()

conn := mocks.NewMockConnection(t)
conn.EXPECT().ID().Return("testID")
registry.connections[conn.ID()] = conn

var wg sync.WaitGroup

wg.Add(1)

go func() {
registry.handleClose()
wg.Done()
}()

registry.onClose <- conn.ID()
close(registry.onClose)

wg.Wait()

if registry.GetConnection(conn.ID()) != nil {
t.Error("Expected connection to be removed from the registry")
}
}

func TestConnectionRegistry_WithMaxFrameLimit(t *testing.T) {
registry := NewConnectionRegistry(WithMaxFrameLimit(100))

Expand Down Expand Up @@ -279,7 +252,7 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) {
t.Error("Expected connection to be passed to onDisconnect hook")
}

done <- struct{}{}
close(done)
}

registry = NewConnectionRegistry(WithOnDisconnectHook(hook))
Expand Down Expand Up @@ -320,8 +293,6 @@ func TestConnectionRegistry_WithOnDisconnectHook(t *testing.T) {
t.Error("Expected connection to be handled")
}

close(registry.onClose)

select {
case <-done:
case <-time.After(1 * time.Second):
Expand Down
Loading

0 comments on commit 3dc5640

Please sign in to comment.