From a5db173321d29dd70784b104ae77b91bb777b360 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Fri, 4 Oct 2024 14:40:38 -0400 Subject: [PATCH] fix tests --- main.go | 16 ++++------------ main_test.go | 27 +++++++++++++++------------ server.go | 20 +++++++++++--------- 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/main.go b/main.go index 7ddc083..3cef14d 100644 --- a/main.go +++ b/main.go @@ -7,25 +7,17 @@ import ( scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config" ) -var ( - config *scyllaridae.ServerConfig -) - -func init() { - var err error - - config, err = scyllaridae.ReadConfig("scyllaridae.yml") +func main() { + config, err := scyllaridae.ReadConfig("scyllaridae.yml") if err != nil { slog.Error("Could not read YML", "err", err) os.Exit(1) } -} -func main() { if len(config.QueueMiddlewares) > 0 { runStompSubscribers(config) } else { - runHTTPServer(config) + server := &Server{Config: config} + runHTTPServer(server) } } - diff --git a/main_test.go b/main_test.go index a4ad585..327a682 100644 --- a/main_test.go +++ b/main_test.go @@ -25,18 +25,19 @@ type Test struct { } func TestMessageHandler_MethodNotAllowed(t *testing.T) { + testConfig := &scyllaridae.ServerConfig{} + server := &Server{Config: testConfig} + req, err := http.NewRequest("POST", "/", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(MessageHandler) - + handler := http.HandlerFunc(server.MessageHandler) handler.ServeHTTP(rr, req) - if status := rr.Code; status != http.StatusMethodNotAllowed { - t.Errorf("handler returned wrong status code: got %v want %v", + t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusMethodNotAllowed) } } @@ -216,19 +217,20 @@ cmdByMimeType: destinationServer := createMockDestinationServer(t, tt.returnedBody) defer destinationServer.Close() - sourceServer := createMockSourceServer(t, tt.mimetype, tt.authHeader, destinationServer.URL) - defer sourceServer.Close() - os.Setenv("SCYLLARIDAE_YML", tt.yml) - // set the config based on tt.yml - config, err = scyllaridae.ReadConfig("") + config, err := scyllaridae.ReadConfig("") + + sourceServer := createMockSourceServer(t, config, tt.mimetype, tt.authHeader, destinationServer.URL) + defer sourceServer.Close() if err != nil { t.Fatalf("Could not read YML: %v", err) - os.Exit(1) } + // Create a Server instance with the test config + server := &Server{Config: config} + // Configure and start the main server - setupServer := httptest.NewServer(http.HandlerFunc(MessageHandler)) + setupServer := httptest.NewServer(http.HandlerFunc(server.MessageHandler)) defer setupServer.Close() // Send the mock message to the main server @@ -260,6 +262,7 @@ cmdByMimeType: } }) } + } func createMockDestinationServer(t *testing.T, content string) *httptest.Server { @@ -270,7 +273,7 @@ func createMockDestinationServer(t *testing.T, content string) *httptest.Server })) } -func createMockSourceServer(t *testing.T, mimetype, auth, content string) *httptest.Server { +func createMockSourceServer(t *testing.T, config *scyllaridae.ServerConfig, mimetype, auth, content string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if config.ForwardAuth && r.Header.Get("Authorization") != auth { w.WriteHeader(http.StatusUnauthorized) diff --git a/server.go b/server.go index fa723ae..9bd25a2 100644 --- a/server.go +++ b/server.go @@ -11,15 +11,18 @@ import ( "github.com/lehigh-university-libraries/scyllaridae/pkg/api" ) -func runHTTPServer(config *scyllaridae.ServerConfig) { +type Server struct { + Config *scyllaridae.ServerConfig +} + +func runHTTPServer(server *Server) { http.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintln(w, "OK") }) - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - MessageHandler(w, r, config) - }) + // Use the method as the handler + http.HandleFunc("/", server.MessageHandler) port := os.Getenv("PORT") if port == "" { @@ -32,7 +35,7 @@ func runHTTPServer(config *scyllaridae.ServerConfig) { } } -func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae.ServerConfig) { +func (s *Server) MessageHandler(w http.ResponseWriter, r *http.Request) { slog.Info(r.RequestURI, "method", r.Method, "ip", r.RemoteAddr, "proto", r.Proto) if r.Method != http.MethodGet { @@ -48,7 +51,7 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae. // Read the Alpaca message payload auth := "" - if config.ForwardAuth { + if s.Config.ForwardAuth { auth = r.Header.Get("Authorization") } message, err := api.DecodeAlpacaMessage(r, auth) @@ -65,7 +68,7 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae. http.Error(w, "Bad request", http.StatusBadRequest) return } - if config.ForwardAuth { + if s.Config.ForwardAuth { req.Header.Set("Authorization", auth) } sourceResp, err := http.DefaultClient.Do(req) @@ -81,7 +84,7 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae. return } - cmd, err := scyllaridae.BuildExecCommand(message, config) + cmd, err := scyllaridae.BuildExecCommand(message, s.Config) if err != nil { slog.Error("Error building command", "err", err) http.Error(w, "Bad request", http.StatusBadRequest) @@ -103,4 +106,3 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae. return } } -