diff --git a/core/opaevaluator.go b/core/opaevaluator.go index 21a7fe83..8beffd58 100644 --- a/core/opaevaluator.go +++ b/core/opaevaluator.go @@ -24,6 +24,8 @@ import ( "strings" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/rond-authz/rond/internal/metrics" "github.com/rond-authz/rond/internal/mongoclient" "github.com/rond-authz/rond/internal/opatranslator" "github.com/rond-authz/rond/internal/utils" @@ -50,6 +52,9 @@ type OPAEvaluator struct { PolicyEvaluator Evaluator PolicyName string Context context.Context + + m *metrics.Metrics + routerInfo openapi.RouterInfo } type PartialResultsEvaluatorConfigKey struct{} @@ -156,6 +161,19 @@ func (h printHook) Print(_ print.Context, message string) error { type EvaluatorOptions struct { EnablePrintStatements bool MongoClient types.IMongoClient + + Metrics *metrics.Metrics + RouterInfo openapi.RouterInfo +} + +func (e *EvaluatorOptions) WithMetrics(metrics metrics.Metrics) *EvaluatorOptions { + e.Metrics = &metrics + return e +} + +func (e *EvaluatorOptions) WithRouterInfo(routerInfo openapi.RouterInfo) *EvaluatorOptions { + e.RouterInfo = routerInfo + return e } func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAModuleConfig, input []byte, options *EvaluatorOptions) (*OPAEvaluator, error) { @@ -188,6 +206,9 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod PolicyEvaluator: query, PolicyName: policy, Context: ctx, + + m: options.Metrics, + routerInfo: options.RouterInfo, }, nil } @@ -258,17 +279,44 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con PolicyName: policy, PolicyEvaluator: evaluator, Context: ctx, + + m: options.Metrics, + routerInfo: options.RouterInfo, }, nil } return nil, fmt.Errorf("policy evaluator not found: %s", policy) } +func (evaluator *OPAEvaluator) metrics() metrics.Metrics { + if evaluator.m != nil { + return *evaluator.m + } + return metrics.SetupMetrics("rond") +} + func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitive.M, error) { + opaEvaluationTimeStart := time.Now() partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.Context) if err != nil { return nil, fmt.Errorf("policy Evaluation has failed when partially evaluating the query: %s", err.Error()) } + opaEvaluationTime := time.Since(opaEvaluationTimeStart) + + evaluator.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{ + "policy_name": evaluator.PolicyName, + }).Observe(float64(opaEvaluationTime.Milliseconds())) + + logger.WithFields(logrus.Fields{ + "evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(), + "policyName": evaluator.PolicyName, + "partialEval": true, + "allowed": true, + "matchedPath": evaluator.routerInfo.MatchedPath, + "requestedPath": evaluator.routerInfo.RequestedPath, + "method": evaluator.routerInfo.Method, + }).Debug("policy evaluation completed") + client := opatranslator.OPAClient{} q, err := client.ProcessQuery(partialResults) if err != nil { @@ -284,11 +332,29 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv } func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, error) { + opaEvaluationTimeStart := time.Now() + results, err := evaluator.PolicyEvaluator.Eval(evaluator.Context) if err != nil { return nil, fmt.Errorf("policy Evaluation has failed when evaluating the query: %s", err.Error()) } + opaEvaluationTime := time.Since(opaEvaluationTimeStart) + evaluator.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{ + "policy_name": evaluator.PolicyName, + }).Observe(float64(opaEvaluationTime.Milliseconds())) + + logger.WithFields(logrus.Fields{ + "evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(), + "policyName": evaluator.PolicyName, + "partialEval": false, + "allowed": results.Allowed(), + "resultsLength": len(results), + "matchedPath": evaluator.routerInfo.MatchedPath, + "requestedPath": evaluator.routerInfo.RequestedPath, + "method": evaluator.routerInfo.Method, + }).Debug("policy evaluation completed") + if results.Allowed() { logger.WithFields(logrus.Fields{ "policyName": evaluator.PolicyName, @@ -297,7 +363,6 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro }).Tracef("policy results") return nil, nil } - // The results returned by OPA are a list of Results object with fields: // - Expressions: list of list // - Bindings: object diff --git a/core/opamiddleware.go b/core/opamiddleware.go index 0b5ea7eb..62adc022 100644 --- a/core/opamiddleware.go +++ b/core/opamiddleware.go @@ -54,8 +54,12 @@ func OPAMiddleware( logger := glogger.Get(r.Context()) evaluator, err := sdk.FindEvaluator(logger, r.Method, path) - permission := evaluator.Config() - if r.Method == http.MethodGet && r.URL.Path == targetServiceOASPath && permission.RequestFlow.PolicyName == "" { + rondConfig := openapi.RondConfig{} + if err == nil { + rondConfig = evaluator.Config() + } + + if r.Method == http.MethodGet && r.URL.Path == targetServiceOASPath && rondConfig.RequestFlow.PolicyName == "" { fields := logrus.Fields{} if err != nil { fields["error"] = logrus.Fields{"message": err.Error()} @@ -65,13 +69,13 @@ func OPAMiddleware( return } - if err != nil || permission.RequestFlow.PolicyName == "" { + if err != nil || rondConfig.RequestFlow.PolicyName == "" { errorMessage := "User is not allowed to request the API" statusCode := http.StatusForbidden fields := logrus.Fields{ "originalRequestPath": utils.SanitizeString(r.URL.Path), "method": utils.SanitizeString(r.Method), - "allowPermission": utils.SanitizeString(permission.RequestFlow.PolicyName), + "allowPermission": utils.SanitizeString(rondConfig.RequestFlow.PolicyName), } technicalError := "" if err != nil { diff --git a/core/sdk.go b/core/sdk.go index a762d147..d938c22a 100644 --- a/core/sdk.go +++ b/core/sdk.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "time" "github.com/rond-authz/rond/internal/metrics" "github.com/rond-authz/rond/openapi" @@ -50,19 +49,12 @@ type SDKEvaluator interface { } type evaluator struct { - rond rondImpl - logger *logrus.Entry - rondConfig openapi.RondConfig - - routeInfo openapi.RouterInfo -} - -func (e evaluator) metrics() metrics.Metrics { - return e.rond.metrics -} + logger *logrus.Entry + rondConfig openapi.RondConfig + opaModuleConfig *OPAModuleConfig + partialResultEvaluators PartialResultsEvaluators -func (e evaluator) partialResultEvaluators() PartialResultsEvaluators { - return e.rond.partialResultEvaluators + evaluatorOptions *EvaluatorOptions } func (e evaluator) Config() openapi.RondConfig { @@ -90,36 +82,19 @@ func (e evaluator) EvaluateRequestPolicy(ctx context.Context, req RondInput, use var evaluatorAllowPolicy *OPAEvaluator if !rondConfig.RequestFlow.GenerateQuery { - evaluatorAllowPolicy, err = e.partialResultEvaluators().GetEvaluatorFromPolicy(ctx, rondConfig.RequestFlow.PolicyName, regoInput, e.rond.evaluatorOptions) + evaluatorAllowPolicy, err = e.partialResultEvaluators.GetEvaluatorFromPolicy(ctx, rondConfig.RequestFlow.PolicyName, regoInput, e.evaluatorOptions) if err != nil { return PolicyResult{}, err } } else { - evaluatorAllowPolicy, err = e.rond.opaModuleConfig.CreateQueryEvaluator(ctx, e.logger, rondConfig.RequestFlow.PolicyName, regoInput, e.rond.evaluatorOptions) + evaluatorAllowPolicy, err = e.opaModuleConfig.CreateQueryEvaluator(ctx, e.logger, rondConfig.RequestFlow.PolicyName, regoInput, e.evaluatorOptions) if err != nil { return PolicyResult{}, err } } - opaEvaluationTimeStart := time.Now() - _, query, err := evaluatorAllowPolicy.PolicyEvaluation(e.logger, &rondConfig) - policyName := rondConfig.RequestFlow.PolicyName - opaEvaluationTime := time.Since(opaEvaluationTimeStart) - e.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{ - "policy_name": policyName, - }).Observe(float64(opaEvaluationTime.Milliseconds())) - - e.logger.WithFields(logrus.Fields{ - "evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(), - "policyName": policyName, - "partialEval": rondConfig.RequestFlow.GenerateQuery, - "allowed": err == nil, - "matchedPath": e.routeInfo.MatchedPath, - "requestedPath": e.routeInfo.RequestedPath, - "method": e.routeInfo.Method, - }).Debug("policy evaluation completed") if err != nil { e.logger.WithField("error", logrus.Fields{ "policyName": rondConfig.RequestFlow.PolicyName, @@ -161,26 +136,7 @@ func (e evaluator) EvaluateResponsePolicy(ctx context.Context, rondInput RondInp return nil, err } - opaEvaluationTimeStart := time.Now() - - evaluator, err := e.partialResultEvaluators().GetEvaluatorFromPolicy(ctx, e.rondConfig.ResponseFlow.PolicyName, regoInput, e.rond.evaluatorOptions) - - policyName := rondConfig.ResponseFlow.PolicyName - opaEvaluationTime := time.Since(opaEvaluationTimeStart) - e.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{ - "policy_name": policyName, - }).Observe(float64(opaEvaluationTime.Milliseconds())) - - e.logger.WithFields(logrus.Fields{ - "evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(), - "policyName": policyName, - "partialEval": false, - "allowed": err == nil, - "matchedPath": e.routeInfo.MatchedPath, - "requestedPath": e.routeInfo.RequestedPath, - "method": e.routeInfo.Method, - }).Debug("policy evaluation completed") - + evaluator, err := e.partialResultEvaluators.GetEvaluatorFromPolicy(ctx, e.rondConfig.ResponseFlow.PolicyName, regoInput, e.evaluatorOptions) if err != nil { return nil, err } @@ -205,19 +161,20 @@ type rondImpl struct { oas *openapi.OpenAPISpec opaModuleConfig *OPAModuleConfig - metrics metrics.Metrics - clientTypeHeaderKey string } func (r rondImpl) FindEvaluator(logger *logrus.Entry, method, path string) (SDKEvaluator, error) { permission, routerInfo, err := r.oas.FindPermission(r.oasRouter, path, method) + if err != nil { + return nil, err + } return evaluator{ - rondConfig: permission, - logger: logger, - rond: r, - - routeInfo: routerInfo, + rondConfig: permission, + logger: logger, + opaModuleConfig: r.opaModuleConfig, + partialResultEvaluators: r.partialResultEvaluators, + evaluatorOptions: r.evaluatorOptions.WithRouterInfo(routerInfo), }, err } @@ -249,6 +206,10 @@ func NewSDK( if registry != nil { m.MustRegister(registry) } + if evaluatorOptions == nil { + evaluatorOptions = &EvaluatorOptions{} + } + evaluatorOptions.WithMetrics(m) return rondImpl{ partialResultEvaluators: evaluator, @@ -257,8 +218,6 @@ func NewSDK( oas: oas, opaModuleConfig: opaModuleConfig, - metrics: m, - clientTypeHeaderKey: clientTypeHeaderKey, }, nil } diff --git a/core/sdk_test.go b/core/sdk_test.go index 712070ce..50bde638 100644 --- a/core/sdk_test.go +++ b/core/sdk_test.go @@ -100,35 +100,30 @@ func TestSDK(t *testing.T) { t.Run("throws if path and method not found", func(t *testing.T) { actual, err := sdk.FindEvaluator(logger, http.MethodGet, "/not-existent/path") require.ErrorContains(t, err, "not found oas definition: GET /not-existent/path") - require.Equal(t, evaluator{ - rondConfig: openapi.RondConfig{}, - logger: logger, - rond: rond, - - routeInfo: openapi.RouterInfo{ - RequestedPath: "/not-existent/path", - Method: "GET", - }, - }, actual) + require.Nil(t, actual) }) t.Run("returns correct evaluator", func(t *testing.T) { actual, err := sdk.FindEvaluator(logger, http.MethodGet, "/users/") require.NoError(t, err) + evaluatorOptions := &EvaluatorOptions{ + Metrics: rond.evaluatorOptions.Metrics, + RouterInfo: openapi.RouterInfo{ + MatchedPath: "/users/", + RequestedPath: "/users/", + Method: http.MethodGet, + }, + } require.Equal(t, evaluator{ rondConfig: openapi.RondConfig{ RequestFlow: openapi.RequestFlow{ PolicyName: "todo", }, }, - logger: logger, - rond: rond, - - routeInfo: openapi.RouterInfo{ - RequestedPath: "/users/", - Method: "GET", - MatchedPath: "/users/", - }, + opaModuleConfig: opaModule, + partialResultEvaluators: rond.partialResultEvaluators, + logger: logger, + evaluatorOptions: evaluatorOptions, }, actual) t.Run("get permissions", func(t *testing.T) { @@ -418,14 +413,26 @@ func TestEvaluateRequestPolicy(t *testing.T) { require.NotNil(t, actual) delete(actualEntry.Data, "evaluationTimeMicroseconds") - require.Equal(t, logrus.Fields{ + + resultLength := 1 + if !actual.Allowed { + resultLength = 0 + } + + fields := logrus.Fields{ "allowed": actual.Allowed, "requestedPath": testCase.path, - "matchedPath": evaluatorInfo.routeInfo.MatchedPath, + "matchedPath": evaluatorInfo.evaluatorOptions.RouterInfo.MatchedPath, "method": testCase.method, "partialEval": evaluate.Config().RequestFlow.GenerateQuery, "policyName": evaluate.Config().RequestFlow.PolicyName, - }, actualEntry.Data) + } + + if !evaluate.Config().RequestFlow.GenerateQuery { + fields["resultsLength"] = resultLength + } + + require.Equal(t, fields, actualEntry.Data) }) t.Run("metrics", func(t *testing.T) { @@ -596,12 +603,13 @@ func TestEvaluateResponsePolicy(t *testing.T) { require.NotNil(t, actual) delete(actual.Data, "evaluationTimeMicroseconds") require.Equal(t, logrus.Fields{ - "allowed": true, + "allowed": false, "requestedPath": testCase.path, - "matchedPath": evaluatorInfo.routeInfo.MatchedPath, + "matchedPath": evaluatorInfo.evaluatorOptions.RouterInfo.MatchedPath, "method": testCase.method, "partialEval": false, "policyName": evaluate.Config().ResponseFlow.PolicyName, + "resultsLength": 1, }, actual.Data) }) diff --git a/service/handler_test.go b/service/handler_test.go index 7cb57936..63fa2dc6 100644 --- a/service/handler_test.go +++ b/service/handler_test.go @@ -756,6 +756,7 @@ allow { "partialEval": false, "policyName": "todo", "requestedPath": "/api", + "resultsLength": 1, }, actualLog[0].Data) })