From cdd952ef118ca3aa925faa16d7726e755d6e6267 Mon Sep 17 00:00:00 2001 From: Yourim Cha Date: Wed, 16 Oct 2024 15:08:51 +0900 Subject: [PATCH 1/9] Add Code type for authentication webhook response --- api/types/auth_webhook.go | 21 +++++++++++++++++-- server/rpc/auth/webhook.go | 30 +++++++++++++++++++-------- server/rpc/connecthelper/status.go | 10 +++++++-- test/integration/auth_webhook_test.go | 22 ++++++++++++-------- 4 files changed, 61 insertions(+), 22 deletions(-) diff --git a/api/types/auth_webhook.go b/api/types/auth_webhook.go index 841e672f4..cf340dd62 100644 --- a/api/types/auth_webhook.go +++ b/api/types/auth_webhook.go @@ -126,10 +126,27 @@ func NewAuthWebhookRequest(reader io.Reader) (*AuthWebhookRequest, error) { return req, nil } +// Code represents the result of an authentication webhook request. +type Code int + +const ( + // CodeOK indicates that the request is fully authenticated and has + // the necessary permissions. + CodeOK Code = 200 + + // CodeUnauthenticated indicates that the request does not have valid + // authentication credentials for the operation. + CodeUnauthenticated Code = 401 + + // CodePermissionDenied indicates that the authenticated request lacks + // the necessary permissions. + CodePermissionDenied Code = 403 +) + // AuthWebhookResponse represents the response of authentication webhook. type AuthWebhookResponse struct { - Allowed bool `json:"allowed"` - Reason string `json:"reason"` + Code Code `json:"code"` + Message string `json:"message"` } // NewAuthWebhookResponse creates a new instance of AuthWebhookResponse. diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index cb15a53c8..a25314a75 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -33,14 +33,20 @@ import ( ) var ( - // ErrNotAllowed is returned when the given user is not allowed for the access. - ErrNotAllowed = errors.New("method is not allowed for this user") + // ErrPermissionDenied is returned when the given user is not allowed for the access. + ErrPermissionDenied = errors.New("method is not allowed for this user") // ErrUnexpectedStatusCode is returned when the response code is not 200 from the webhook. ErrUnexpectedStatusCode = errors.New("unexpected status code from webhook") + // ErrUnexpectedResponse is returned when the response from the webhook is not as expected. + ErrUnexpectedResponse = errors.New("unexpected response from webhook") + // ErrWebhookTimeout is returned when the webhook does not respond in time. ErrWebhookTimeout = errors.New("webhook timeout") + + // ErrUnauthenticated is returned when the request lacks valid authentication credentials. + ErrUnauthenticated = errors.New("request lacks valid authentication credentials") ) // verifyAccess verifies the given user is allowed to access the given method. @@ -63,8 +69,8 @@ func verifyAccess( cacheKey := string(reqBody) if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok { resp := entry - if !resp.Allowed { - return fmt.Errorf("%s: %w", resp.Reason, ErrNotAllowed) + if resp.Code != types.CodeOK { + return fmt.Errorf("%s: %w", resp.Message, ErrPermissionDenied) } return nil } @@ -95,13 +101,19 @@ func verifyAccess( return resp.StatusCode, err } - if !authResp.Allowed { - return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrNotAllowed) + if authResp.Code == types.CodeOK { + return resp.StatusCode, nil + } + if authResp.Code == types.CodePermissionDenied { + return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrPermissionDenied) + } + if authResp.Code == types.CodeUnauthenticated { + return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrUnauthenticated) } - return resp.StatusCode, nil + return resp.StatusCode, fmt.Errorf("%d: %w", authResp.Code, ErrUnexpectedResponse) }); err != nil { - if errors.Is(err, ErrNotAllowed) { + if errors.Is(err, ErrPermissionDenied) { be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheUnauthTTL()) } @@ -120,7 +132,7 @@ func withExponentialBackoff(ctx context.Context, cfg *backend.Config, webhookFn statusCode, err := webhookFn() if !shouldRetry(statusCode, err) { if err == ErrUnexpectedStatusCode { - return fmt.Errorf("unexpected status code from webhook: %d", statusCode) + return fmt.Errorf("%d: %w", statusCode, ErrUnexpectedStatusCode) } return err diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index 8708b401b..39e8ef8ad 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -78,11 +78,15 @@ var errorToConnectCode = map[error]connect.Code{ converter.ErrUnsupportedCounterType: connect.CodeUnimplemented, // Unauthenticated means the request does not have valid authentication - auth.ErrNotAllowed: connect.CodeUnauthenticated, auth.ErrUnexpectedStatusCode: connect.CodeUnauthenticated, + auth.ErrUnexpectedResponse: connect.CodeUnauthenticated, auth.ErrWebhookTimeout: connect.CodeUnauthenticated, + auth.ErrUnauthenticated: connect.CodeUnauthenticated, database.ErrMismatchedPassword: connect.CodeUnauthenticated, + // PermissionDenied means the request does not have permission for the operation. + auth.ErrPermissionDenied: connect.CodePermissionDenied, + // Canceled means the operation was canceled (typically by the caller). context.Canceled: connect.CodeCanceled, } @@ -124,7 +128,9 @@ var errorToCode = map[error]string{ converter.ErrUnsupportedValueType: "ErrUnsupportedValueType", converter.ErrUnsupportedCounterType: "ErrUnsupportedCounterType", - auth.ErrNotAllowed: "ErrNotAllowed", + auth.ErrPermissionDenied: "ErrPermissionDenied", + auth.ErrUnauthenticated: "ErrUnauthenticated", + auth.ErrUnexpectedResponse: "ErrUnexpectedResponse", auth.ErrUnexpectedStatusCode: "ErrUnexpectedStatusCode", auth.ErrWebhookTimeout: "ErrWebhookTimeout", database.ErrMismatchedPassword: "ErrMismatchedPassword", diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index a37a6bf41..0b6a1f555 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -29,12 +29,15 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/yorkie-team/yorkie/api/converter" "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/client" "github.com/yorkie-team/yorkie/pkg/document" "github.com/yorkie-team/yorkie/pkg/document/json" "github.com/yorkie-team/yorkie/pkg/document/presence" "github.com/yorkie-team/yorkie/server" + "github.com/yorkie-team/yorkie/server/rpc/auth" + "github.com/yorkie-team/yorkie/server/rpc/connecthelper" "github.com/yorkie-team/yorkie/test/helper" ) @@ -47,9 +50,10 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { var res types.AuthWebhookResponse if req.Token == token { - res.Allowed = true + res.Code = types.CodeOK } else { - res.Reason = "invalid token" + res.Code = types.CodeUnauthenticated + res.Message = "invalid token" } _, err = res.Write(w) @@ -64,7 +68,7 @@ func newUnavailableAuthServer(t *testing.T, recoveryCnt uint64) *httptest.Server assert.NoError(t, err) var res types.AuthWebhookResponse - res.Allowed = true + res.Code = types.CodeOK if retries < recoveryCnt-1 { w.WriteHeader(http.StatusServiceUnavailable) retries++ @@ -186,8 +190,7 @@ func TestAuthWebhook(t *testing.T) { t.Run("authorization webhook that success after retries test", func(t *testing.T) { ctx := context.Background() - var recoveryCnt uint64 - recoveryCnt = 4 + var recoveryCnt uint64 = 4 authServer := newUnavailableAuthServer(t, recoveryCnt) conf := helper.TestConfig() @@ -264,6 +267,7 @@ func TestAuthWebhook(t *testing.T) { err = cli.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrWebhookTimeout), converter.ErrorCodeOf(err)) }) t.Run("authorized request cache test", func(t *testing.T) { @@ -274,7 +278,7 @@ func TestAuthWebhook(t *testing.T) { assert.NoError(t, err) var res types.AuthWebhookResponse - res.Allowed = true + res.Code = types.CodeOK _, err = res.Write(w) assert.NoError(t, err) @@ -352,7 +356,7 @@ func TestAuthWebhook(t *testing.T) { assert.NoError(t, err) var res types.AuthWebhookResponse - res.Allowed = false + res.Code = types.CodePermissionDenied _, err = res.Write(w) assert.NoError(t, err) @@ -394,14 +398,14 @@ func TestAuthWebhook(t *testing.T) { // 01. multiple requests. for i := 0; i < 3; i++ { err = cli.Activate(ctx) - assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) } // 02. multiple requests after eviction by ttl. time.Sleep(unauthorizedTTL) for i := 0; i < 3; i++ { err = cli.Activate(ctx) - assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) } assert.Equal(t, 2, reqCnt) }) From b73e95b69e8eebc03e4038a1dd91401b8bee937f Mon Sep 17 00:00:00 2001 From: Yourim Cha Date: Wed, 16 Oct 2024 18:12:28 +0900 Subject: [PATCH 2/9] Modify tests to handle new Code type in auth webhook response --- test/integration/auth_webhook_test.go | 269 ++++++++++++++++++++++---- 1 file changed, 226 insertions(+), 43 deletions(-) diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 0b6a1f555..8b1ed4742 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -51,6 +51,8 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { var res types.AuthWebhookResponse if req.Token == token { res.Code = types.CodeOK + } else if req.Token == "not allowed token" { + res.Code = types.CodePermissionDenied } else { res.Code = types.CodeUnauthenticated res.Message = "invalid token" @@ -62,22 +64,20 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { } func newUnavailableAuthServer(t *testing.T, recoveryCnt uint64) *httptest.Server { - var retries uint64 + var requestCount uint64 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := types.NewAuthWebhookRequest(r.Body) assert.NoError(t, err) var res types.AuthWebhookResponse res.Code = types.CodeOK - if retries < recoveryCnt-1 { + + if requestCount < recoveryCnt { w.WriteHeader(http.StatusServiceUnavailable) - retries++ - } else { - retries = 0 } - _, err = res.Write(w) assert.NoError(t, err) + requestCount++ })) } @@ -93,7 +93,7 @@ func TestProjectAuthWebhook(t *testing.T) { project, err := adminCli.CreateProject(context.Background(), "auth-webhook-test") assert.NoError(t, err) - t.Run("authorization webhook test", func(t *testing.T) { + t.Run("successful authorization test", func(t *testing.T) { ctx := context.Background() authServer, token := newAuthServer(t) @@ -121,6 +121,22 @@ func TestProjectAuthWebhook(t *testing.T) { doc := document.New(helper.TestDocKey(t)) assert.NoError(t, cli.Attach(ctx, doc)) + }) + + t.Run("unauthenticated response test", func(t *testing.T) { + ctx := context.Background() + authServer, _ := newAuthServer(t) + + // project with authorization webhook + project.AuthWebhookURL = authServer.URL + _, err := adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) // client without token cliWithoutToken, err := client.Dial( @@ -131,6 +147,7 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithoutToken.Close()) }() err = cliWithoutToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrUnauthenticated), converter.ErrorCodeOf(err)) // client with invalid token cliWithInvalidToken, err := client.Dial( @@ -142,9 +159,38 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithInvalidToken.Close()) }() err = cliWithInvalidToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrUnauthenticated), converter.ErrorCodeOf(err)) }) - t.Run("Selected method authorization webhook test", func(t *testing.T) { + t.Run("permission denied response test", func(t *testing.T) { + ctx := context.Background() + authServer, _ := newAuthServer(t) + + // project with authorization webhook + project.AuthWebhookURL = authServer.URL + _, err := adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + // client with not allowed token + cliNotAllowed, err := client.Dial( + svr.RPCAddr(), + client.WithAPIKey(project.PublicKey), + client.WithToken("not allowed token"), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cliNotAllowed.Close()) }() + err = cliNotAllowed.Activate(ctx) + assert.Equal(t, connect.CodePermissionDenied, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrPermissionDenied), converter.ErrorCodeOf(err)) + }) + + t.Run("selected method authorization webhook test", func(t *testing.T) { ctx := context.Background() authServer, _ := newAuthServer(t) @@ -186,25 +232,40 @@ func TestProjectAuthWebhook(t *testing.T) { }) } -func TestAuthWebhook(t *testing.T) { - t.Run("authorization webhook that success after retries test", func(t *testing.T) { +func TestAuthWebhookErrorHandling(t *testing.T) { + var recoveryCnt uint64 = 4 + + conf := helper.TestConfig() + conf.Backend.AuthWebhookMaxRetries = recoveryCnt + conf.Backend.AuthWebhookMaxWaitInterval = "1000ms" + svr, err := server.New(conf) + assert.NoError(t, err) + assert.NoError(t, svr.Start()) + defer func() { assert.NoError(t, svr.Shutdown(true)) }() + + adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) + defer func() { adminCli.Close() }() + + t.Run("unexpected status code test", func(t *testing.T) { ctx := context.Background() + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) - var recoveryCnt uint64 = 4 - authServer := newUnavailableAuthServer(t, recoveryCnt) + var res types.AuthWebhookResponse + res.Code = types.CodeOK - conf := helper.TestConfig() - conf.Backend.AuthWebhookMaxRetries = recoveryCnt - conf.Backend.AuthWebhookMaxWaitInterval = "1000ms" - svr, err := server.New(conf) - assert.NoError(t, err) - assert.NoError(t, svr.Start()) - defer func() { assert.NoError(t, svr.Shutdown(true)) }() + // unexpected status code + w.WriteHeader(http.StatusBadRequest) - adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) - defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "success-webhook-after-retries") + _, err = res.Write(w) + assert.NoError(t, err) + })) + + // project with authorization webhook + project, err := adminCli.CreateProject(context.Background(), "unexpected-status-code") assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( ctx, @@ -217,35 +278,61 @@ func TestAuthWebhook(t *testing.T) { cli, err := client.Dial( svr.RPCAddr(), - client.WithToken("token"), client.WithAPIKey(project.PublicKey), + client.WithToken("token"), ) assert.NoError(t, err) defer func() { assert.NoError(t, cli.Close()) }() - err = cli.Activate(ctx) + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedStatusCode), converter.ErrorCodeOf(err)) + }) + + t.Run("unexpected webhook response test", func(t *testing.T) { + ctx := context.Background() + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) + + var res types.AuthWebhookResponse + // unexpected response code + res.Code = 555 + + _, err = res.Write(w) + assert.NoError(t, err) + })) + + // project with authorization webhook + project, err := adminCli.CreateProject(context.Background(), "unexpected-response-code") assert.NoError(t, err) - doc := document.New(helper.TestDocKey(t)) - err = cli.Attach(ctx, doc) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithAPIKey(project.PublicKey), + client.WithToken("token"), + ) assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + err = cli.Activate(ctx) + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedResponse), converter.ErrorCodeOf(err)) }) - t.Run("authorization webhook that fails after retries test", func(t *testing.T) { + t.Run("unavailable authentication server test(timeout)", func(t *testing.T) { ctx := context.Background() - authServer := newUnavailableAuthServer(t, 4) + authServer := newUnavailableAuthServer(t, recoveryCnt+1) - conf := helper.TestConfig() - conf.Backend.AuthWebhookMaxRetries = 2 - conf.Backend.AuthWebhookMaxWaitInterval = "1000ms" - svr, err := server.New(conf) - assert.NoError(t, err) - assert.NoError(t, svr.Start()) - defer func() { assert.NoError(t, svr.Shutdown(true)) }() - - adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) - defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "fail-webhook-after-retries") + project, err := adminCli.CreateProject(context.Background(), "unavailable-auth-server") assert.NoError(t, err) project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( @@ -270,7 +357,41 @@ func TestAuthWebhook(t *testing.T) { assert.Equal(t, connecthelper.CodeOf(auth.ErrWebhookTimeout), converter.ErrorCodeOf(err)) }) - t.Run("authorized request cache test", func(t *testing.T) { + t.Run("successful authorization after temporarily unavailable server test", func(t *testing.T) { + ctx := context.Background() + authServer := newUnavailableAuthServer(t, recoveryCnt) + + project, err := adminCli.CreateProject(context.Background(), "success-webhook-after-retries") + assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithToken("token"), + client.WithAPIKey(project.PublicKey), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + err = cli.Activate(ctx) + assert.NoError(t, err) + + doc := document.New(helper.TestDocKey(t)) + err = cli.Attach(ctx, doc) + assert.NoError(t, err) + }) +} + +func TestAuthWebhookCache(t *testing.T) { + t.Run("authorized response cache test", func(t *testing.T) { ctx := context.Background() reqCnt := 0 authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -299,7 +420,7 @@ func TestAuthWebhook(t *testing.T) { adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "auth-request-cache") + project, err := adminCli.CreateProject(context.Background(), "authorized-response-cache") assert.NoError(t, err) project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( @@ -348,7 +469,7 @@ func TestAuthWebhook(t *testing.T) { assert.Equal(t, 2, reqCnt) }) - t.Run("unauthorized request cache test", func(t *testing.T) { + t.Run("permission denied response cache test", func(t *testing.T) { ctx := context.Background() reqCnt := 0 authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -375,7 +496,7 @@ func TestAuthWebhook(t *testing.T) { adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) defer func() { adminCli.Close() }() - project, err := adminCli.CreateProject(context.Background(), "unauth-request-cache") + project, err := adminCli.CreateProject(context.Background(), "permission-denied-cache") assert.NoError(t, err) project.AuthWebhookURL = authServer.URL _, err = adminCli.UpdateProject( @@ -409,4 +530,66 @@ func TestAuthWebhook(t *testing.T) { } assert.Equal(t, 2, reqCnt) }) + + t.Run("other response not cached test", func(t *testing.T) { + ctx := context.Background() + reqCnt := 0 + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := types.NewAuthWebhookRequest(r.Body) + assert.NoError(t, err) + + var res types.AuthWebhookResponse + res.Code = types.CodeUnauthenticated + + _, err = res.Write(w) + assert.NoError(t, err) + + reqCnt++ + })) + + unauthorizedTTL := 1 * time.Second + conf := helper.TestConfig() + conf.Backend.AuthWebhookCacheUnauthTTL = unauthorizedTTL.String() + + svr, err := server.New(conf) + assert.NoError(t, err) + assert.NoError(t, svr.Start()) + defer func() { assert.NoError(t, svr.Shutdown(true)) }() + + adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) + defer func() { adminCli.Close() }() + project, err := adminCli.CreateProject(context.Background(), "other-response-not-cached") + assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithToken("token"), + client.WithAPIKey(project.PublicKey), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + // 01. multiple requests. + for i := 0; i < 3; i++ { + err = cli.Activate(ctx) + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + } + + // 02. multiple requests after eviction by ttl. + time.Sleep(unauthorizedTTL) + for i := 0; i < 3; i++ { + err = cli.Activate(ctx) + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + } + assert.Equal(t, 6, reqCnt) + }) } From 411d1c9258b29f3d72f416506743850c948cdbff Mon Sep 17 00:00:00 2001 From: Yourim Cha Date: Fri, 18 Oct 2024 14:02:20 +0900 Subject: [PATCH 3/9] Introduce RichError struct to extend error responses with metadata for Auth webhook --- api/converter/errors.go | 19 ++++++++++++ internal/richerror/richerror.go | 28 +++++++++++++++++ server/rpc/auth/webhook.go | 7 ++++- server/rpc/connecthelper/status.go | 43 +++++++++++++++++++++++++++ test/integration/auth_webhook_test.go | 7 +++-- 5 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 internal/richerror/richerror.go diff --git a/api/converter/errors.go b/api/converter/errors.go index 69ec75f4e..a89b0bfb1 100644 --- a/api/converter/errors.go +++ b/api/converter/errors.go @@ -25,3 +25,22 @@ func ErrorCodeOf(err error) string { } return "" } + +// ErrorMetadataOf returns the error metadata of the given error. +func ErrorMetadataOf(err error) map[string]string { + var connectErr *connect.Error + if !errors.As(err, &connectErr) { + return nil + } + for _, detail := range connectErr.Details() { + msg, valueErr := detail.Value() + if valueErr != nil { + continue + } + + if errorInfo, ok := msg.(*errdetails.ErrorInfo); ok { + return errorInfo.GetMetadata() + } + } + return nil +} diff --git a/internal/richerror/richerror.go b/internal/richerror/richerror.go new file mode 100644 index 000000000..b8e8f9fe6 --- /dev/null +++ b/internal/richerror/richerror.go @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package richerror provides a rich error type that can be used to wrap errors +package richerror + +// RichError is an error type that can be used to wrap errors with additional metadata +type RichError struct { + Err error + Metadata map[string]string +} + +func (e RichError) Error() string { + return e.Err.Error() +} diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index a25314a75..d9f4a198e 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -28,6 +28,7 @@ import ( "time" "github.com/yorkie-team/yorkie/api/types" + "github.com/yorkie-team/yorkie/internal/richerror" "github.com/yorkie-team/yorkie/server/backend" "github.com/yorkie-team/yorkie/server/logging" ) @@ -108,7 +109,11 @@ func verifyAccess( return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrPermissionDenied) } if authResp.Code == types.CodeUnauthenticated { - return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrUnauthenticated) + richError := &richerror.RichError{ + Err: ErrUnauthenticated, + Metadata: map[string]string{"message": authResp.Message}, + } + return resp.StatusCode, richError } return resp.StatusCode, fmt.Errorf("%d: %w", authResp.Code, ErrUnexpectedResponse) diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index 39e8ef8ad..bc9ce3ca8 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -26,6 +26,7 @@ import ( "github.com/yorkie-team/yorkie/api/converter" "github.com/yorkie-team/yorkie/api/types" + "github.com/yorkie-team/yorkie/internal/richerror" "github.com/yorkie-team/yorkie/internal/validation" "github.com/yorkie-team/yorkie/pkg/document/key" "github.com/yorkie-team/yorkie/pkg/document/time" @@ -185,6 +186,44 @@ func errorToConnectError(err error) (*connect.Error, bool) { return connectErr, true } +// richErrorToConnectError returns connect.Error from the given rich error. +func richErrorToConnectError(err error) (*connect.Error, bool) { + var richError *richerror.RichError + if !errors.As(err, &richError) { + return nil, false + } + + // NOTE(hackerwins): This prevents panic when the cause is an unhashable + // error. + var connectCode connect.Code + var ok bool + defer func() { + if r := recover(); r != nil { + ok = false + } + }() + + connectCode, ok = errorToConnectCode[richError.Err] + if !ok { + return nil, false + } + + connectErr := connect.NewError(connectCode, err) + if code, ok := errorToCode[richError.Err]; ok { + errorInfo := &errdetails.ErrorInfo{ + Metadata: map[string]string{"code": code}, + } + for key, value := range richError.Metadata { + errorInfo.Metadata[key] = value + } + if detail, detailErr := connect.NewErrorDetail(errorInfo); detailErr == nil { + connectErr.AddDetail(detail) + } + } + + return connectErr, true +} + // structErrorToConnectError returns connect.Error from the given struct error. func structErrorToConnectError(err error) (*connect.Error, bool) { var invalidFieldsError *validation.StructError @@ -231,6 +270,10 @@ func ToStatusError(err error) error { return nil } + if err, ok := richErrorToConnectError(err); ok { + return err + } + if err, ok := errorToConnectError(err); ok { return err } diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 8b1ed4742..6126ee68a 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -53,6 +53,9 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { res.Code = types.CodeOK } else if req.Token == "not allowed token" { res.Code = types.CodePermissionDenied + } else if req.Token == "" { + res.Code = types.CodeUnauthenticated + res.Message = "no token" } else { res.Code = types.CodeUnauthenticated res.Message = "invalid token" @@ -147,7 +150,7 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithoutToken.Close()) }() err = cliWithoutToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) - assert.Equal(t, connecthelper.CodeOf(auth.ErrUnauthenticated), converter.ErrorCodeOf(err)) + assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "message": "no token"}, converter.ErrorMetadataOf(err)) // client with invalid token cliWithInvalidToken, err := client.Dial( @@ -159,7 +162,7 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithInvalidToken.Close()) }() err = cliWithInvalidToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) - assert.Equal(t, connecthelper.CodeOf(auth.ErrUnauthenticated), converter.ErrorCodeOf(err)) + assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "message": "invalid token"}, converter.ErrorMetadataOf(err)) }) t.Run("permission denied response test", func(t *testing.T) { From 148ac3791bcf64f71b980032ab7c2c71bb4cb00e Mon Sep 17 00:00:00 2001 From: Youngteac Hong Date: Thu, 31 Oct 2024 15:23:41 +0900 Subject: [PATCH 4/9] Rename StructError to FormError --- api/types/updatable_project_fields_test.go | 16 ++++++++-------- api/types/user_fields_test.go | 16 ++++++++-------- internal/validation/validation.go | 12 ++++++------ internal/validation/validation_test.go | 4 ++-- server/rpc/connecthelper/status.go | 10 +++++----- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/api/types/updatable_project_fields_test.go b/api/types/updatable_project_fields_test.go index cba41401d..e8e61263b 100644 --- a/api/types/updatable_project_fields_test.go +++ b/api/types/updatable_project_fields_test.go @@ -26,7 +26,7 @@ import ( ) func TestUpdatableProjectFields(t *testing.T) { - var structError *validation.StructError + var formErr *validation.FormError t.Run("validation test", func(t *testing.T) { newName := "changed-name" newAuthWebhookURL := "http://localhost:3000" @@ -68,7 +68,7 @@ func TestUpdatableProjectFields(t *testing.T) { AuthWebhookMethods: &newAuthWebhookMethods, ClientDeactivateThreshold: &newClientDeactivateThreshold, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) }) t.Run("project name format test", func(t *testing.T) { @@ -82,36 +82,36 @@ func TestUpdatableProjectFields(t *testing.T) { fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) reservedName := "new" fields = &types.UpdatableProjectFields{ Name: &reservedName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) reservedName = "default" fields = &types.UpdatableProjectFields{ Name: &reservedName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidName = "1" fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidName = "over_30_chracaters_is_invalid_name" fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidName = "invalid/name" fields = &types.UpdatableProjectFields{ Name: &invalidName, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) }) } diff --git a/api/types/user_fields_test.go b/api/types/user_fields_test.go index 15e30339a..e60aae49b 100644 --- a/api/types/user_fields_test.go +++ b/api/types/user_fields_test.go @@ -26,7 +26,7 @@ import ( ) func TestSignupFields(t *testing.T) { - var structError *validation.StructError + var formErr *validation.FormError t.Run("password validation test", func(t *testing.T) { validUsername := "test" @@ -42,48 +42,48 @@ func TestSignupFields(t *testing.T) { Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "!@#$" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd1234" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd!@#$" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "1234!@#$" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) invalidPassword = "abcd1234!@abcd1234!@abcd1234!@1" fields = &types.UserFields{ Username: &validUsername, Password: &invalidPassword, } - assert.ErrorAs(t, fields.Validate(), &structError) + assert.ErrorAs(t, fields.Validate(), &formErr) }) } diff --git a/internal/validation/validation.go b/internal/validation/validation.go index fbf4cbd01..3c802c7df 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -119,13 +119,13 @@ func (e Violation) Error() string { return e.Err.Error() } -// StructError is the error returned by the validation of struct. -type StructError struct { +// FormError represents the error of the form validation. +type FormError struct { Violations []Violation } // Error returns the error message. -func (s StructError) Error() string { +func (s FormError) Error() string { sb := strings.Builder{} for _, v := range s.Violations { @@ -223,16 +223,16 @@ func Validate(v string, tagOrRules []interface{}) error { // ValidateStruct validates the struct func ValidateStruct(s interface{}) error { if err := defaultValidator.Struct(s); err != nil { - structError := &StructError{} + formErr := &FormError{} for _, e := range err.(validator.ValidationErrors) { - structError.Violations = append(structError.Violations, Violation{ + formErr.Violations = append(formErr.Violations, Violation{ Tag: e.Tag(), Field: e.StructField(), Err: e, Description: e.Translate(trans), }) } - return structError + return formErr } return nil diff --git a/internal/validation/validation_test.go b/internal/validation/validation_test.go index be513e21d..9fe6bcea4 100644 --- a/internal/validation/validation_test.go +++ b/internal/validation/validation_test.go @@ -61,8 +61,8 @@ func TestValidation(t *testing.T) { user := User{Name: "invalid-key-$-wrong-string-value", Country: "korea"} err := ValidateStruct(user) - structError := err.(*StructError) - assert.Len(t, structError.Violations, 2, "user should be invalid") + formErr := err.(*FormError) + assert.Len(t, formErr.Violations, 2, "user should be invalid") }) t.Run("custom rule test", func(t *testing.T) { diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index bc9ce3ca8..1dbbabd94 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -224,9 +224,9 @@ func richErrorToConnectError(err error) (*connect.Error, bool) { return connectErr, true } -// structErrorToConnectError returns connect.Error from the given struct error. -func structErrorToConnectError(err error) (*connect.Error, bool) { - var invalidFieldsError *validation.StructError +// formErrorToConnectError returns connect.Error from the given form error. +func formErrorToConnectError(err error) (*connect.Error, bool) { + var invalidFieldsError *validation.FormError if !errors.As(err, &invalidFieldsError) { return nil, false } @@ -244,7 +244,7 @@ func structErrorToConnectError(err error) (*connect.Error, bool) { } func badRequestFromError(err error) (*errdetails.BadRequest, bool) { - var invalidFieldsError *validation.StructError + var invalidFieldsError *validation.FormError if !errors.As(err, &invalidFieldsError) { return nil, false } @@ -278,7 +278,7 @@ func ToStatusError(err error) error { return err } - if err, ok := structErrorToConnectError(err); ok { + if err, ok := formErrorToConnectError(err); ok { return err } From 1bdbe91fc38f3b36f449925fb648c17ac5c47132 Mon Sep 17 00:00:00 2001 From: Youngteac Hong Date: Thu, 31 Oct 2024 15:54:31 +0900 Subject: [PATCH 5/9] Rename RichError to MetaError --- internal/metaerrors/metaerrors.go | 59 ++++++++++++++++++++++++++ internal/metaerrors/metaerrors_test.go | 57 +++++++++++++++++++++++++ internal/richerror/richerror.go | 28 ------------ internal/validation/validation.go | 2 +- server/rpc/auth/webhook.go | 11 +++-- server/rpc/connecthelper/status.go | 18 ++++---- 6 files changed, 131 insertions(+), 44 deletions(-) create mode 100644 internal/metaerrors/metaerrors.go create mode 100644 internal/metaerrors/metaerrors_test.go delete mode 100644 internal/richerror/richerror.go diff --git a/internal/metaerrors/metaerrors.go b/internal/metaerrors/metaerrors.go new file mode 100644 index 000000000..9fce8d7a1 --- /dev/null +++ b/internal/metaerrors/metaerrors.go @@ -0,0 +1,59 @@ +/* + * Copyright 2024 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package metaerrors provides a way to attach metadata to errors. +package metaerrors + +import "strings" + +// MetaError is an error that can have metadata attached to it. This can be used +// to send additional information to the SDK or to the user. +type MetaError struct { + // Err is the underlying error. + Err error + + // Metadata is a map of additional information that can be attached to the + // error. + Metadata map[string]string +} + +// New returns a new MetaError with the given error and metadata. +func New(err error, metadata map[string]string) *MetaError { + return &MetaError{ + Err: err, + Metadata: metadata, + } +} + +// Error returns the error message. +func (e MetaError) Error() string { + if len(e.Metadata) == 0 { + return e.Err.Error() + } + + sb := strings.Builder{} + + for key, val := range e.Metadata { + if sb.Len() > 0 { + sb.WriteString(",") + } + sb.WriteString(key) + sb.WriteString("=") + sb.WriteString(val) + } + + return e.Err.Error() + " [" + sb.String() + "]" +} diff --git a/internal/metaerrors/metaerrors_test.go b/internal/metaerrors/metaerrors_test.go new file mode 100644 index 000000000..948d92aa4 --- /dev/null +++ b/internal/metaerrors/metaerrors_test.go @@ -0,0 +1,57 @@ +/* + * Copyright 2024 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package metaerrors_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/yorkie-team/yorkie/internal/metaerrors" +) + +func TestMetaError(t *testing.T) { + t.Run("test meta error", func(t *testing.T) { + err := errors.New("error message") + metaErr := metaerrors.New(err, map[string]string{"key": "value"}) + assert.Equal(t, "error message [key=value]", metaErr.Error()) + + err = errors.New("error message") + metaErr = metaerrors.New(err, map[string]string{"key1": "value1", "key2": "value2"}) + assert.Equal(t, "error message [key1=value1,key2=value2]", metaErr.Error()) + }) + + t.Run("test meta error without metadata", func(t *testing.T) { + err := errors.New("error message") + metaErr := metaerrors.New(err, nil) + assert.Equal(t, "error message", metaErr.Error()) + }) + + t.Run("test meta error with wrapped error", func(t *testing.T) { + err := fmt.Errorf("wrapped error: %w", errors.New("error message")) + metaErr := metaerrors.New(err, map[string]string{"key": "value"}) + assert.Equal(t, "wrapped error: error message [key=value]", metaErr.Error()) + + metaErr = metaerrors.New(errors.New("error message"), map[string]string{"key": "value"}) + assert.Equal(t, "error message [key=value]", metaErr.Error()) + + wrappedErr := fmt.Errorf("wrapped error: %w", metaErr) + assert.Equal(t, "wrapped error: error message [key=value]", wrappedErr.Error()) + }) +} diff --git a/internal/richerror/richerror.go b/internal/richerror/richerror.go deleted file mode 100644 index b8e8f9fe6..000000000 --- a/internal/richerror/richerror.go +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2024 The Yorkie Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package richerror provides a rich error type that can be used to wrap errors -package richerror - -// RichError is an error type that can be used to wrap errors with additional metadata -type RichError struct { - Err error - Metadata map[string]string -} - -func (e RichError) Error() string { - return e.Err.Error() -} diff --git a/internal/validation/validation.go b/internal/validation/validation.go index 3c802c7df..e5ad41d6e 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -14,7 +14,7 @@ * limitations under the License. */ -// Package validation provides the validation functions. +// Package validation provides the validation functions for form and field. package validation import ( diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index d9f4a198e..a3b2ebd54 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -28,7 +28,7 @@ import ( "time" "github.com/yorkie-team/yorkie/api/types" - "github.com/yorkie-team/yorkie/internal/richerror" + "github.com/yorkie-team/yorkie/internal/metaerrors" "github.com/yorkie-team/yorkie/server/backend" "github.com/yorkie-team/yorkie/server/logging" ) @@ -109,11 +109,10 @@ func verifyAccess( return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrPermissionDenied) } if authResp.Code == types.CodeUnauthenticated { - richError := &richerror.RichError{ - Err: ErrUnauthenticated, - Metadata: map[string]string{"message": authResp.Message}, - } - return resp.StatusCode, richError + return resp.StatusCode, metaerrors.New( + ErrUnauthenticated, + map[string]string{"message": authResp.Message}, + ) } return resp.StatusCode, fmt.Errorf("%d: %w", authResp.Code, ErrUnexpectedResponse) diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index 1dbbabd94..1eef3306e 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -26,7 +26,7 @@ import ( "github.com/yorkie-team/yorkie/api/converter" "github.com/yorkie-team/yorkie/api/types" - "github.com/yorkie-team/yorkie/internal/richerror" + "github.com/yorkie-team/yorkie/internal/metaerrors" "github.com/yorkie-team/yorkie/internal/validation" "github.com/yorkie-team/yorkie/pkg/document/key" "github.com/yorkie-team/yorkie/pkg/document/time" @@ -186,10 +186,10 @@ func errorToConnectError(err error) (*connect.Error, bool) { return connectErr, true } -// richErrorToConnectError returns connect.Error from the given rich error. -func richErrorToConnectError(err error) (*connect.Error, bool) { - var richError *richerror.RichError - if !errors.As(err, &richError) { +// metaErrorToConnectError returns connect.Error from the given rich error. +func metaErrorToConnectError(err error) (*connect.Error, bool) { + var metaErr *metaerrors.MetaError + if !errors.As(err, &metaErr) { return nil, false } @@ -203,17 +203,17 @@ func richErrorToConnectError(err error) (*connect.Error, bool) { } }() - connectCode, ok = errorToConnectCode[richError.Err] + connectCode, ok = errorToConnectCode[metaErr.Err] if !ok { return nil, false } connectErr := connect.NewError(connectCode, err) - if code, ok := errorToCode[richError.Err]; ok { + if code, ok := errorToCode[metaErr.Err]; ok { errorInfo := &errdetails.ErrorInfo{ Metadata: map[string]string{"code": code}, } - for key, value := range richError.Metadata { + for key, value := range metaErr.Metadata { errorInfo.Metadata[key] = value } if detail, detailErr := connect.NewErrorDetail(errorInfo); detailErr == nil { @@ -270,7 +270,7 @@ func ToStatusError(err error) error { return nil } - if err, ok := richErrorToConnectError(err); ok { + if err, ok := metaErrorToConnectError(err); ok { return err } From fd417693e3756884c85a35ffda0ac60f0f10f9a4 Mon Sep 17 00:00:00 2001 From: Youngteac Hong Date: Thu, 31 Oct 2024 16:17:44 +0900 Subject: [PATCH 6/9] Adjust status codes of webhook errors --- internal/metaerrors/metaerrors_test.go | 4 ---- server/rpc/auth/webhook.go | 6 +++--- server/rpc/connecthelper/status.go | 8 +++++--- test/integration/auth_webhook_test.go | 6 +++--- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/internal/metaerrors/metaerrors_test.go b/internal/metaerrors/metaerrors_test.go index 948d92aa4..70a3c286c 100644 --- a/internal/metaerrors/metaerrors_test.go +++ b/internal/metaerrors/metaerrors_test.go @@ -31,10 +31,6 @@ func TestMetaError(t *testing.T) { err := errors.New("error message") metaErr := metaerrors.New(err, map[string]string{"key": "value"}) assert.Equal(t, "error message [key=value]", metaErr.Error()) - - err = errors.New("error message") - metaErr = metaerrors.New(err, map[string]string{"key1": "value1", "key2": "value2"}) - assert.Equal(t, "error message [key1=value1,key2=value2]", metaErr.Error()) }) t.Run("test meta error without metadata", func(t *testing.T) { diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index a3b2ebd54..a58359397 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -34,6 +34,9 @@ import ( ) var ( + // ErrUnauthenticated is returned when the authentication is failed. + ErrUnauthenticated = errors.New("unauthenticated") + // ErrPermissionDenied is returned when the given user is not allowed for the access. ErrPermissionDenied = errors.New("method is not allowed for this user") @@ -45,9 +48,6 @@ var ( // ErrWebhookTimeout is returned when the webhook does not respond in time. ErrWebhookTimeout = errors.New("webhook timeout") - - // ErrUnauthenticated is returned when the request lacks valid authentication credentials. - ErrUnauthenticated = errors.New("request lacks valid authentication credentials") ) // verifyAccess verifies the given user is allowed to access the given method. diff --git a/server/rpc/connecthelper/status.go b/server/rpc/connecthelper/status.go index 1eef3306e..b56fd161c 100644 --- a/server/rpc/connecthelper/status.go +++ b/server/rpc/connecthelper/status.go @@ -79,12 +79,14 @@ var errorToConnectCode = map[error]connect.Code{ converter.ErrUnsupportedCounterType: connect.CodeUnimplemented, // Unauthenticated means the request does not have valid authentication - auth.ErrUnexpectedStatusCode: connect.CodeUnauthenticated, - auth.ErrUnexpectedResponse: connect.CodeUnauthenticated, - auth.ErrWebhookTimeout: connect.CodeUnauthenticated, auth.ErrUnauthenticated: connect.CodeUnauthenticated, database.ErrMismatchedPassword: connect.CodeUnauthenticated, + // Internal means an internal error occurred. + auth.ErrUnexpectedStatusCode: connect.CodeInternal, + auth.ErrUnexpectedResponse: connect.CodeInternal, + auth.ErrWebhookTimeout: connect.CodeInternal, + // PermissionDenied means the request does not have permission for the operation. auth.ErrPermissionDenied: connect.CodePermissionDenied, diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 6126ee68a..f20c28604 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -287,7 +287,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { assert.NoError(t, err) defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) - assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedStatusCode), converter.ErrorCodeOf(err)) }) @@ -327,7 +327,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { assert.NoError(t, err) defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) - assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) assert.Equal(t, connecthelper.CodeOf(auth.ErrUnexpectedResponse), converter.ErrorCodeOf(err)) }) @@ -356,7 +356,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) - assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) assert.Equal(t, connecthelper.CodeOf(auth.ErrWebhookTimeout), converter.ErrorCodeOf(err)) }) From 51d0adac49d836e5a970d29ea396e5f5056020bb Mon Sep 17 00:00:00 2001 From: Yourim Cha Date: Fri, 1 Nov 2024 15:45:39 +0900 Subject: [PATCH 7/9] Add SetToken to update token when token is invalid --- client/client.go | 16 ++++++++++ client/options.go | 3 ++ test/integration/auth_webhook_test.go | 46 +++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/client/client.go b/client/client.go index 46b39d925..3d8e4a466 100644 --- a/client/client.go +++ b/client/client.go @@ -183,6 +183,7 @@ func Dial(rpcAddr string, opts ...Option) (*Client, error) { return nil, err } + cli.options.RPCAddress = rpcAddr if err := cli.Dial(rpcAddr); err != nil { return nil, err } @@ -205,6 +206,21 @@ func (c *Client) Dial(rpcAddr string) error { return nil } +// SetToken updates the client's token for reauthentication purposes. +func (c *Client) SetToken(token string) error { + newClientOptions := []connect.ClientOption{ + connect.WithInterceptors(NewAuthInterceptor(c.options.APIKey, token)), + } + if c.options.MaxCallRecvMsgSize != 0 { + newClientOptions = append(newClientOptions, + connect.WithReadMaxBytes(c.options.MaxCallRecvMsgSize)) + } + c.clientOptions = newClientOptions + + c.conn.CloseIdleConnections() + return c.Dial(c.options.RPCAddress) +} + // Close closes all resources of this client. func (c *Client) Close() error { if err := c.Deactivate(context.Background()); err != nil { diff --git a/client/options.go b/client/options.go index fcb03b6a9..03799cd8f 100644 --- a/client/options.go +++ b/client/options.go @@ -41,6 +41,9 @@ type Options struct { // CertFile is the path to the certificate file. CertFile string + // RPCAddress is the address of the RPC server. + RPCAddress string + // ServerNameOverride is the server name override. ServerNameOverride string diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index f20c28604..16d663a20 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -596,3 +596,49 @@ func TestAuthWebhookCache(t *testing.T) { assert.Equal(t, 6, reqCnt) }) } + +func TestAuthWebhookNewToken(t *testing.T) { + t.Run("reactivate with new token when receiving invalid token test", func(t *testing.T) { + ctx := context.Background() + authServer, validToken := newAuthServer(t) + + svr, err := server.New(helper.TestConfig()) + assert.NoError(t, err) + assert.NoError(t, svr.Start()) + defer func() { assert.NoError(t, svr.Shutdown(true)) }() + + adminCli := helper.CreateAdminCli(t, svr.RPCAddr()) + defer func() { adminCli.Close() }() + project, err := adminCli.CreateProject(context.Background(), "new-auth-token") + assert.NoError(t, err) + project.AuthWebhookURL = authServer.URL + _, err = adminCli.UpdateProject( + ctx, + project.ID.String(), + &types.UpdatableProjectFields{ + AuthWebhookURL: &project.AuthWebhookURL, + }, + ) + assert.NoError(t, err) + + cli, err := client.Dial( + svr.RPCAddr(), + client.WithToken("invalid"), + client.WithAPIKey(project.PublicKey), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, cli.Close()) }() + + err = cli.Activate(ctx) + // reactivate with new token + if err != nil { + metadata := converter.ErrorMetadataOf(err) + if metadata["message"] == "invalid token" { + err = cli.SetToken(validToken) + assert.NoError(t, err) + err = cli.Activate(ctx) + assert.NoError(t, err) + } + } + }) +} From 035634b5f7286c97e19156b09832e8e0700d6f84 Mon Sep 17 00:00:00 2001 From: Yourim Cha Date: Fri, 1 Nov 2024 16:29:21 +0900 Subject: [PATCH 8/9] Move authorized status codes to HTTP response header --- api/types/auth_webhook.go | 21 ++------------ server/rpc/auth/webhook.go | 20 ++++++++------ test/integration/auth_webhook_test.go | 40 +++++++++++++++------------ 3 files changed, 36 insertions(+), 45 deletions(-) diff --git a/api/types/auth_webhook.go b/api/types/auth_webhook.go index cf340dd62..841e672f4 100644 --- a/api/types/auth_webhook.go +++ b/api/types/auth_webhook.go @@ -126,27 +126,10 @@ func NewAuthWebhookRequest(reader io.Reader) (*AuthWebhookRequest, error) { return req, nil } -// Code represents the result of an authentication webhook request. -type Code int - -const ( - // CodeOK indicates that the request is fully authenticated and has - // the necessary permissions. - CodeOK Code = 200 - - // CodeUnauthenticated indicates that the request does not have valid - // authentication credentials for the operation. - CodeUnauthenticated Code = 401 - - // CodePermissionDenied indicates that the authenticated request lacks - // the necessary permissions. - CodePermissionDenied Code = 403 -) - // AuthWebhookResponse represents the response of authentication webhook. type AuthWebhookResponse struct { - Code Code `json:"code"` - Message string `json:"message"` + Allowed bool `json:"allowed"` + Reason string `json:"reason"` } // NewAuthWebhookResponse creates a new instance of AuthWebhookResponse. diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index a58359397..ec997c316 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -70,8 +70,8 @@ func verifyAccess( cacheKey := string(reqBody) if entry, ok := be.AuthWebhookCache.Get(cacheKey); ok { resp := entry - if resp.Code != types.CodeOK { - return fmt.Errorf("%s: %w", resp.Message, ErrPermissionDenied) + if !resp.Allowed { + return fmt.Errorf("%s: %w", resp.Reason, ErrPermissionDenied) } return nil } @@ -93,7 +93,9 @@ func verifyAccess( } }() - if http.StatusOK != resp.StatusCode { + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusUnauthorized && + resp.StatusCode != http.StatusForbidden { return resp.StatusCode, ErrUnexpectedStatusCode } @@ -102,20 +104,20 @@ func verifyAccess( return resp.StatusCode, err } - if authResp.Code == types.CodeOK { + if resp.StatusCode == http.StatusOK && authResp.Allowed { return resp.StatusCode, nil } - if authResp.Code == types.CodePermissionDenied { - return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Message, ErrPermissionDenied) + if resp.StatusCode == http.StatusForbidden && !authResp.Allowed { + return resp.StatusCode, fmt.Errorf("%s: %w", authResp.Reason, ErrPermissionDenied) } - if authResp.Code == types.CodeUnauthenticated { + if resp.StatusCode == http.StatusUnauthorized && !authResp.Allowed { return resp.StatusCode, metaerrors.New( ErrUnauthenticated, - map[string]string{"message": authResp.Message}, + map[string]string{"reason": authResp.Reason}, ) } - return resp.StatusCode, fmt.Errorf("%d: %w", authResp.Code, ErrUnexpectedResponse) + return resp.StatusCode, fmt.Errorf("%d: %w", resp.StatusCode, ErrUnexpectedResponse) }); err != nil { if errors.Is(err, ErrPermissionDenied) { be.AuthWebhookCache.Add(cacheKey, authResp, be.Config.ParseAuthWebhookCacheUnauthTTL()) diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 16d663a20..4dc405812 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -50,15 +50,19 @@ func newAuthServer(t *testing.T) (*httptest.Server, string) { var res types.AuthWebhookResponse if req.Token == token { - res.Code = types.CodeOK + w.WriteHeader(http.StatusOK) // 200 + res.Allowed = true } else if req.Token == "not allowed token" { - res.Code = types.CodePermissionDenied + w.WriteHeader(http.StatusForbidden) // 403 + res.Allowed = false } else if req.Token == "" { - res.Code = types.CodeUnauthenticated - res.Message = "no token" + w.WriteHeader(http.StatusUnauthorized) // 401 + res.Allowed = false + res.Reason = "no token" } else { - res.Code = types.CodeUnauthenticated - res.Message = "invalid token" + w.WriteHeader(http.StatusUnauthorized) // 401 + res.Allowed = false + res.Reason = "invalid token" } _, err = res.Write(w) @@ -73,7 +77,7 @@ func newUnavailableAuthServer(t *testing.T, recoveryCnt uint64) *httptest.Server assert.NoError(t, err) var res types.AuthWebhookResponse - res.Code = types.CodeOK + res.Allowed = true if requestCount < recoveryCnt { w.WriteHeader(http.StatusServiceUnavailable) @@ -150,7 +154,7 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithoutToken.Close()) }() err = cliWithoutToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) - assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "message": "no token"}, converter.ErrorMetadataOf(err)) + assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "reason": "no token"}, converter.ErrorMetadataOf(err)) // client with invalid token cliWithInvalidToken, err := client.Dial( @@ -162,7 +166,7 @@ func TestProjectAuthWebhook(t *testing.T) { defer func() { assert.NoError(t, cliWithInvalidToken.Close()) }() err = cliWithInvalidToken.Activate(ctx) assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) - assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "message": "invalid token"}, converter.ErrorMetadataOf(err)) + assert.Equal(t, map[string]string{"code": connecthelper.CodeOf(auth.ErrUnauthenticated), "reason": "invalid token"}, converter.ErrorMetadataOf(err)) }) t.Run("permission denied response test", func(t *testing.T) { @@ -256,7 +260,7 @@ func TestAuthWebhookErrorHandling(t *testing.T) { assert.NoError(t, err) var res types.AuthWebhookResponse - res.Code = types.CodeOK + res.Allowed = true // unexpected status code w.WriteHeader(http.StatusBadRequest) @@ -298,8 +302,8 @@ func TestAuthWebhookErrorHandling(t *testing.T) { assert.NoError(t, err) var res types.AuthWebhookResponse - // unexpected response code - res.Code = 555 + // mismatched response + res.Allowed = false _, err = res.Write(w) assert.NoError(t, err) @@ -402,7 +406,7 @@ func TestAuthWebhookCache(t *testing.T) { assert.NoError(t, err) var res types.AuthWebhookResponse - res.Code = types.CodeOK + res.Allowed = true _, err = res.Write(w) assert.NoError(t, err) @@ -479,8 +483,9 @@ func TestAuthWebhookCache(t *testing.T) { _, err := types.NewAuthWebhookRequest(r.Body) assert.NoError(t, err) + w.WriteHeader(http.StatusForbidden) var res types.AuthWebhookResponse - res.Code = types.CodePermissionDenied + res.Allowed = false _, err = res.Write(w) assert.NoError(t, err) @@ -541,8 +546,9 @@ func TestAuthWebhookCache(t *testing.T) { _, err := types.NewAuthWebhookRequest(r.Body) assert.NoError(t, err) + w.WriteHeader(http.StatusUnauthorized) var res types.AuthWebhookResponse - res.Code = types.CodeUnauthenticated + res.Allowed = false _, err = res.Write(w) assert.NoError(t, err) @@ -598,7 +604,7 @@ func TestAuthWebhookCache(t *testing.T) { } func TestAuthWebhookNewToken(t *testing.T) { - t.Run("reactivate with new token when receiving invalid token test", func(t *testing.T) { + t.Run("set new token when receiving invalid token test", func(t *testing.T) { ctx := context.Background() authServer, validToken := newAuthServer(t) @@ -633,7 +639,7 @@ func TestAuthWebhookNewToken(t *testing.T) { // reactivate with new token if err != nil { metadata := converter.ErrorMetadataOf(err) - if metadata["message"] == "invalid token" { + if metadata["reason"] == "invalid token" { err = cli.SetToken(validToken) assert.NoError(t, err) err = cli.Activate(ctx) From e68cb7144d27bc2e31b23883179beaba257c2d42 Mon Sep 17 00:00:00 2001 From: Youngteac Hong Date: Fri, 1 Nov 2024 19:42:29 +0900 Subject: [PATCH 9/9] Clean up codes --- client/auth.go | 5 +++++ client/client.go | 23 +++++++---------------- client/options.go | 3 --- test/integration/auth_webhook_test.go | 19 ++++++++----------- 4 files changed, 20 insertions(+), 30 deletions(-) diff --git a/client/auth.go b/client/auth.go index 7e422c6e2..269d80d84 100644 --- a/client/auth.go +++ b/client/auth.go @@ -39,6 +39,11 @@ func NewAuthInterceptor(apiKey, token string) *AuthInterceptor { } } +// SetToken sets the token. +func (i *AuthInterceptor) SetToken(token string) { + i.token = token +} + // WrapUnary creates a unary server interceptor for authorization. func (i *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func( diff --git a/client/client.go b/client/client.go index 3d8e4a466..282a19858 100644 --- a/client/client.go +++ b/client/client.go @@ -99,6 +99,7 @@ type Client struct { client v1connect.YorkieServiceClient options Options clientOptions []connect.ClientOption + interceptor *AuthInterceptor logger *zap.Logger id *time.ActorID @@ -149,8 +150,8 @@ func New(opts ...Option) (*Client, error) { } var clientOptions []connect.ClientOption - - clientOptions = append(clientOptions, connect.WithInterceptors(NewAuthInterceptor(options.APIKey, options.Token))) + interceptor := NewAuthInterceptor(options.APIKey, options.Token) + clientOptions = append(clientOptions, connect.WithInterceptors(interceptor)) if options.MaxCallRecvMsgSize != 0 { clientOptions = append(clientOptions, connect.WithReadMaxBytes(options.MaxCallRecvMsgSize)) } @@ -169,6 +170,7 @@ func New(opts ...Option) (*Client, error) { clientOptions: clientOptions, options: options, logger: logger, + interceptor: interceptor, key: k, status: deactivated, @@ -183,7 +185,6 @@ func Dial(rpcAddr string, opts ...Option) (*Client, error) { return nil, err } - cli.options.RPCAddress = rpcAddr if err := cli.Dial(rpcAddr); err != nil { return nil, err } @@ -206,19 +207,9 @@ func (c *Client) Dial(rpcAddr string) error { return nil } -// SetToken updates the client's token for reauthentication purposes. -func (c *Client) SetToken(token string) error { - newClientOptions := []connect.ClientOption{ - connect.WithInterceptors(NewAuthInterceptor(c.options.APIKey, token)), - } - if c.options.MaxCallRecvMsgSize != 0 { - newClientOptions = append(newClientOptions, - connect.WithReadMaxBytes(c.options.MaxCallRecvMsgSize)) - } - c.clientOptions = newClientOptions - - c.conn.CloseIdleConnections() - return c.Dial(c.options.RPCAddress) +// SetToken sets the given token of this client. +func (c *Client) SetToken(token string) { + c.interceptor.SetToken(token) } // Close closes all resources of this client. diff --git a/client/options.go b/client/options.go index 03799cd8f..fcb03b6a9 100644 --- a/client/options.go +++ b/client/options.go @@ -41,9 +41,6 @@ type Options struct { // CertFile is the path to the certificate file. CertFile string - // RPCAddress is the address of the RPC server. - RPCAddress string - // ServerNameOverride is the server name override. ServerNameOverride string diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 4dc405812..b8f9692b4 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -604,7 +604,7 @@ func TestAuthWebhookCache(t *testing.T) { } func TestAuthWebhookNewToken(t *testing.T) { - t.Run("set new token when receiving invalid token test", func(t *testing.T) { + t.Run("set valid token after invalid token test", func(t *testing.T) { ctx := context.Background() authServer, validToken := newAuthServer(t) @@ -636,15 +636,12 @@ func TestAuthWebhookNewToken(t *testing.T) { defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) - // reactivate with new token - if err != nil { - metadata := converter.ErrorMetadataOf(err) - if metadata["reason"] == "invalid token" { - err = cli.SetToken(validToken) - assert.NoError(t, err) - err = cli.Activate(ctx) - assert.NoError(t, err) - } - } + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + + // activate again with valid token + metadata := converter.ErrorMetadataOf(err) + assert.Equal(t, "invalid token", metadata["reason"]) + cli.SetToken(validToken) + assert.NoError(t, cli.Activate(ctx)) }) }