Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements graceful shutdown #22

Merged
merged 4 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions channel/channel.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package channel

import (
"context"
"net/http"

"github.com/ksysoev/wasabi"
Expand Down Expand Up @@ -84,6 +85,13 @@ func (c *Channel) Use(middlewere Middlewere) {
c.middlewares = append(c.middlewares, middlewere)
}

// Shutdown gracefully shuts down the Channel by shutting down the underlying connection registry.
// It waits for all active connections to be closed or until the context is canceled.
// Returns an error if the shutdown process encounters any issues.
func (c *Channel) Shutdown(ctx context.Context) error {
return c.connRegistry.Shutdown(ctx)
}

// useMiddleware applies middlewares to handler
func (c *Channel) wrapMiddleware(handler http.Handler) http.Handler {
for i := len(c.middlewares) - 1; i >= 0; i-- {
Expand Down
14 changes: 14 additions & 0 deletions channel/channel_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package channel

import (
"context"
"net/http"
"testing"

Expand Down Expand Up @@ -132,3 +133,16 @@ func TestChannel_WithOriginPatterns(t *testing.T) {
t.Errorf("Unexpected to get default origin pattern: got %s, expected %s", channel.config.originPatterns[1], "test2")
}
}
func TestChannel_Shutdown(t *testing.T) {
path := "/test/path"
dispatcher := mocks.NewMockDispatcher(t)

channel := NewChannel(path, dispatcher, NewConnectionRegistry())

// Call the Shutdown method
err := channel.Shutdown(context.Background())

if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
39 changes: 39 additions & 0 deletions channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type ConnectionRegistry struct {
concurrencyLimit uint
mu sync.RWMutex
frameSizeLimit int64
isClosed bool
}

type ConnectionRegistryOption func(*ConnectionRegistry)
Expand All @@ -33,6 +34,7 @@ func NewConnectionRegistry(opts ...ConnectionRegistryOption) *ConnectionRegistry
concurrencyLimit: concurencyLimitPerConnection,
bufferPool: newBufferPool(),
frameSizeLimit: frameSizeLimitInBytes,
isClosed: false,
}

for _, opt := range opts {
Expand All @@ -53,6 +55,10 @@ func (r *ConnectionRegistry) AddConnection(
r.mu.Lock()
defer r.mu.Unlock()

if r.isClosed {
return nil
}

conn := NewConnection(ctx, ws, cb, r.onClose, r.bufferPool, r.concurrencyLimit)
r.connections[conn.ID()] = conn

Expand All @@ -78,6 +84,39 @@ func (r *ConnectionRegistry) handleClose() {
}
}

// Shutdown closes all connections in the ConnectionRegistry.
// It sets the isClosed flag to true, indicating that the registry is closed.
// It then iterates over all connections, closes them with the given context,
// and waits for all closures to complete before returning.
func (r *ConnectionRegistry) Shutdown(ctx context.Context) error {
r.mu.Lock()
r.isClosed = true
connections := make([]wasabi.Connection, 0, len(r.connections))

for _, conn := range r.connections {
connections = append(connections, conn)
}

r.mu.Unlock()

wg := sync.WaitGroup{}

for _, conn := range connections {
c := conn

wg.Add(1)

go func() {
defer wg.Done()
c.Close(ctx, websocket.StatusServiceRestart, "")
}()
}

wg.Wait()

return nil
}

// WithMaxFrameLimit sets the maximum frame size limit for incomming messages to the ConnectionRegistry.
// The limit parameter specifies the maximum frame size limit in bytes.
// This option can be used when creating a new ConnectionRegistry instance.
Expand Down
29 changes: 29 additions & 0 deletions channel/connection_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,32 @@ func TestConnectionRegistry_WithMaxFrameLimit(t *testing.T) {
t.Errorf("Unexpected frame size limit: got %d, expected %d", registry.frameSizeLimit, 100)
}
}
func TestConnectionRegistry_Shutdown(t *testing.T) {
ctx := context.Background()
registry := NewConnectionRegistry()

// Add some mock connections to the registry
conn1 := mocks.NewMockConnection(t)
conn2 := mocks.NewMockConnection(t)

conn1.EXPECT().ID().Return("conn1")
conn2.EXPECT().ID().Return("conn2")

registry.connections[conn1.ID()] = conn1
registry.connections[conn2.ID()] = conn2

// Set up expectations for the Close method
conn1.EXPECT().Close(ctx, websocket.StatusServiceRestart, "").Return(nil)
conn2.EXPECT().Close(ctx, websocket.StatusServiceRestart, "").Return(nil)

err := registry.Shutdown(ctx)

if err != nil {
t.Errorf("Unexpected error: %v", err)
}

// Verify that the registry is closed
if !registry.isClosed {
t.Error("Expected registry to be closed")
}
}
2 changes: 2 additions & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type RequestHandler interface {
type Channel interface {
Path() string
Handler() http.Handler
Shutdown(ctx context.Context) error
}

// ConnectionRegistry is interface for connection registries
Expand All @@ -61,4 +62,5 @@ type ConnectionRegistry interface {
cb OnMessage,
) Connection
GetConnection(id string) Connection
Shutdown(ctx context.Context) error
}
2 changes: 1 addition & 1 deletion loadtesting/k6.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export default function () {
const url = 'ws://localhost:8080/';
const params = { tags: { my_tag: 'hello' } };

let counter = 100;
let counter = 2000;

const res = ws.connect(url, params, function (socket) {
socket.on('open', () => {
Expand Down
37 changes: 25 additions & 12 deletions mocks/mock_Channel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions mocks/mock_ConnectionRegistry.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 52 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,58 @@ func (s *Server) Run() error {

slog.Info("Starting app server on " + s.addr)

return s.handler.ListenAndServe()
err := s.handler.ListenAndServe()

if err != nil && err != http.ErrServerClosed {
return err
}

return nil
}

// Shutdown gracefully shuts down the server and all its channels.
// It waits for all channels to be shut down before returning.
// If the context is canceled before all channels are shut down, it returns the context error.
// If any error occurs during the shutdown process, it returns the first error encountered.
func (s *Server) Shutdown(ctx context.Context) error {
done := make(chan error)

go func() {
defer close(done)

if err := s.handler.Shutdown(ctx); err != nil {
done <- err
}
}()

wg := sync.WaitGroup{}

for _, channel := range s.channels {
c := channel

wg.Add(1)

go func() {
defer wg.Done()

if err := c.Shutdown(ctx); err != nil {
slog.Error("Error shutting down channel:" + err.Error())
}
}()
}

wg.Wait()

select {
case <-ctx.Done():
return ctx.Err()
case err, ok := <-done:
if !ok {
return nil
}

return err
}
}

// BaseContext optionally specifies based context that will be used for all connections.
Expand Down
Loading
Loading