Skip to content

Commit

Permalink
Refactor error handling and initialization in server
Browse files Browse the repository at this point in the history
Introduce dedicated error messages for various server issues and update the `New` function to return an error for invalid input. Modify `Start`, `Stop`, and `Close` methods to use the new error types for enhanced clarity and debugging. Adjust tests to handle the new error-returning `New` function.
  • Loading branch information
dmitrymomot committed Nov 13, 2024
1 parent 90b3f84 commit 3f192e9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 56 deletions.
11 changes: 11 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package httpserver

import "errors"

var (
ErrEmptyAddress = errors.New("server address cannot be empty")
ErrNilHandler = errors.New("server handler cannot be nil")
ErrServerStart = errors.New("server failed to start")
ErrServerStop = errors.New("server failed to stop")
ErrServerForceClose = errors.New("server force close failed")
)
32 changes: 20 additions & 12 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package httpserver
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
Expand All @@ -22,7 +21,6 @@ type Server struct {
httpServer *http.Server
shutdownTimeout time.Duration
log Logger
errCh chan error
}

// Logger is an interface that defines the logging methods used by the server.
Expand All @@ -41,7 +39,14 @@ type Logger interface {
// The opt parameter is a variadic list of server options.
// The server options are applied in order, so the last option overrides the previous ones.
// The server options are applied before the server is started.
func New(addr string, handler http.Handler, opt ...serverOption) *Server {
func New(addr string, handler http.Handler, opt ...serverOption) (*Server, error) {
if addr == "" {
return nil, ErrEmptyAddress
}
if handler == nil {
return nil, ErrNilHandler
}

s := &Server{
httpServer: &http.Server{
Addr: addr,
Expand All @@ -53,15 +58,14 @@ func New(addr string, handler http.Handler, opt ...serverOption) *Server {
},
shutdownTimeout: 5 * time.Second,
log: slog.Default().With(slog.String("component", "httpserver")),
errCh: make(chan error, 1), // Initialize error channel
}

// Apply options
for _, o := range opt {
o(s)
}

return s
return s, nil
}

// Start starts the server and listens for incoming requests.
Expand All @@ -85,7 +89,7 @@ func (s *Server) Start(ctx context.Context) error {
// Start the server in a new goroutine within the errgroup
g.Go(func() error {
if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("HTTP server ListenAndServe: %w", err)
return errors.Join(ErrServerStart, err)
}
return nil
})
Expand Down Expand Up @@ -116,7 +120,7 @@ func (s *Server) Start(ctx context.Context) error {
// It uses the provided timeout to gracefully shutdown the underlying HTTP server.
// If the timeout is reached before the server is fully stopped, an error is returned.
func (s *Server) Stop(ctx context.Context, timeout time.Duration) error {
s.log.InfoContext(context.Background(), "stopping HTTP server", "timeout", timeout)
s.log.InfoContext(ctx, "stopping HTTP server", "timeout", timeout)

// Create a new context for shutdown with timeout
shutdownCtx, cancel := context.WithTimeout(ctx, timeout)
Expand All @@ -127,9 +131,10 @@ func (s *Server) Stop(ctx context.Context, timeout time.Duration) error {

// Shutdown the HTTP server
g.Go(func() error {
err := s.httpServer.Shutdown(shutdownCtx)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server shutdown error: %w", err)
if err := s.httpServer.Shutdown(shutdownCtx); err != nil &&
!errors.Is(err, context.Canceled) &&
!errors.Is(err, http.ErrServerClosed) {
return errors.Join(ErrServerStop, err)
}
return nil
})
Expand All @@ -153,7 +158,7 @@ func (s *Server) Close(ctx context.Context) error {

if err := s.httpServer.Close(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.log.ErrorContext(ctx, "error during force close", "error", err)
return fmt.Errorf("server force close error: %w", err)
return errors.Join(ErrServerForceClose, err)
}
return nil
}
Expand All @@ -171,6 +176,9 @@ func signalChan() <-chan os.Signal {
// The addr parameter specifies the address to listen on, e.g., ":8080" for all interfaces on port 8080.
// The handler parameter is an http.Handler that defines the behavior of the server.
func Run(ctx context.Context, addr string, handler http.Handler) error {
server := New(addr, handler)
server, err := New(addr, handler)
if err != nil {
return err
}
return server.Start(ctx)
}
89 changes: 45 additions & 44 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,49 @@ import (
)

func TestServer(t *testing.T) {
listenAddr := "localhost:9999"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, World!")
})
server := httpserver.New(listenAddr, handler)

// Create a context with cancel for server control
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Channel to catch server errors
serverErr := make(chan error, 1)

// Start the server in a goroutine
go func() {
serverErr <- server.Start(ctx)
}()

// Wait for the server to start
time.Sleep(500 * time.Millisecond)

// Test server response
resp, err := http.Get(fmt.Sprintf("http://%s", listenAddr))
require.NoError(t, err, "Unexpected error in GET request")
require.Equal(t, http.StatusOK, resp.StatusCode, "Unexpected status code")
resp.Body.Close()

// Initiate graceful shutdown
cancel()

// Wait for server to shut down with timeout
shutdownTimeout := time.After(5 * time.Second)
select {
case err := <-serverErr:
require.True(t, err == nil || errors.Is(err, context.Canceled),
"Expected nil or context.Canceled error, got: %v", err)
case <-shutdownTimeout:
t.Fatal("Server shutdown timed out")
}

// Verify server is no longer accepting connections
_, err = http.Get(fmt.Sprintf("http://%s", listenAddr))
require.Error(t, err, "Expected error after server shutdown")
listenAddr := "localhost:9999"

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, World!")
})
server, err := httpserver.New(listenAddr, handler)
require.NoError(t, err, "Unexpected error creating server")

// Create a context with cancel for server control
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Channel to catch server errors
serverErr := make(chan error, 1)

// Start the server in a goroutine
go func() {
serverErr <- server.Start(ctx)
}()

// Wait for the server to start
time.Sleep(500 * time.Millisecond)

// Test server response
resp, err := http.Get(fmt.Sprintf("http://%s", listenAddr))
require.NoError(t, err, "Unexpected error in GET request")
require.Equal(t, http.StatusOK, resp.StatusCode, "Unexpected status code")
resp.Body.Close()

// Initiate graceful shutdown
cancel()

// Wait for server to shut down with timeout
shutdownTimeout := time.After(5 * time.Second)
select {
case err := <-serverErr:
require.True(t, err == nil || errors.Is(err, context.Canceled),
"Expected nil or context.Canceled error, got: %v", err)
case <-shutdownTimeout:
t.Fatal("Server shutdown timed out")
}

// Verify server is no longer accepting connections
_, err = http.Get(fmt.Sprintf("http://%s", listenAddr))
require.Error(t, err, "Expected error after server shutdown")
}

0 comments on commit 3f192e9

Please sign in to comment.