Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Oct 4, 2024
1 parent d04d811 commit a5db173
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 33 deletions.
16 changes: 4 additions & 12 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,17 @@ import (
scyllaridae "github.com/lehigh-university-libraries/scyllaridae/internal/config"
)

var (
config *scyllaridae.ServerConfig
)

func init() {
var err error

config, err = scyllaridae.ReadConfig("scyllaridae.yml")
func main() {
config, err := scyllaridae.ReadConfig("scyllaridae.yml")
if err != nil {
slog.Error("Could not read YML", "err", err)
os.Exit(1)
}
}

func main() {
if len(config.QueueMiddlewares) > 0 {
runStompSubscribers(config)
} else {
runHTTPServer(config)
server := &Server{Config: config}
runHTTPServer(server)
}
}

27 changes: 15 additions & 12 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ type Test struct {
}

func TestMessageHandler_MethodNotAllowed(t *testing.T) {
testConfig := &scyllaridae.ServerConfig{}
server := &Server{Config: testConfig}

req, err := http.NewRequest("POST", "/", nil)
if err != nil {
t.Fatal(err)
}

rr := httptest.NewRecorder()
handler := http.HandlerFunc(MessageHandler)

handler := http.HandlerFunc(server.MessageHandler)
handler.ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusMethodNotAllowed {
t.Errorf("handler returned wrong status code: got %v want %v",
t.Errorf("Handler returned wrong status code: got %v want %v",
status, http.StatusMethodNotAllowed)
}
}
Expand Down Expand Up @@ -216,19 +217,20 @@ cmdByMimeType:
destinationServer := createMockDestinationServer(t, tt.returnedBody)
defer destinationServer.Close()

sourceServer := createMockSourceServer(t, tt.mimetype, tt.authHeader, destinationServer.URL)
defer sourceServer.Close()

os.Setenv("SCYLLARIDAE_YML", tt.yml)
// set the config based on tt.yml
config, err = scyllaridae.ReadConfig("")
config, err := scyllaridae.ReadConfig("")

sourceServer := createMockSourceServer(t, config, tt.mimetype, tt.authHeader, destinationServer.URL)
defer sourceServer.Close()
if err != nil {
t.Fatalf("Could not read YML: %v", err)
os.Exit(1)
}

// Create a Server instance with the test config
server := &Server{Config: config}

// Configure and start the main server
setupServer := httptest.NewServer(http.HandlerFunc(MessageHandler))
setupServer := httptest.NewServer(http.HandlerFunc(server.MessageHandler))
defer setupServer.Close()

// Send the mock message to the main server
Expand Down Expand Up @@ -260,6 +262,7 @@ cmdByMimeType:
}
})
}

}

func createMockDestinationServer(t *testing.T, content string) *httptest.Server {
Expand All @@ -270,7 +273,7 @@ func createMockDestinationServer(t *testing.T, content string) *httptest.Server
}))
}

func createMockSourceServer(t *testing.T, mimetype, auth, content string) *httptest.Server {
func createMockSourceServer(t *testing.T, config *scyllaridae.ServerConfig, mimetype, auth, content string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if config.ForwardAuth && r.Header.Get("Authorization") != auth {
w.WriteHeader(http.StatusUnauthorized)
Expand Down
20 changes: 11 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ import (
"github.com/lehigh-university-libraries/scyllaridae/pkg/api"
)

func runHTTPServer(config *scyllaridae.ServerConfig) {
type Server struct {
Config *scyllaridae.ServerConfig
}

func runHTTPServer(server *Server) {
http.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, "OK")
})

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
MessageHandler(w, r, config)
})
// Use the method as the handler
http.HandleFunc("/", server.MessageHandler)

port := os.Getenv("PORT")
if port == "" {
Expand All @@ -32,7 +35,7 @@ func runHTTPServer(config *scyllaridae.ServerConfig) {
}
}

func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae.ServerConfig) {
func (s *Server) MessageHandler(w http.ResponseWriter, r *http.Request) {
slog.Info(r.RequestURI, "method", r.Method, "ip", r.RemoteAddr, "proto", r.Proto)

if r.Method != http.MethodGet {
Expand All @@ -48,7 +51,7 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae.

// Read the Alpaca message payload
auth := ""
if config.ForwardAuth {
if s.Config.ForwardAuth {
auth = r.Header.Get("Authorization")
}
message, err := api.DecodeAlpacaMessage(r, auth)
Expand All @@ -65,7 +68,7 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae.
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if config.ForwardAuth {
if s.Config.ForwardAuth {
req.Header.Set("Authorization", auth)
}
sourceResp, err := http.DefaultClient.Do(req)
Expand All @@ -81,7 +84,7 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae.
return
}

cmd, err := scyllaridae.BuildExecCommand(message, config)
cmd, err := scyllaridae.BuildExecCommand(message, s.Config)
if err != nil {
slog.Error("Error building command", "err", err)
http.Error(w, "Bad request", http.StatusBadRequest)
Expand All @@ -103,4 +106,3 @@ func MessageHandler(w http.ResponseWriter, r *http.Request, config *scyllaridae.
return
}
}

0 comments on commit a5db173

Please sign in to comment.