diff --git a/examples/echo/main.go b/examples/echo/main.go index cc4ed62..993ff19 100644 --- a/examples/echo/main.go +++ b/examples/echo/main.go @@ -31,6 +31,7 @@ func main() { channel := channel.NewChannel("/", dispatcher, channel.NewConnectionRegistry(), channel.WithOriginPatterns("*")) server := server.NewServer(Addr, server.WithBaseContext(context.Background()), server.WithProfilerEndpoint()) + server.AddChannel(channel) if err := server.Run(); err != nil { diff --git a/server/config.go b/server/config.go new file mode 100644 index 0000000..bb36ed0 --- /dev/null +++ b/server/config.go @@ -0,0 +1,19 @@ +package server + +import ( + "time" +) + +type Config struct { + ReadHeaderTimeout time.Duration + ReadTimeout time.Duration +} + +var DefaultConfig = Config{ + ReadHeaderTimeout: ReadHeaderTimeoutSeconds * time.Second, + ReadTimeout: ReadTimeoutSeconds * time.Second, +} + +const ReadHeaderTimeoutSeconds = 3 + +const ReadTimeoutSeconds = 30 diff --git a/server/server.go b/server/server.go index 6bc2b40..f87c625 100644 --- a/server/server.go +++ b/server/server.go @@ -28,17 +28,11 @@ import ( "net" "net/http" "sync" - "time" "github.com/ksysoev/wasabi" "golang.org/x/exp/slog" ) -const ( - ReadHeaderTimeout = 3 * time.Second - ReadTimeout = 30 * time.Second -) - var ErrServerAlreadyRunning = fmt.Errorf("server is already running") type Server struct { @@ -57,6 +51,8 @@ type Server struct { type Option func(*Server) +type ctxConfigKey struct{} + // WithReadinessChan sets ch to [Server] and will be closed once the [Server] is // ready to accept connection. Typically used in testing after calling [Run] // method and waiting for ch to close, before continuing with test logics. @@ -86,10 +82,12 @@ func NewServer(addr string, opts ...Option) *Server { opt(server) } + serverConfig := server.GetServerConfig() + server.handler = &http.Server{ Addr: addr, - ReadHeaderTimeout: ReadHeaderTimeout, - ReadTimeout: ReadTimeout, + ReadHeaderTimeout: serverConfig.ReadHeaderTimeout, + ReadTimeout: serverConfig.ReadTimeout, BaseContext: func(_ net.Listener) context.Context { return server.baseCtx }, @@ -98,6 +96,16 @@ func NewServer(addr string, opts ...Option) *Server { return server } +// Helper to fetch server config. Returns Default config if base ctx has no config set +func (s *Server) GetServerConfig() Config { + config, ok := s.baseCtx.Value(ctxConfigKey{}).(Config) + if !ok { + return DefaultConfig + } + + return config +} + // AddChannel adds new channel to server func (s *Server) AddChannel(channel wasabi.Channel) { s.channels = append(s.channels, channel) @@ -217,7 +225,8 @@ func WithBaseContext(ctx context.Context) Option { } return func(s *Server) { - s.baseCtx = ctx + config := s.GetServerConfig() + s.baseCtx = context.WithValue(ctx, ctxConfigKey{}, config) } } @@ -250,3 +259,11 @@ func WithProfilerEndpoint() Option { s.pprofEnabled = true } } + +// WithServerConfig is an option function that overrides the default configuration settings of the server. +// This can give the client more control over the server feature functions +func WithServerConfig(config Config) Option { + return func(s *Server) { + s.baseCtx = context.WithValue(s.baseCtx, ctxConfigKey{}, config) + } +} diff --git a/server/server_test.go b/server/server_test.go index eefe407..3f2e8ab 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -76,6 +76,82 @@ func TestServer_WithBaseContext(t *testing.T) { } } +func TestServer_WithConfig(t *testing.T) { + // Create a new Server instance with config + serverConfig := Config{ + ReadHeaderTimeout: 1 * time.Second, + ReadTimeout: 10 * time.Second, + } + + server := NewServer(":0", WithServerConfig(serverConfig)) + + // Check if the base context was set correctly + if server.baseCtx == nil { + t.Error("Expected non-nil base context") + } + + if server.GetServerConfig() == DefaultConfig { + t.Errorf("Expected non-Default context value for server config") + } + + config := server.GetServerConfig() + + if config.ReadHeaderTimeout != 1*time.Second { + t.Errorf("Expected config ReadHeaderTimeout to be %s but got %s", 1*time.Second, config.ReadHeaderTimeout) + } + + if config.ReadTimeout != 10*time.Second { + t.Errorf("Expected config ReadTimeout to be %s but got %s", 10*time.Second, config.ReadTimeout) + } +} + +func TestServer_WithDefaultConfig(t *testing.T) { + // Create a new vanilla Server instance + server := NewServer(":0") + + // Check if the base context was set correctly + if server.baseCtx == nil { + t.Error("Expected non-nil base context") + } + + if server.GetServerConfig() != DefaultConfig { + t.Errorf("Expected Default context value for server config") + } +} + +func TestServer_WithBaseContextAndConfig(t *testing.T) { + serverConfig := Config{ + ReadHeaderTimeout: 2 * time.Second, + ReadTimeout: 20 * time.Second, + } + ctx := context.WithValue(context.Background(), testCtxKey("test"), "test") + + // Create 2 servers with different order of optional methods + server1 := NewServer(":0", WithBaseContext(ctx), WithServerConfig(serverConfig)) + server2 := NewServer(":1", WithServerConfig(serverConfig), WithBaseContext(ctx)) + + // Check if the base context was set correctly + if server1.baseCtx == nil || server2.baseCtx == nil { + t.Error("Expected non-nil base contexts for both servers") + } + + if server1.baseCtx.Value(testCtxKey("test")) != "test" { + t.Errorf("Expected context value 'test', but got '%s'", server1.baseCtx.Value("test")) + } + + if server2.baseCtx.Value(testCtxKey("test")) != "test" { + t.Errorf("Expected context value 'test', but got '%s'", server2.baseCtx.Value("test")) + } + + if server1.GetServerConfig() != serverConfig { + t.Errorf("Expected config for server1 to be %s but got %s", serverConfig, server1.GetServerConfig()) + } + + if server2.GetServerConfig() != serverConfig { + t.Errorf("Expected config for server2 to be %s but got %s", serverConfig, server2.GetServerConfig()) + } +} + func TestServer_WithReadinessChan(t *testing.T) { // Create a new Server instance with a base context ready := make(chan struct{})