diff --git a/examples/echo/main.go b/examples/echo/main.go index 7d7b74e..cc4ed62 100644 --- a/examples/echo/main.go +++ b/examples/echo/main.go @@ -30,7 +30,7 @@ func main() { }) channel := channel.NewChannel("/", dispatcher, channel.NewConnectionRegistry(), channel.WithOriginPatterns("*")) - server := server.NewServer(Addr, server.WithBaseContext(context.Background())) + server := server.NewServer(Addr, server.WithBaseContext(context.Background()), server.WithProfilerEndpoint()) server.AddChannel(channel) if err := server.Run(); err != nil { diff --git a/server/server.go b/server/server.go index e2d3f01..3b66afd 100644 --- a/server/server.go +++ b/server/server.go @@ -52,6 +52,7 @@ type Server struct { ready chan<- struct{} addr string channels []wasabi.Channel + pprofEnabled bool } type Option func(*Server) @@ -121,6 +122,10 @@ func (s *Server) Run() (err error) { ) } + if s.pprofEnabled { + mux.Handle("/debug/pprof/", http.DefaultServeMux) + } + s.handler.Handler = mux s.listenerLock.Lock() @@ -229,3 +234,12 @@ func WithTLS(certFile, keyFile string, config ...*tls.Config) Option { } } } + +// WithProfilerEndpoint is an option function that enables the profiler endpoint for the server. +// Enabling the profiler endpoint allows profiling and performance monitoring of the server. +// The profiler endpoint is available at /debug/pprof/. +func WithProfilerEndpoint() Option { + return func(s *Server) { + s.pprofEnabled = true + } +} diff --git a/server/server_test.go b/server/server_test.go index 1db103e..70f1330 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -317,3 +317,49 @@ func TestServer_WithTLS(t *testing.T) { t.Errorf("Got unexpected error: %v", err) } } +func TestServer_WithProfilerEndpoint(t *testing.T) { + ready := make(chan struct{}) + // Create a new Server instance + server := NewServer(":0", WithReadinessChan(ready)) + + // Check if the profiler endpoint is disabled by default + if server.pprofEnabled { + t.Error("Expected profiler endpoint to be disabled") + } + + // Apply the WithProfilerEndpoint option + WithProfilerEndpoint()(server) + + // Check if the profiler endpoint is enabled + if !server.pprofEnabled { + t.Error("Expected profiler endpoint to be enabled") + } + + go func() { + err := server.Run() + if err != nil { + t.Errorf("Got unexpected error: %v", err) + } + }() + + defer server.Close() + + select { + case <-ready: + case <-time.After(1 * time.Second): + t.Error("Expected server to start") + } + + // Check if the profiler endpoint is enabled + res, err := http.Get("http://" + server.Addr().String() + "/debug/pprof/") + if err != nil { + t.Errorf("Got unexpected error: %v", err) + } + + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + // TODO: Fix this test, WHY IS IT FAILING? + t.Errorf("Expected status code 200, but got %d", res.StatusCode) + } +}