From 41c14b3af635cb5fd351503ade21cbcc7d28395f Mon Sep 17 00:00:00 2001 From: Davide Bianchi Date: Mon, 10 Jul 2023 15:57:58 +0200 Subject: [PATCH] update PR --- sdk/context.go | 12 ++++++++---- sdk/context_test.go | 7 ++++--- sdk/evaluator.go | 4 +++- sdk/openapi.go | 1 + sdk/openapi_test.go | 3 ++- sdk/rondinput/http/input_test.go | 1 + sdk/sdk.go | 2 +- sdk/sdk_test.go | 4 ++-- service/handler.go | 2 +- service/opa_transport_test.go | 3 ++- service/opamiddleware.go | 2 +- service/opamiddleware_test.go | 14 +++++++------- service/router_test.go | 2 +- 13 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sdk/context.go b/sdk/context.go index 4d494dcf..b9d90d1b 100644 --- a/sdk/context.go +++ b/sdk/context.go @@ -16,19 +16,23 @@ package sdk import ( "context" - "fmt" + "errors" +) + +var ( + ErrGetEvaluator = errors.New("no Evaluator found in request context") ) type sdkKey struct{} -func WithEvaluatorSDK(ctx context.Context, evaluator Evaluator) context.Context { +func WithEvaluator(ctx context.Context, evaluator Evaluator) context.Context { return context.WithValue(ctx, sdkKey{}, evaluator) } -func GetEvaluatorSKD(ctx context.Context) (Evaluator, error) { +func GetEvaluator(ctx context.Context) (Evaluator, error) { sdk, ok := ctx.Value(sdkKey{}).(Evaluator) if !ok { - return nil, fmt.Errorf("no SDKEvaluator found in request context") + return nil, ErrGetEvaluator } return sdk, nil diff --git a/sdk/context_test.go b/sdk/context_test.go index 41e1944f..92dee25a 100644 --- a/sdk/context_test.go +++ b/sdk/context_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/rond-authz/rond/openapi" + "github.com/stretchr/testify/require" ) @@ -39,15 +40,15 @@ func TestContext(t *testing.T) { rondConfig: rondConfig, } - ctx = WithEvaluatorSDK(ctx, expectedEvaluator) + ctx = WithEvaluator(ctx, expectedEvaluator) - actualEvaluator, err := GetEvaluatorSKD(ctx) + actualEvaluator, err := GetEvaluator(ctx) require.NoError(t, err) require.Equal(t, expectedEvaluator, actualEvaluator) }) t.Run("throws if not in context", func(t *testing.T) { - actualEvaluator, err := GetEvaluatorSKD(context.Background()) + actualEvaluator, err := GetEvaluator(context.Background()) require.EqualError(t, err, "no SDKEvaluator found in request context") require.Nil(t, actualEvaluator) }) diff --git a/sdk/evaluator.go b/sdk/evaluator.go index 18693650..b45ef87c 100644 --- a/sdk/evaluator.go +++ b/sdk/evaluator.go @@ -22,6 +22,7 @@ import ( "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/types" + "github.com/sirupsen/logrus" ) @@ -31,8 +32,9 @@ type PolicyResult struct { } // Warning: This interface is experimental, and it could change with breaking also in rond patches. -// Do not use outside this repository until it is not ready. +// Do not use outside this repository until it is ready. type Evaluator interface { + // retrieve the RondConfig used to generate the evaluator Config() openapi.RondConfig EvaluateRequestPolicy(ctx context.Context, input core.RondInput, userInfo types.User) (PolicyResult, error) diff --git a/sdk/openapi.go b/sdk/openapi.go index bef4f2a9..e9766977 100644 --- a/sdk/openapi.go +++ b/sdk/openapi.go @@ -17,6 +17,7 @@ package sdk import ( "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/openapi" + "github.com/sirupsen/logrus" "github.com/uptrace/bunrouter" ) diff --git a/sdk/openapi_test.go b/sdk/openapi_test.go index 7448950b..c62d4475 100644 --- a/sdk/openapi_test.go +++ b/sdk/openapi_test.go @@ -19,9 +19,10 @@ import ( "net/http" "testing" - "github.com/prometheus/client_golang/prometheus" "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/openapi" + + "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" diff --git a/sdk/rondinput/http/input_test.go b/sdk/rondinput/http/input_test.go index 71421ed0..e8635d12 100644 --- a/sdk/rondinput/http/input_test.go +++ b/sdk/rondinput/http/input_test.go @@ -23,6 +23,7 @@ import ( "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/types" + "github.com/stretchr/testify/require" ) diff --git a/sdk/sdk.go b/sdk/sdk.go index 93d40a00..0ecb9a8e 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -44,7 +44,7 @@ func NewFromOAS(ctx context.Context, opaModuleConfig *core.OPAModuleConfig, oas options = &FromOASOptions{} } if options.Logger == nil { - // TODO: default to a logger instead of return error + // TODO: default to a fake silent logger instead of return error return nil, fmt.Errorf("logger is required inside options") } diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index 099ccbf7..7d1a4eb1 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -23,10 +23,10 @@ import ( "github.com/rond-authz/rond/internal/mocks" "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/types" - "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/test" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) diff --git a/service/handler.go b/service/handler.go index 0b1fb2c8..ac31b21d 100644 --- a/service/handler.go +++ b/service/handler.go @@ -77,7 +77,7 @@ func rbacHandler(w http.ResponseWriter, req *http.Request) { return } - evaluatorSdk, err := sdk.GetEvaluatorSKD(requestContext) + evaluatorSdk, err := sdk.GetEvaluator(requestContext) if err != nil { logger.WithField("error", logrus.Fields{"message": err.Error()}).Error("no evaluatorSdk found in context") utils.FailResponse(w, "no evaluators sdk found in context", utils.GENERIC_BUSINESS_ERROR_MESSAGE) diff --git a/service/opa_transport_test.go b/service/opa_transport_test.go index c588a5b7..f8e43583 100644 --- a/service/opa_transport_test.go +++ b/service/opa_transport_test.go @@ -25,7 +25,6 @@ import ( "strconv" "testing" - "github.com/prometheus/client_golang/prometheus" "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/internal/mocks" "github.com/rond-authz/rond/internal/mongoclient" @@ -33,6 +32,8 @@ import ( "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/sdk" "github.com/rond-authz/rond/types" + + "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" diff --git a/service/opamiddleware.go b/service/opamiddleware.go index 33205d10..5fad08d8 100644 --- a/service/opamiddleware.go +++ b/service/opamiddleware.go @@ -93,7 +93,7 @@ func OPAMiddleware( return } - ctx := sdk.WithEvaluatorSDK(r.Context(), evaluator) + ctx := sdk.WithEvaluator(r.Context(), evaluator) next.ServeHTTP(w, r.WithContext(ctx)) }) diff --git a/service/opamiddleware_test.go b/service/opamiddleware_test.go index e8d471ac..e6d91194 100644 --- a/service/opamiddleware_test.go +++ b/service/opamiddleware_test.go @@ -187,7 +187,7 @@ foobar { true }`, middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := sdk.GetEvaluatorSKD(r.Context()) + actual, err := sdk.GetEvaluator(r.Context()) require.NoError(t, err, "Unexpected error") require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "todo"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -209,7 +209,7 @@ foobar { true }`, middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := sdk.GetEvaluatorSKD(r.Context()) + actual, err := sdk.GetEvaluator(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "todo"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -232,7 +232,7 @@ very_very_composed_permission { true }`, middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := sdk.GetEvaluatorSKD(r.Context()) + actual, err := sdk.GetEvaluator(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -260,7 +260,7 @@ very_very_composed_permission_with_eval { true }`, middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", options) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := sdk.GetEvaluatorSKD(r.Context()) + actual, err := sdk.GetEvaluator(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission.with.eval"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -278,7 +278,7 @@ very_very_composed_permission_with_eval { true }`, routesNotToProxy := []string{"/not/proxy"} middleware := OPAMiddleware(nil, nil, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := sdk.GetEvaluatorSKD(r.Context()) + _, err := sdk.GetEvaluator(r.Context()) require.EqualError(t, err, "no SDKEvaluator found in request context") })) @@ -322,7 +322,7 @@ func TestOPAMiddlewareStandaloneIntegration(t *testing.T) { rondSDK := getSdk(t, opaModule) middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", options) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := sdk.GetEvaluatorSKD(r.Context()) + actual, err := sdk.GetEvaluator(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -345,7 +345,7 @@ very_very_composed_permission_with_eval { true }`, rondSDK := getSdk(t, opaModule) middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", options) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := sdk.GetEvaluatorSKD(r.Context()) + actual, err := sdk.GetEvaluator(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission.with.eval"}}, actual.Config()) w.WriteHeader(http.StatusOK) diff --git a/service/router_test.go b/service/router_test.go index 585202e6..7b250c38 100644 --- a/service/router_test.go +++ b/service/router_test.go @@ -240,7 +240,7 @@ func createContext( var partialContext context.Context partialContext = context.WithValue(originalCtx, config.EnvKey{}, env) - partialContext = sdk.WithEvaluatorSDK(partialContext, evaluator) + partialContext = sdk.WithEvaluator(partialContext, evaluator) if mongoClient != nil { partialContext = context.WithValue(partialContext, types.MongoClientContextKey{}, mongoClient)