From 6b1e3ba5e065d77f0185ce4f7719366bbc885c7e Mon Sep 17 00:00:00 2001 From: shan-96 Date: Mon, 24 Jun 2024 23:41:15 +0700 Subject: [PATCH] address review comments --- .golangci.yml | 9 ---- examples/echo/main.go | 2 +- examples/http_backend/main.go | 2 +- examples/passthrough/main.go | 2 +- server/{server_config.go => config.go} | 4 +- server/server.go | 27 ++++++++--- server/server_test.go | 67 ++++++++++++++++++++------ tests/echo_test.go | 2 +- 8 files changed, 78 insertions(+), 37 deletions(-) rename server/{server_config.go => config.go} (81%) diff --git a/.golangci.yml b/.golangci.yml index 257fbcc..46d3ef8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,13 +1,4 @@ linters-settings: - revive: - rules: - # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#exported - - name: exported - severity: warning - disabled: false - exclude: [""] - arguments: - - "disableStutteringCheck" exhaustive: default-signifies-exhaustive: true errcheck: diff --git a/examples/echo/main.go b/examples/echo/main.go index 8bc8a59..993ff19 100644 --- a/examples/echo/main.go +++ b/examples/echo/main.go @@ -30,7 +30,7 @@ func main() { }) channel := channel.NewChannel("/", dispatcher, channel.NewConnectionRegistry(), channel.WithOriginPatterns("*")) - server := server.NewServer(Addr, server.DefaultServerConfig, server.WithBaseContext(context.Background()), server.WithProfilerEndpoint()) + server := server.NewServer(Addr, server.WithBaseContext(context.Background()), server.WithProfilerEndpoint()) server.AddChannel(channel) diff --git a/examples/http_backend/main.go b/examples/http_backend/main.go index 14e8b8b..4325076 100644 --- a/examples/http_backend/main.go +++ b/examples/http_backend/main.go @@ -36,7 +36,7 @@ func main() { channel := channel.NewChannel("/", dispatcher, channel.NewConnectionRegistry(), channel.WithOriginPatterns("*")) - server := server.NewServer(Addr, server.DefaultServerConfig, server.WithBaseContext(context.Background())) + server := server.NewServer(Addr, server.WithBaseContext(context.Background())) server.AddChannel(channel) if err := server.Run(); err != nil { diff --git a/examples/passthrough/main.go b/examples/passthrough/main.go index 9a7e461..12f04e8 100644 --- a/examples/passthrough/main.go +++ b/examples/passthrough/main.go @@ -42,7 +42,7 @@ func main() { }) channel := channel.NewChannel("/", dispatcher, channel.NewConnectionRegistry(), channel.WithOriginPatterns("*")) - server := server.NewServer(Addr, server.DefaultServerConfig, server.WithBaseContext(context.Background())) + server := server.NewServer(Addr, server.WithBaseContext(context.Background())) server.AddChannel(channel) if err := server.Run(); err != nil { diff --git a/server/server_config.go b/server/config.go similarity index 81% rename from server/server_config.go rename to server/config.go index 6908068..bb36ed0 100644 --- a/server/server_config.go +++ b/server/config.go @@ -4,12 +4,12 @@ import ( "time" ) -type ServerConfig struct { +type Config struct { ReadHeaderTimeout time.Duration ReadTimeout time.Duration } -var DefaultServerConfig = ServerConfig{ +var DefaultConfig = Config{ ReadHeaderTimeout: ReadHeaderTimeoutSeconds * time.Second, ReadTimeout: ReadTimeoutSeconds * time.Second, } diff --git a/server/server.go b/server/server.go index 17d2389..e09bb01 100644 --- a/server/server.go +++ b/server/server.go @@ -27,7 +27,6 @@ import ( "fmt" "net" "net/http" - "reflect" "sync" "github.com/ksysoev/wasabi" @@ -52,6 +51,8 @@ type Server struct { type Option func(*Server) +type ctxConfigKey struct{} + // WithReadinessChan sets ch to [Server] and will be closed once the [Server] is // ready to accept connection. Typically used in testing after calling [Run] // method and waiting for ch to close, before continuing with test logics. @@ -64,7 +65,7 @@ func WithReadinessChan(ch chan<- struct{}) Option { // NewServer creates new instance of Wasabi server // port - port to listen on // returns new instance of Server -func NewServer(addr string, serverConfig ServerConfig, opts ...Option) *Server { +func NewServer(addr string, opts ...Option) *Server { if addr == "" { addr = ":http" } @@ -81,6 +82,8 @@ func NewServer(addr string, serverConfig ServerConfig, opts ...Option) *Server { opt(server) } + serverConfig := server.GetServerConfig() + server.handler = &http.Server{ Addr: addr, ReadHeaderTimeout: serverConfig.ReadHeaderTimeout, @@ -90,14 +93,17 @@ func NewServer(addr string, serverConfig ServerConfig, opts ...Option) *Server { }, } - server.ApplyServerConfig(serverConfig) - return server } -// Applies the ServerConfig as a value in the baseCtx of the server -func (s *Server) ApplyServerConfig(serverConfig ServerConfig) { - s.baseCtx = context.WithValue(s.baseCtx, reflect.TypeOf(serverConfig), serverConfig) +// Helper to fetch server config. Returns Default config if base ctx has no config set +func (s *Server) GetServerConfig() Config { + config, ok := s.baseCtx.Value(ctxConfigKey{}).(Config) + if !ok { + return DefaultConfig + } + + return config } // AddChannel adds new channel to server @@ -252,3 +258,10 @@ func WithProfilerEndpoint() Option { s.pprofEnabled = true } } + +// TODO: add comment +func WithServerConfig(config Config) Option { + return func(s *Server) { + s.baseCtx = context.WithValue(s.baseCtx, ctxConfigKey{}, config) + } +} diff --git a/server/server_test.go b/server/server_test.go index 0c13d03..563b9a7 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "os" - "reflect" "testing" "time" @@ -20,7 +19,7 @@ type testCtxKey string func TestNewServer(t *testing.T) { addr := ":8080" - server := NewServer(addr, DefaultServerConfig) + server := NewServer(addr) if server.addr != addr { t.Errorf("Expected port %s, but got %s", addr, server.addr) @@ -34,7 +33,7 @@ func TestNewServer(t *testing.T) { t.Error("Expected non-nil mutex") } - server = NewServer("", DefaultServerConfig) + server = NewServer("") if server.addr != ":http" { t.Errorf("Expected default port :http, but got %s", server.addr) @@ -42,7 +41,7 @@ func TestNewServer(t *testing.T) { } func TestServer_AddChannel(t *testing.T) { // Create a new Server instance - server := NewServer(":0", DefaultServerConfig) + server := NewServer(":0") // Create a new channel channel := mocks.NewMockChannel(t) @@ -65,7 +64,7 @@ func TestServer_WithBaseContext(t *testing.T) { // Create a new Server instance with a base context ctx := context.WithValue(context.Background(), testCtxKey("test"), "test") - server := NewServer(":0", DefaultServerConfig, WithBaseContext(ctx)) + server := NewServer(":0", WithBaseContext(ctx)) // Check if the base context was set correctly if server.baseCtx == nil { @@ -75,17 +74,55 @@ func TestServer_WithBaseContext(t *testing.T) { if server.baseCtx.Value(testCtxKey("test")) != "test" { t.Errorf("Expected context value 'test', but got '%s'", server.baseCtx.Value("test")) } +} + +func TestServer_WithConfig(t *testing.T) { + // Create a new Server instance with config + serverConfig := Config{ + ReadHeaderTimeout: 1 * time.Second, + ReadTimeout: 10 * time.Second, + } + + server := NewServer(":0", WithServerConfig(serverConfig)) + + // Check if the base context was set correctly + if server.baseCtx == nil { + t.Error("Expected non-nil base context") + } + + if server.GetServerConfig() == DefaultConfig { + t.Errorf("Expected non-Default context value for server config") + } + + config := server.GetServerConfig() + + if config.ReadHeaderTimeout != 1*time.Second { + t.Errorf("Expected config ReadHeaderTimeout to be %s but got %s", 1*time.Second, config.ReadHeaderTimeout) + } + + if config.ReadTimeout != 10*time.Second { + t.Errorf("Expected config ReadTimeout to be %s but got %s", 10*time.Second, config.ReadTimeout) + } +} + +func TestServer_WithDefaultConfig(t *testing.T) { + // Create a new vanilla Server instance + server := NewServer(":0") + + // Check if the base context was set correctly + if server.baseCtx == nil { + t.Error("Expected non-nil base context") + } - // Check that server config is part of context - if server.baseCtx.Value(reflect.TypeFor[ServerConfig]()) != DefaultServerConfig { - t.Errorf("Expected context to key value ServerConfig with value '%s' but got '%s'", DefaultServerConfig, server.baseCtx.Value(reflect.TypeFor[ServerConfig]())) + if server.GetServerConfig() != DefaultConfig { + t.Errorf("Expected Default context value for server config") } } func TestServer_WithReadinessChan(t *testing.T) { // Create a new Server instance with a base context ready := make(chan struct{}) - server := NewServer(":0", DefaultServerConfig, WithReadinessChan(ready)) + server := NewServer(":0", WithReadinessChan(ready)) if server.ready == nil { t.Error("Expected non-nil channel") @@ -105,7 +142,7 @@ func TestServer_Run(t *testing.T) { t.Run(fmt.Sprintf("%d times of calling Run", run), func(t *testing.T) { // Create a new Server instance ready := make(chan struct{}) - server := NewServer(":0", DefaultServerConfig, WithReadinessChan(ready)) + server := NewServer(":0", WithReadinessChan(ready)) channel := mocks.NewMockChannel(t) channel.EXPECT().Path().Return("/test") @@ -158,7 +195,7 @@ func TestServer_Run(t *testing.T) { func TestServer_Close(t *testing.T) { // Create a new Server instance ready := make(chan struct{}) - server := NewServer(":0", DefaultServerConfig, WithReadinessChan(ready)) + server := NewServer(":0", WithReadinessChan(ready)) // Create a context with a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -208,7 +245,7 @@ func TestServer_Close(t *testing.T) { func TestServer_Close_NoContext(t *testing.T) { // Create a new Server instance ready := make(chan struct{}) - server := NewServer(":0", DefaultServerConfig, WithReadinessChan(ready)) + server := NewServer(":0", WithReadinessChan(ready)) // Create a mock channel channel := mocks.NewMockChannel(t) @@ -255,7 +292,7 @@ func TestServer_Addr(t *testing.T) { // Create a new Server instance done := make(chan struct{}) - server := NewServer(":0", DefaultServerConfig, WithReadinessChan(done)) + server := NewServer(":0", WithReadinessChan(done)) defer server.Close() @@ -290,7 +327,7 @@ func TestServer_Addr(t *testing.T) { } func TestServer_WithTLS(t *testing.T) { // Create a new Server instance - server := NewServer(":0", DefaultServerConfig) + server := NewServer(":0") // Set TLS configuration using WithTLS certPath := "/path/to/cert.pem" keyPath := "/path/to/key.pem" @@ -329,7 +366,7 @@ func TestServer_WithTLS(t *testing.T) { func TestServer_WithProfilerEndpoint(t *testing.T) { ready := make(chan struct{}) // Create a new Server instance - server := NewServer(":0", DefaultServerConfig, WithReadinessChan(ready)) + server := NewServer(":0", WithReadinessChan(ready)) // Check if the profiler endpoint is disabled by default if server.pprofEnabled { diff --git a/tests/echo_test.go b/tests/echo_test.go index 69bdd49..0234044 100644 --- a/tests/echo_test.go +++ b/tests/echo_test.go @@ -28,7 +28,7 @@ func TestEcho(t *testing.T) { ch := channel.NewChannel("/", dispatcher, channel.NewConnectionRegistry(), channel.WithOriginPatterns("*")) ready := make(chan struct{}) - s := server.NewServer(":0", server.DefaultServerConfig, server.WithBaseContext(context.Background()), server.WithReadinessChan(ready)) + s := server.NewServer(":0", server.WithBaseContext(context.Background()), server.WithReadinessChan(ready)) s.AddChannel(ch) go func() {