Skip to content

Commit

Permalink
Implements tests for server shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
ksysoev committed Apr 27, 2024
1 parent 98802fc commit f6de6c1
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 7 deletions.
4 changes: 2 additions & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ func (c *Channel) Use(middlewere Middlewere) {
// Shutdown gracefully shuts down the Channel by shutting down the underlying connection registry.
// It waits for all active connections to be closed or until the context is canceled.
// Returns an error if the shutdown process encounters any issues.
func (srv *Channel) Shutdown(ctx context.Context) error {
return srv.connRegistry.Shutdown(ctx)
func (c *Channel) Shutdown(ctx context.Context) error {
return c.connRegistry.Shutdown(ctx)
}

// useMiddleware applies middlewares to handler
Expand Down
10 changes: 10 additions & 0 deletions channel/connection_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,36 @@ func (r *ConnectionRegistry) handleClose() {
}
}

// Shutdown closes all connections in the ConnectionRegistry.
// It sets the isClosed flag to true, indicating that the registry is closed.
// It then iterates over all connections, closes them with the given context,
// and waits for all closures to complete before returning.
func (r *ConnectionRegistry) Shutdown(ctx context.Context) error {
r.mu.Lock()
r.isClosed = true
connections := make([]wasabi.Connection, 0, len(r.connections))

for _, conn := range r.connections {
connections = append(connections, conn)
}

r.mu.Unlock()

wg := sync.WaitGroup{}

for _, conn := range connections {
c := conn

wg.Add(1)

go func() {
c.Close(ctx, websocket.StatusServiceRestart, "")
wg.Done()
}()
}

wg.Wait()

return nil
}

Expand Down
15 changes: 11 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,32 @@ func (s *Server) Run() error {
return nil
}

// Shutdown gracefully shuts down the server and all its channels.
// It waits for all channels to be shut down before returning.
// If the context is canceled before all channels are shut down, it returns the context error.
// If any error occurs during the shutdown process, it returns the first error encountered.
func (s *Server) Shutdown(ctx context.Context) error {
done := make(chan error)

go func() {
defer close(done)
err := s.handler.Shutdown(ctx)
if err != nil {

if err := s.handler.Shutdown(ctx); err != nil {
done <- err
}
}()

wg := sync.WaitGroup{}

for _, channel := range s.channels {
c := channel

wg.Add(1)

go func() {
defer wg.Done()
err := c.Shutdown(ctx)
if err != nil {

if err := c.Shutdown(ctx); err != nil {
slog.Error("Error shutting down channel:" + err.Error())
}
}()
Expand Down
58 changes: 57 additions & 1 deletion server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestServer_Run(t *testing.T) {
go func() {
err := server.Run()
switch err {
case http.ErrServerClosed:
case nil:
close(done)
case ErrServerAlreadyRunning:
done <- struct{}{}
Expand Down Expand Up @@ -110,3 +110,59 @@ func TestServer_Run(t *testing.T) {
t.Error("Expected server to stop")
}
}
func TestServer_Shutdown(t *testing.T) {
// Create a new Server instance
server := NewServer(":0")

// Create a context with a timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

// Create a mock channel
channel := mocks.NewMockChannel(t)
channel.EXPECT().Path().Return("/test")
channel.EXPECT().Handler().Return(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
channel.EXPECT().Shutdown(ctx).Return(nil)

server.AddChannel(channel)

defer cancel()

// Start the server in a separate goroutine
done := make(chan struct{})

// Run the server
for i := 0; i < 2; i++ {
go func() {
err := server.Run()
switch err {
case nil:
close(done)
case ErrServerAlreadyRunning:
done <- struct{}{}
default:
t.Errorf("Got unexpected error: %v", err)
}
}()
}

select {
case _, ok := <-done:
if !ok {
t.Error("Expected server to start")
}
case <-time.After(1 * time.Second):
t.Error("Expected server to start")
}

// Call the Shutdown method
err := server.Shutdown(ctx)
if err != nil {
t.Errorf("Unexpected error shutting down server: %v", err)
}

select {
case <-done:
case <-time.After(1 * time.Second):
t.Error("Expected server to stop")
}
}

0 comments on commit f6de6c1

Please sign in to comment.