From 514495e52e107afffb0b728b7984a9490bdb573f Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Mon, 17 Jun 2024 20:12:31 +0800 Subject: [PATCH 1/2] feat: Add TLS support to server with custom certificate and key paths --- server/server.go | 24 +++++++++++++++++++++++- server/server_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index c6423d2..e3e7ec0 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/tls" "fmt" "net" "net/http" @@ -20,6 +21,8 @@ const ( var ErrServerAlreadyRunning = fmt.Errorf("server is already running") type Server struct { + certPath string + keyPath string baseCtx context.Context listener net.Listener listenerLock *sync.Mutex @@ -114,7 +117,11 @@ func (s *Server) Run() (err error) { close(s.ready) } - err = s.handler.Serve(s.listener) + if s.certPath != "" && s.keyPath != "" { + err = s.handler.ServeTLS(s.listener, s.certPath, s.keyPath) + } else { + err = s.handler.Serve(s.listener) + } if err != nil && err != http.ErrServerClosed { return err @@ -186,3 +193,18 @@ func WithBaseContext(ctx context.Context) Option { s.baseCtx = ctx } } + +// WithTLS is an option function that configures the server to use TLS (Transport Layer Security). +// It sets the certificate and key file paths, and optionally allows custom TLS configuration. +// The certificate and key file paths must be provided as arguments. +// If a custom TLS configuration is provided, it will be applied to the server's handler. +func WithTLS(certFile, keyFile string, config ...*tls.Config) Option { + return func(s *Server) { + s.certPath = certFile + s.keyPath = keyFile + + if len(config) > 0 { + s.handler.TLSConfig = config[0] + } + } +} diff --git a/server/server_test.go b/server/server_test.go index d6525ed..48f07a5 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,8 +2,11 @@ package server import ( "context" + "crypto/tls" + "errors" "fmt" "net/http" + "os" "testing" "time" @@ -277,3 +280,33 @@ func TestServer_Addr(t *testing.T) { t.Error("Expected non-empty address") } } +func TestServer_WithTLS(t *testing.T) { + // Create a new Server instance + server := NewServer(":0") + // Set TLS configuration using WithTLS + certPath := "/path/to/cert.pem" + keyPath := "/path/to/key.pem" + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + WithTLS(certPath, keyPath, tlsConfig)(server) + // Check if the certificate and key paths are set correctly + if server.certPath != certPath { + t.Errorf("Expected certificate path %s, but got %s", certPath, server.certPath) + } + if server.keyPath != keyPath { + t.Errorf("Expected key path %s, but got %s", keyPath, server.keyPath) + } + // Check if the TLS configuration is set correctly + if server.handler.TLSConfig == nil { + t.Error("Expected non-nil TLS configuration") + } + if server.handler.TLSConfig.InsecureSkipVerify != true { + t.Error("Expected InsecureSkipVerify to be true") + } + + err := server.Run() + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("Got unexpected error: %v", err) + } +} From 8db99633da9c86ad72ef943936791f98945e725a Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Mon, 17 Jun 2024 20:26:17 +0800 Subject: [PATCH 2/2] Makes code tidy --- server/server_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server/server_test.go b/server/server_test.go index 48f07a5..1db103e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -286,21 +286,28 @@ func TestServer_WithTLS(t *testing.T) { // Set TLS configuration using WithTLS certPath := "/path/to/cert.pem" keyPath := "/path/to/key.pem" + + // #nosec G402 - InsecureSkipVerify is used for testing purposes tlsConfig := &tls.Config{ InsecureSkipVerify: true, } + WithTLS(certPath, keyPath, tlsConfig)(server) + // Check if the certificate and key paths are set correctly if server.certPath != certPath { t.Errorf("Expected certificate path %s, but got %s", certPath, server.certPath) } + if server.keyPath != keyPath { t.Errorf("Expected key path %s, but got %s", keyPath, server.keyPath) } + // Check if the TLS configuration is set correctly if server.handler.TLSConfig == nil { t.Error("Expected non-nil TLS configuration") } + if server.handler.TLSConfig.InsecureSkipVerify != true { t.Error("Expected InsecureSkipVerify to be true") }