From f6de6c1309a3b57b48baf21e764bbdfee9ea5bed Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Sat, 27 Apr 2024 14:52:45 +0800 Subject: [PATCH] Implements tests for server shutdown --- channel/channel.go | 4 +-- channel/connection_registry.go | 10 ++++++ server/server.go | 15 ++++++--- server/server_test.go | 58 +++++++++++++++++++++++++++++++++- 4 files changed, 80 insertions(+), 7 deletions(-) diff --git a/channel/channel.go b/channel/channel.go index 95675b8..733384c 100644 --- a/channel/channel.go +++ b/channel/channel.go @@ -88,8 +88,8 @@ func (c *Channel) Use(middlewere Middlewere) { // Shutdown gracefully shuts down the Channel by shutting down the underlying connection registry. // It waits for all active connections to be closed or until the context is canceled. // Returns an error if the shutdown process encounters any issues. -func (srv *Channel) Shutdown(ctx context.Context) error { - return srv.connRegistry.Shutdown(ctx) +func (c *Channel) Shutdown(ctx context.Context) error { + return c.connRegistry.Shutdown(ctx) } // useMiddleware applies middlewares to handler diff --git a/channel/connection_registry.go b/channel/connection_registry.go index 5f7bbb6..602a36a 100644 --- a/channel/connection_registry.go +++ b/channel/connection_registry.go @@ -84,19 +84,28 @@ func (r *ConnectionRegistry) handleClose() { } } +// Shutdown closes all connections in the ConnectionRegistry. +// It sets the isClosed flag to true, indicating that the registry is closed. +// It then iterates over all connections, closes them with the given context, +// and waits for all closures to complete before returning. func (r *ConnectionRegistry) Shutdown(ctx context.Context) error { r.mu.Lock() r.isClosed = true connections := make([]wasabi.Connection, 0, len(r.connections)) + for _, conn := range r.connections { connections = append(connections, conn) } + r.mu.Unlock() wg := sync.WaitGroup{} + for _, conn := range connections { c := conn + wg.Add(1) + go func() { c.Close(ctx, websocket.StatusServiceRestart, "") wg.Done() @@ -104,6 +113,7 @@ func (r *ConnectionRegistry) Shutdown(ctx context.Context) error { } wg.Wait() + return nil } diff --git a/server/server.go b/server/server.go index 75a2e81..cbbae70 100644 --- a/server/server.go +++ b/server/server.go @@ -93,25 +93,32 @@ func (s *Server) Run() error { return nil } +// Shutdown gracefully shuts down the server and all its channels. +// It waits for all channels to be shut down before returning. +// If the context is canceled before all channels are shut down, it returns the context error. +// If any error occurs during the shutdown process, it returns the first error encountered. func (s *Server) Shutdown(ctx context.Context) error { done := make(chan error) go func() { defer close(done) - err := s.handler.Shutdown(ctx) - if err != nil { + + if err := s.handler.Shutdown(ctx); err != nil { done <- err } }() wg := sync.WaitGroup{} + for _, channel := range s.channels { c := channel + wg.Add(1) + go func() { defer wg.Done() - err := c.Shutdown(ctx) - if err != nil { + + if err := c.Shutdown(ctx); err != nil { slog.Error("Error shutting down channel:" + err.Error()) } }() diff --git a/server/server_test.go b/server/server_test.go index 1807a3d..4632fc3 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -81,7 +81,7 @@ func TestServer_Run(t *testing.T) { go func() { err := server.Run() switch err { - case http.ErrServerClosed: + case nil: close(done) case ErrServerAlreadyRunning: done <- struct{}{} @@ -110,3 +110,59 @@ func TestServer_Run(t *testing.T) { t.Error("Expected server to stop") } } +func TestServer_Shutdown(t *testing.T) { + // Create a new Server instance + server := NewServer(":0") + + // Create a context with a timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + + // Create a mock channel + channel := mocks.NewMockChannel(t) + channel.EXPECT().Path().Return("/test") + channel.EXPECT().Handler().Return(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + channel.EXPECT().Shutdown(ctx).Return(nil) + + server.AddChannel(channel) + + defer cancel() + + // Start the server in a separate goroutine + done := make(chan struct{}) + + // Run the server + for i := 0; i < 2; i++ { + go func() { + err := server.Run() + switch err { + case nil: + close(done) + case ErrServerAlreadyRunning: + done <- struct{}{} + default: + t.Errorf("Got unexpected error: %v", err) + } + }() + } + + select { + case _, ok := <-done: + if !ok { + t.Error("Expected server to start") + } + case <-time.After(1 * time.Second): + t.Error("Expected server to start") + } + + // Call the Shutdown method + err := server.Shutdown(ctx) + if err != nil { + t.Errorf("Unexpected error shutting down server: %v", err) + } + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("Expected server to stop") + } +}