Skip to content

Commit

Permalink
Move authorized status codes to HTTP response header
Browse files Browse the repository at this point in the history
  • Loading branch information
chacha912 committed Nov 1, 2024
1 parent 51d0ada commit 035634b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 45 deletions.
21 changes: 2 additions & 19 deletions api/types/auth_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,10 @@ func NewAuthWebhookRequest(reader io.Reader) (*AuthWebhookRequest, error) {
return req, nil
}

// Code represents the result of an authentication webhook request.
type Code int

const (
// CodeOK indicates that the request is fully authenticated and has
// the necessary permissions.
CodeOK Code = 200

// CodeUnauthenticated indicates that the request does not have valid
// authentication credentials for the operation.
CodeUnauthenticated Code = 401

// CodePermissionDenied indicates that the authenticated request lacks
// the necessary permissions.
CodePermissionDenied Code = 403
)

// AuthWebhookResponse represents the response of authentication webhook.
type AuthWebhookResponse struct {
Code Code `json:"code"`
Message string `json:"message"`
Allowed bool `json:"allowed"`
Reason string `json:"reason"`
}

// NewAuthWebhookResponse creates a new instance of AuthWebhookResponse.
Expand Down
20 changes: 11 additions & 9 deletions server/rpc/auth/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func verifyAccess(
cacheKey := string(reqBody)
if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok {
resp := entry
if resp.Code != types.CodeOK {
return fmt.Errorf("%s: %w", resp.Message, ErrPermissionDenied)
if !resp.Allowed {
return fmt.Errorf("%s: %w", resp.Reason, ErrPermissionDenied)
}
return nil
}
Expand All @@ -93,7 +93,9 @@ func verifyAccess(
}
}()

if http.StatusOK != resp.StatusCode {
if resp.StatusCode != http.StatusOK &&
resp.StatusCode != http.StatusUnauthorized &&
resp.StatusCode != http.StatusForbidden {
return resp.StatusCode, ErrUnexpectedStatusCode
}

Expand All @@ -102,20 +104,20 @@ func verifyAccess(
return resp.StatusCode, err
}

if authResp.Code == types.CodeOK {
if resp.StatusCode == http.StatusOK && authResp.Allowed {
return resp.StatusCode, nil
}
if authResp.Code == types.CodePermissionDenied {
return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrPermissionDenied)
if resp.StatusCode == http.StatusForbidden && !authResp.Allowed {
return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrPermissionDenied)
}
if authResp.Code == types.CodeUnauthenticated {
if resp.StatusCode == http.StatusUnauthorized && !authResp.Allowed {
return resp.StatusCode, metaerrors.New(
ErrUnauthenticated,
map[string]string{"message": authResp.Message},
map[string]string{"reason": authResp.Reason},
)
}

return resp.StatusCode, fmt.Errorf("%d: %w", authResp.Code, ErrUnexpectedResponse)
return resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedResponse)
}); err != nil {
if errors.Is(err, ErrPermissionDenied) {
be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheUnauthTTL())
Expand Down
40 changes: 23 additions & 17 deletions test/integration/auth_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) {

var res types.AuthWebhookResponse
if req.Token == token {
res.Code = types.CodeOK
w.WriteHeader(http.StatusOK) // 200
res.Allowed = true
} else if req.Token == "not allowed token" {
res.Code = types.CodePermissionDenied
w.WriteHeader(http.StatusForbidden) // 403
res.Allowed = false
} else if req.Token == "" {
res.Code = types.CodeUnauthenticated
res.Message = "no token"
w.WriteHeader(http.StatusUnauthorized) // 401
res.Allowed = false
res.Reason = "no token"
} else {
res.Code = types.CodeUnauthenticated
res.Message = "invalid token"
w.WriteHeader(http.StatusUnauthorized) // 401
res.Allowed = false
res.Reason = "invalid token"
}

_, err = res.Write(w)
Expand All @@ -73,7 +77,7 @@ func newUnavailableAuthServer(t *testing.T, recoveryCnt uint64) *httptest.Server
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Code = types.CodeOK
res.Allowed = true

if requestCount < recoveryCnt {
w.WriteHeader(http.StatusServiceUnavailable)
Expand Down Expand Up @@ -150,7 +154,7 @@ func TestProjectAuthWebhook(t *testing.T) {
defer func() { assert.NoError(t, cliWithoutToken.Close()) }()
err = cliWithoutToken.Activate(ctx)
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "message": "no token"}, converter.ErrorMetadataOf(err))
assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "reason": "no token"}, converter.ErrorMetadataOf(err))

// client with invalid token
cliWithInvalidToken, err := client.Dial(
Expand All @@ -162,7 +166,7 @@ func TestProjectAuthWebhook(t *testing.T) {
defer func() { assert.NoError(t, cliWithInvalidToken.Close()) }()
err = cliWithInvalidToken.Activate(ctx)
assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err))
assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "message": "invalid token"}, converter.ErrorMetadataOf(err))
assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "reason": "invalid token"}, converter.ErrorMetadataOf(err))
})

t.Run("permission denied response test", func(t *testing.T) {
Expand Down Expand Up @@ -256,7 +260,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) {
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Code = types.CodeOK
res.Allowed = true

// unexpected status code
w.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -298,8 +302,8 @@ func TestAuthWebhookErrorHandling(t *testing.T) {
assert.NoError(t, err)

var res types.AuthWebhookResponse
// unexpected response code
res.Code = 555
// mismatched response
res.Allowed = false

_, err = res.Write(w)
assert.NoError(t, err)
Expand Down Expand Up @@ -402,7 +406,7 @@ func TestAuthWebhookCache(t *testing.T) {
assert.NoError(t, err)

var res types.AuthWebhookResponse
res.Code = types.CodeOK
res.Allowed = true

_, err = res.Write(w)
assert.NoError(t, err)
Expand Down Expand Up @@ -479,8 +483,9 @@ func TestAuthWebhookCache(t *testing.T) {
_, err := types.NewAuthWebhookRequest(r.Body)
assert.NoError(t, err)

w.WriteHeader(http.StatusForbidden)
var res types.AuthWebhookResponse
res.Code = types.CodePermissionDenied
res.Allowed = false

_, err = res.Write(w)
assert.NoError(t, err)
Expand Down Expand Up @@ -541,8 +546,9 @@ func TestAuthWebhookCache(t *testing.T) {
_, err := types.NewAuthWebhookRequest(r.Body)
assert.NoError(t, err)

w.WriteHeader(http.StatusUnauthorized)
var res types.AuthWebhookResponse
res.Code = types.CodeUnauthenticated
res.Allowed = false

_, err = res.Write(w)
assert.NoError(t, err)
Expand Down Expand Up @@ -598,7 +604,7 @@ func TestAuthWebhookCache(t *testing.T) {
}

func TestAuthWebhookNewToken(t *testing.T) {
t.Run("reactivate with new token when receiving invalid token test", func(t *testing.T) {
t.Run("set new token when receiving invalid token test", func(t *testing.T) {
ctx := context.Background()
authServer, validToken := newAuthServer(t)

Expand Down Expand Up @@ -633,7 +639,7 @@ func TestAuthWebhookNewToken(t *testing.T) {
// reactivate with new token
if err != nil {
metadata := converter.ErrorMetadataOf(err)
if metadata["message"] == "invalid token" {
if metadata["reason"] == "invalid token" {
err = cli.SetToken(validToken)
assert.NoError(t, err)
err = cli.Activate(ctx)
Expand Down

0 comments on commit 035634b

Please sign in to comment.