diff --git a/health/http.go b/health/http.go new file mode 100644 index 0000000..720a673 --- /dev/null +++ b/health/http.go @@ -0,0 +1,108 @@ +package health + +import ( + "context" + "fmt" + "net/http" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultHealthCheckRoute = "/health" + defaultReadinessCheckRoute = "/ready" + defaultPort = 4444 +) + +type CheckFunc func() error + +type Option func(*server) + +type server struct { + healthCheckRoute string + readinessCheckRoute string + port int + healthCheckFunc CheckFunc + readinessCheckFunc CheckFunc +} + +func WithHealthCheckRoute(route string) Option { + return func(s *server) { + s.healthCheckRoute = route + } +} + +func WithReadinessCheckRoute(route string) Option { + return func(s *server) { + s.readinessCheckRoute = route + } +} + +func WithPort(port int) Option { + return func(s *server) { + s.port = port + } +} + +func WithHealthCheckFunc(healthCheckFunc CheckFunc) Option { + return func(s *server) { + s.healthCheckFunc = healthCheckFunc + } +} + +func WithReadinessCheckFunc(readinessCheckFunc CheckFunc) Option { + return func(s *server) { + s.readinessCheckFunc = readinessCheckFunc + } +} + +func handle(handler *http.ServeMux, route string, handleFunc CheckFunc) { + handler.HandleFunc(route, func(w http.ResponseWriter, r *http.Request) { + if handleFunc == nil { + w.WriteHeader(http.StatusOK) + return + } + + if err := handleFunc(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + }) +} + +// StartHealthCheckServer starts a HTTP server to handle health check and readiness check requests. +func StartHealthCheckServer(ctx context.Context, opts ...Option) error { + hcServer := &server{ + healthCheckRoute: defaultHealthCheckRoute, + readinessCheckRoute: defaultReadinessCheckRoute, + port: defaultPort, + } + + for _, opt := range opts { + opt(hcServer) + } + + handler := http.NewServeMux() + handle(handler, hcServer.healthCheckRoute, hcServer.healthCheckFunc) + handle(handler, hcServer.readinessCheckRoute, hcServer.readinessCheckFunc) + + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", hcServer.port), + Handler: handler, + } + + go func() { + <-ctx.Done() + if err := srv.Shutdown(ctx); err != nil { + log.Info("server shutdown: ", err) + } + }() + + if err := srv.ListenAndServe(); err != http.ErrServerClosed { + return err + } + + return nil +} diff --git a/health/http_test.go b/health/http_test.go new file mode 100644 index 0000000..935f162 --- /dev/null +++ b/health/http_test.go @@ -0,0 +1,153 @@ +package health_test + +import ( + "context" + "errors" + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + . "github.com/trustwallet/go-libs/health" +) + +func TestStartHealthCheckServer(t *testing.T) { + tests := []struct { + name string + healthCheckFunc func() error + readinessCheckFunc func() error + healthCheckRoute string + readinessCheckRoute string + port int + expHealthy bool + expReady bool + }{ + { + name: "default case", + expHealthy: true, + expReady: true, + }, + { + name: "not healthy", + healthCheckFunc: func() error { return errors.New("health check") }, + readinessCheckFunc: func() error { return nil }, + port: 1111, + expHealthy: false, + expReady: true, + }, + { + name: "not ready", + healthCheckFunc: func() error { return nil }, + readinessCheckFunc: func() error { return errors.New("health check") }, + port: 2222, + expHealthy: true, + expReady: false, + }, + { + name: "custom routes and port", + healthCheckFunc: func() error { return nil }, + readinessCheckFunc: func() error { return nil }, + healthCheckRoute: "/custom-health", + readinessCheckRoute: "/custom-ready", + port: 3333, + expHealthy: true, + expReady: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var opts []Option + if test.healthCheckFunc != nil { + opts = append(opts, WithHealthCheckFunc(test.healthCheckFunc)) + } + + if test.readinessCheckFunc != nil { + opts = append(opts, WithReadinessCheckFunc(test.readinessCheckFunc)) + } + + if test.healthCheckRoute != "" { + opts = append(opts, WithHealthCheckRoute(test.healthCheckRoute)) + } + + if test.readinessCheckRoute != "" { + opts = append(opts, WithReadinessCheckRoute(test.readinessCheckRoute)) + } + + if test.port != 0 { + opts = append(opts, WithPort(test.port)) + } + + port := 4444 + if test.port != 0 { + port = test.port + } + + healthRoute := "/health" + if test.healthCheckRoute != "" { + healthRoute = test.healthCheckRoute + } + + healthURL := fmt.Sprintf("http://:%d/%s", port, healthRoute) + + readinessRoute := "/ready" + if test.readinessCheckRoute != "" { + readinessRoute = test.readinessCheckRoute + } + + readinessURL := fmt.Sprintf("http://:%d/%s", port, readinessRoute) + + go func() { + assert.NoError(t, StartHealthCheckServer(ctx, opts...)) + }() + waitForServerToStart(t, healthURL, 20*time.Millisecond, 1*time.Second) + + resp, err := http.Get(healthURL) + assert.NoError(t, err) + assert.True(t, (test.expHealthy && resp.StatusCode == http.StatusOK) || (!test.expHealthy && resp.StatusCode != http.StatusOK)) + + resp, err = http.Get(readinessURL) + assert.NoError(t, err) + assert.True(t, (test.expReady && resp.StatusCode == http.StatusOK) || (!test.expReady && resp.StatusCode != http.StatusOK)) + + cancel() + }) + } +} + +func waitForServerToStart(t *testing.T, url string, interval time.Duration, timeout time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + now := time.Now() + for { + if time.Since(now) > timeout { + t.Fatal("timeout to connect to server") + return + } + + <-tick.C + if _, err := http.Get(url); err == nil { + return + } + } +} + +func TestServerClosedOnContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + assert.NoError(t, StartHealthCheckServer(ctx)) + }() + waitForServerToStart(t, "http://:4444/health", 20*time.Millisecond, 1*time.Second) + + cancel() + time.Sleep(time.Millisecond * 100) + _, err := http.Get("http://:4444/health") + assert.Error(t, err) // server was shut down +}