-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #82 from ksysoev/73-basic-auth-middleware
Add basic authentication middleware for HTTP requests
- Loading branch information
Showing
2 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package http | ||
|
||
import "net/http" | ||
|
||
// NewBasicAuthMiddleware returns a middleware function that performs basic authentication. | ||
// It takes a map of users and passwords, and a realm string as input. | ||
// The returned middleware function checks if the request contains valid basic authentication credentials. | ||
// If the credentials are valid, it calls the next handler in the chain. | ||
// If the credentials are invalid or missing, it sends an HTTP 401 Unauthorized response. | ||
func NewBasicAuthMiddleware(users map[string]string, realm string) func(next http.Handler) http.Handler { | ||
return func(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
user, pass, ok := r.BasicAuth() | ||
|
||
if !ok { | ||
unauthorized(w, realm) | ||
return | ||
} | ||
|
||
if p, ok := users[user]; !ok || p != pass { | ||
unauthorized(w, realm) | ||
return | ||
} | ||
|
||
next.ServeHTTP(w, r) | ||
}) | ||
} | ||
} | ||
|
||
// unauthorized sends an HTTP 401 Unauthorized response with the specified realm. | ||
func unauthorized(w http.ResponseWriter, realm string) { | ||
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`) | ||
http.Error(w, "Unauthorized", http.StatusUnauthorized) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package http | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
) | ||
|
||
func TestUnauthorized(t *testing.T) { | ||
w := httptest.NewRecorder() | ||
realm := "Test Realm" | ||
|
||
unauthorized(w, realm) | ||
|
||
resp := w.Result() | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusUnauthorized { | ||
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode) | ||
} | ||
|
||
authHeader := resp.Header.Get("WWW-Authenticate") | ||
expectedAuthHeader := `Basic realm="` + realm + `"` | ||
|
||
if authHeader != expectedAuthHeader { | ||
t.Errorf("Expected WWW-Authenticate header %q, got %q", expectedAuthHeader, authHeader) | ||
} | ||
} | ||
|
||
func TestNewBasicAuthMiddleware(t *testing.T) { | ||
users := map[string]string{ | ||
"admin": "password", | ||
"user": "123456", | ||
} | ||
realm := "Test Realm" | ||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
middleware := NewBasicAuthMiddleware(users, realm) | ||
handler := middleware(nextHandler) | ||
|
||
// Test case 1: Valid credentials | ||
req1, _ := http.NewRequest("GET", "/", http.NoBody) | ||
req1.SetBasicAuth("admin", "password") | ||
|
||
w1 := httptest.NewRecorder() | ||
|
||
handler.ServeHTTP(w1, req1) | ||
|
||
resp1 := w1.Result() | ||
|
||
defer resp1.Body.Close() | ||
|
||
if resp1.StatusCode != http.StatusOK { | ||
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp1.StatusCode) | ||
} | ||
|
||
// Test case 2: Invalid credentials | ||
req2, _ := http.NewRequest("GET", "/", http.NoBody) | ||
req2.SetBasicAuth("admin", "wrongpassword") | ||
|
||
w2 := httptest.NewRecorder() | ||
|
||
handler.ServeHTTP(w2, req2) | ||
|
||
resp2 := w2.Result() | ||
|
||
defer resp2.Body.Close() | ||
|
||
if resp2.StatusCode != http.StatusUnauthorized { | ||
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp2.StatusCode) | ||
} | ||
|
||
// Test case 3: Missing credentials | ||
req3, _ := http.NewRequest("GET", "/", http.NoBody) | ||
w3 := httptest.NewRecorder() | ||
|
||
handler.ServeHTTP(w3, req3) | ||
|
||
resp3 := w3.Result() | ||
|
||
defer resp3.Body.Close() | ||
|
||
if resp3.StatusCode != http.StatusUnauthorized { | ||
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp3.StatusCode) | ||
} | ||
} |