diff --git a/middleware/http/basic_auth.go b/middleware/http/basic_auth.go new file mode 100644 index 0000000..ac0a515 --- /dev/null +++ b/middleware/http/basic_auth.go @@ -0,0 +1,34 @@ +package http + +import "net/http" + +// NewBasicAuthMiddleware returns a middleware function that performs basic authentication. +// It takes a map of users and passwords, and a realm string as input. +// The returned middleware function checks if the request contains valid basic authentication credentials. +// If the credentials are valid, it calls the next handler in the chain. +// If the credentials are invalid or missing, it sends an HTTP 401 Unauthorized response. +func NewBasicAuthMiddleware(users map[string]string, realm string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + + if !ok { + unauthorized(w, realm) + return + } + + if p, ok := users[user]; !ok || p != pass { + unauthorized(w, realm) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// unauthorized sends an HTTP 401 Unauthorized response with the specified realm. +func unauthorized(w http.ResponseWriter, realm string) { + w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} diff --git a/middleware/http/basic_auth_test.go b/middleware/http/basic_auth_test.go new file mode 100644 index 0000000..8ed595c --- /dev/null +++ b/middleware/http/basic_auth_test.go @@ -0,0 +1,88 @@ +package http + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestUnauthorized(t *testing.T) { + w := httptest.NewRecorder() + realm := "Test Realm" + + unauthorized(w, realm) + + resp := w.Result() + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode) + } + + authHeader := resp.Header.Get("WWW-Authenticate") + expectedAuthHeader := `Basic realm="` + realm + `"` + + if authHeader != expectedAuthHeader { + t.Errorf("Expected WWW-Authenticate header %q, got %q", expectedAuthHeader, authHeader) + } +} + +func TestNewBasicAuthMiddleware(t *testing.T) { + users := map[string]string{ + "admin": "password", + "user": "123456", + } + realm := "Test Realm" + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware := NewBasicAuthMiddleware(users, realm) + handler := middleware(nextHandler) + + // Test case 1: Valid credentials + req1, _ := http.NewRequest("GET", "/", http.NoBody) + req1.SetBasicAuth("admin", "password") + + w1 := httptest.NewRecorder() + + handler.ServeHTTP(w1, req1) + + resp1 := w1.Result() + + defer resp1.Body.Close() + + if resp1.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp1.StatusCode) + } + + // Test case 2: Invalid credentials + req2, _ := http.NewRequest("GET", "/", http.NoBody) + req2.SetBasicAuth("admin", "wrongpassword") + + w2 := httptest.NewRecorder() + + handler.ServeHTTP(w2, req2) + + resp2 := w2.Result() + + defer resp2.Body.Close() + + if resp2.StatusCode != http.StatusUnauthorized { + t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp2.StatusCode) + } + + // Test case 3: Missing credentials + req3, _ := http.NewRequest("GET", "/", http.NoBody) + w3 := httptest.NewRecorder() + + handler.ServeHTTP(w3, req3) + + resp3 := w3.Result() + + defer resp3.Body.Close() + + if resp3.StatusCode != http.StatusUnauthorized { + t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp3.StatusCode) + } +}