Skip to content

Commit

Permalink
Merge pull request #82 from ksysoev/73-basic-auth-middleware
Browse files Browse the repository at this point in the history
Add basic authentication middleware for HTTP requests
  • Loading branch information
ksysoev committed Jun 12, 2024
2 parents 7cf34e4 + d87d3bc commit ad6128a
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
34 changes: 34 additions & 0 deletions middleware/http/basic_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package http

import "net/http"

// NewBasicAuthMiddleware returns a middleware function that performs basic authentication.
// It takes a map of users and passwords, and a realm string as input.
// The returned middleware function checks if the request contains valid basic authentication credentials.
// If the credentials are valid, it calls the next handler in the chain.
// If the credentials are invalid or missing, it sends an HTTP 401 Unauthorized response.
func NewBasicAuthMiddleware(users map[string]string, realm string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, pass, ok := r.BasicAuth()

if !ok {
unauthorized(w, realm)
return
}

if p, ok := users[user]; !ok || p != pass {
unauthorized(w, realm)
return
}

next.ServeHTTP(w, r)
})
}
}

// unauthorized sends an HTTP 401 Unauthorized response with the specified realm.
func unauthorized(w http.ResponseWriter, realm string) {
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
88 changes: 88 additions & 0 deletions middleware/http/basic_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package http

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestUnauthorized(t *testing.T) {
w := httptest.NewRecorder()
realm := "Test Realm"

unauthorized(w, realm)

resp := w.Result()
defer resp.Body.Close()

if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}

authHeader := resp.Header.Get("WWW-Authenticate")
expectedAuthHeader := `Basic realm="` + realm + `"`

if authHeader != expectedAuthHeader {
t.Errorf("Expected WWW-Authenticate header %q, got %q", expectedAuthHeader, authHeader)
}
}

func TestNewBasicAuthMiddleware(t *testing.T) {
users := map[string]string{
"admin": "password",
"user": "123456",
}
realm := "Test Realm"
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

middleware := NewBasicAuthMiddleware(users, realm)
handler := middleware(nextHandler)

// Test case 1: Valid credentials
req1, _ := http.NewRequest("GET", "/", http.NoBody)
req1.SetBasicAuth("admin", "password")

w1 := httptest.NewRecorder()

handler.ServeHTTP(w1, req1)

resp1 := w1.Result()

defer resp1.Body.Close()

if resp1.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp1.StatusCode)
}

// Test case 2: Invalid credentials
req2, _ := http.NewRequest("GET", "/", http.NoBody)
req2.SetBasicAuth("admin", "wrongpassword")

w2 := httptest.NewRecorder()

handler.ServeHTTP(w2, req2)

resp2 := w2.Result()

defer resp2.Body.Close()

if resp2.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp2.StatusCode)
}

// Test case 3: Missing credentials
req3, _ := http.NewRequest("GET", "/", http.NoBody)
w3 := httptest.NewRecorder()

handler.ServeHTTP(w3, req3)

resp3 := w3.Result()

defer resp3.Body.Close()

if resp3.StatusCode != http.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp3.StatusCode)
}
}

0 comments on commit ad6128a

Please sign in to comment.