diff --git a/core/errors.go b/core/errors.go new file mode 100644 index 00000000..cb2fa4fd --- /dev/null +++ b/core/errors.go @@ -0,0 +1,38 @@ +// Copyright 2023 Mia srl +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import "fmt" + +var ( + ErrMissingRegoModules = fmt.Errorf("no rego module found in directory") + ErrRegoModuleReadFailed = fmt.Errorf("failed rego file read") + + ErrEvaluatorCreationFailed = fmt.Errorf("error during evaluator creation") + ErrEvaluatorNotFound = fmt.Errorf("evaluator not found") + + ErrPolicyEvalFailed = fmt.Errorf("policy evaluation failed") + ErrPartialPolicyEvalFailed = fmt.Errorf("partial %w", ErrPolicyEvalFailed) + ErrResponsePolicyEvalFailed = fmt.Errorf("response %w", ErrPolicyEvalFailed) + + ErrFailedInputParse = fmt.Errorf("failed input parse") + ErrFailedInputEncode = fmt.Errorf("failed input encode") + ErrFailedInputRequestParse = fmt.Errorf("failed request body parse") + ErrFailedInputRequestDeserialization = fmt.Errorf("failed request body deserialization") + + ErrUnexepectedContentType = fmt.Errorf("unexpected content type") + + ErrOPATransportInvalidResponseBody = fmt.Errorf("response body is not valid") +) diff --git a/core/input.go b/core/input.go index 4a2b9318..0327da6e 100644 --- a/core/input.go +++ b/core/input.go @@ -106,9 +106,11 @@ func CreateRegoQueryInput( inputBytes, err := json.Marshal(input) if err != nil { - return nil, fmt.Errorf("failed input JSON encode: %v", err) + return nil, fmt.Errorf("%w: %v", ErrFailedInputEncode, err) } - logger.Tracef("OPA input rego creation in: %+v", time.Since(opaInputCreationTime)) + logger. + WithField("inputCreationTimeMicroseconds", time.Since(opaInputCreationTime).Microseconds()). + Tracef("input creation time") return inputBytes, nil } @@ -131,10 +133,10 @@ func (req requestInfo) Input(user types.User, responseBody any) (Input, error) { if shouldParseJSONBody { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return Input{}, fmt.Errorf("failed request body parse: %s", err.Error()) + return Input{}, fmt.Errorf("%w: %s", ErrFailedInputRequestParse, err.Error()) } if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - return Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error()) + return Input{}, fmt.Errorf("%w: %s", ErrFailedInputRequestDeserialization, err.Error()) } req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } diff --git a/core/opa_transport.go b/core/opa_transport.go index 0620aceb..be3cd27e 100644 --- a/core/opa_transport.go +++ b/core/opa_transport.go @@ -92,13 +92,13 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er if !utils.HasApplicationJSONContentType(resp.Header) { t.logger.WithField("foundContentType", resp.Header.Get(utils.ContentTypeHeaderKey)).Debug("found content type") - t.responseWithError(resp, fmt.Errorf("content-type is not application/json"), http.StatusInternalServerError) + t.responseWithError(resp, fmt.Errorf("%w: response content-type is not application/json", ErrUnexepectedContentType), http.StatusInternalServerError) return resp, nil } var decodedBody interface{} if err := json.Unmarshal(b, &decodedBody); err != nil { - return nil, fmt.Errorf("response body is not valid: %s", err.Error()) + return nil, fmt.Errorf("%w: %s", ErrOPATransportInvalidResponseBody, err.Error()) } userInfo, err := mongoclient.RetrieveUserBindingsAndRoles(t.logger, t.request, t.userHeaders) @@ -121,7 +121,7 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er } func (t *OPATransport) responseWithError(resp *http.Response, err error, statusCode int) { - t.logger.WithField("error", logrus.Fields{"message": err.Error()}).Error("error while evaluating column filter query") + t.logger.WithField("error", logrus.Fields{"message": err.Error()}).Error(ErrResponsePolicyEvalFailed) message := utils.NO_PERMISSIONS_ERROR_MESSAGE if statusCode != http.StatusForbidden { message = utils.GENERIC_BUSINESS_ERROR_MESSAGE diff --git a/core/opaevaluator.go b/core/opaevaluator.go index d8b8d461..a1ea5b6e 100644 --- a/core/opaevaluator.go +++ b/core/opaevaluator.go @@ -24,7 +24,7 @@ import ( "strings" "time" - "github.com/prometheus/client_golang/prometheus" + "github.com/rond-authz/rond/custom_builtins" "github.com/rond-authz/rond/internal/metrics" "github.com/rond-authz/rond/internal/mongoclient" "github.com/rond-authz/rond/internal/opatranslator" @@ -32,11 +32,10 @@ import ( "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/types" - "github.com/rond-authz/rond/custom_builtins" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/topdown/print" + "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson/primitive" ) @@ -65,23 +64,29 @@ type PartialEvaluator struct { } func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (*PartialEvaluator, error) { - logger.Infof("precomputing rego query for allow policy: %s", policy) + logger.WithField("policyName", policy).Info("precomputing rego policy") policyEvaluatorTime := time.Now() partialResultEvaluator, err := NewPartialResultEvaluator(ctx, policy, opaModuleConfig, options) - if err == nil { - logger.Infof("computed rego query for policy: %s in %s", policy, time.Since(policyEvaluatorTime)) - return &PartialEvaluator{ - PartialEvaluator: partialResultEvaluator, - }, nil + if err != nil { + return nil, err } - return nil, err + + logger. + WithFields(logrus.Fields{ + "policyName": policy, + "computationTimeMicroserconds": time.Since(policyEvaluatorTime).Microseconds, + }). + Info("precomputation time") + + return &PartialEvaluator{PartialEvaluator: partialResultEvaluator}, nil } func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (PartialResultsEvaluators, error) { if oas == nil { return nil, fmt.Errorf("oas must not be nil") } + policyEvaluators := PartialResultsEvaluators{} for path, OASContent := range oas.Paths { for verb, verbConfig := range OASContent { @@ -92,7 +97,15 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope allowPolicy := verbConfig.PermissionV2.RequestFlow.PolicyName responsePolicy := verbConfig.PermissionV2.ResponseFlow.PolicyName - logger.Infof("precomputing rego queries for API: %s %s. Allow policy: %s. Response policy: %s.", verb, path, allowPolicy, responsePolicy) + logger. + WithFields(logrus.Fields{ + "verb": verb, + "policyName": allowPolicy, + "path": path, + "responsePolicyName": responsePolicy, + }). + Info("precomputing rego queries for API") + if allowPolicy == "" { // allow policy is required, if missing assume the API has no valid x-rond configuration. continue @@ -100,9 +113,8 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope if _, ok := policyEvaluators[allowPolicy]; !ok { evaluator, err := createPartialEvaluator(ctx, logger, allowPolicy, oas, opaModuleConfig, options) - if err != nil { - return nil, fmt.Errorf("error during evaluator creation: %s", err.Error()) + return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error()) } policyEvaluators[allowPolicy] = *evaluator @@ -111,9 +123,8 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope if responsePolicy != "" { if _, ok := policyEvaluators[responsePolicy]; !ok { evaluator, err := createPartialEvaluator(ctx, logger, responsePolicy, oas, opaModuleConfig, options) - if err != nil { - return nil, fmt.Errorf("error during evaluator creation: %s", err.Error()) + return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error()) } policyEvaluators[responsePolicy] = *evaluator @@ -182,7 +193,7 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod } inputTerm, err := ast.ParseTerm(string(input)) if err != nil { - return nil, fmt.Errorf("failed input parse: %v", err) + return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err) } sanitizedPolicy := strings.Replace(policy, ".", "_", -1) @@ -221,10 +232,12 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger opaEvaluatorInstanceTime := time.Now() evaluator, err := NewOPAEvaluator(ctx, policy, config, input, options) if err != nil { - logger.WithError(err).Error("failed RBAC policy creation") + logger.WithError(err).Error(ErrEvaluatorCreationFailed) return nil, err } - logger.Tracef("OPA evaluator instantiated in: %+v", time.Since(opaEvaluatorInstanceTime)) + logger. + WithField("evaluatorCreationTimeMicroseconds", time.Since(opaEvaluatorInstanceTime).Microseconds()). + Trace("evaluator creation time") return evaluator, nil } @@ -266,7 +279,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con if eval, ok := partialEvaluators[policy]; ok { inputTerm, err := ast.ParseTerm(string(input)) if err != nil { - return nil, fmt.Errorf("failed input parse: %v", err) + return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err) } evaluator := eval.PartialEvaluator.Rego( @@ -284,7 +297,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con routerInfo: options.RouterInfo, }, nil } - return nil, fmt.Errorf("policy evaluator not found: %s", policy) + return nil, fmt.Errorf("%w: %s", ErrEvaluatorNotFound, policy) } func (evaluator *OPAEvaluator) metrics() metrics.Metrics { @@ -298,7 +311,7 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv opaEvaluationTimeStart := time.Now() partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.Context) if err != nil { - return nil, fmt.Errorf("policy Evaluation has failed when partially evaluating the query: %s", err.Error()) + return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error()) } opaEvaluationTime := time.Since(opaEvaluationTimeStart) @@ -336,7 +349,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro results, err := evaluator.PolicyEvaluator.Eval(evaluator.Context) if err != nil { - return nil, fmt.Errorf("policy Evaluation has failed when evaluating the query: %s", err.Error()) + return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error()) } opaEvaluationTime := time.Since(opaEvaluationTimeStart) @@ -364,7 +377,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro if allowed { return responseBodyOverwriter, nil } - return nil, fmt.Errorf("RBAC policy evaluation failed, user is not allowed") + return nil, ErrPolicyEvalFailed } func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig) (interface{}, primitive.M, error) { @@ -418,11 +431,11 @@ func LoadRegoModule(rootDirectory string) (*OPAModuleConfig, error) { }) if regoModulePath == "" { - return nil, fmt.Errorf("no rego module found in directory") + return nil, ErrMissingRegoModules } fileContent, err := utils.ReadFile(regoModulePath) if err != nil { - return nil, fmt.Errorf("failed rego file read: %s", err.Error()) + return nil, fmt.Errorf("%w: %s", ErrRegoModuleReadFailed, err.Error()) } return &OPAModuleConfig{ diff --git a/core/sdk_test.go b/core/sdk_test.go index d88313fa..511f2b64 100644 --- a/core/sdk_test.go +++ b/core/sdk_test.go @@ -161,9 +161,8 @@ func TestEvaluateRequestPolicy(t *testing.T) { reqHeaders map[string]string mongoClient types.IMongoClient - expectedPolicy PolicyResult - expectedErr bool - expectedErrMessage string + expectedPolicy PolicyResult + expectedErr error } t.Run("evaluate request", func(t *testing.T) { @@ -196,9 +195,8 @@ func TestEvaluateRequestPolicy(t *testing.T) { UserID: "my-user", }, - expectedPolicy: PolicyResult{}, - expectedErr: true, - expectedErrMessage: "RBAC policy evaluation failed, user is not allowed", + expectedPolicy: PolicyResult{}, + expectedErr: ErrPolicyEvalFailed, }, "not allowed policy result": { method: http.MethodGet, @@ -208,9 +206,8 @@ func TestEvaluateRequestPolicy(t *testing.T) { }, opaModuleContent: `package policies todo { false }`, - expectedPolicy: PolicyResult{}, - expectedErr: true, - expectedErrMessage: "RBAC policy evaluation failed, user is not allowed", + expectedPolicy: PolicyResult{}, + expectedErr: ErrPolicyEvalFailed, }, "with empty filter query": { method: http.MethodGet, @@ -396,8 +393,8 @@ func TestEvaluateRequestPolicy(t *testing.T) { rondInput := NewRondInput(req, clientTypeHeaderKey, nil) actual, err := evaluate.EvaluateRequestPolicy(context.Background(), rondInput, testCase.user) - if testCase.expectedErr { - require.EqualError(t, err, testCase.expectedErrMessage) + if testCase.expectedErr != nil { + require.EqualError(t, err, testCase.expectedErr.Error()) } else { require.NoError(t, err) } @@ -490,10 +487,9 @@ func TestEvaluateResponsePolicy(t *testing.T) { decodedBody any - expectedBody string - expectedErr bool - expectedErrMessage string - notAllowed bool + expectedBody string + expectedErr error + notAllowed bool } t.Run("evaluate response", func(t *testing.T) { @@ -537,10 +533,9 @@ func TestEvaluateResponsePolicy(t *testing.T) { false body := input.response.body }`, - expectedErr: true, - expectedErrMessage: "RBAC policy evaluation failed, user is not allowed", - expectedBody: "", - notAllowed: true, + expectedErr: ErrPolicyEvalFailed, + expectedBody: "", + notAllowed: true, }, "with mongo query and body changed": { method: http.MethodGet, @@ -605,8 +600,8 @@ func TestEvaluateResponsePolicy(t *testing.T) { rondInput := NewRondInput(req, clientTypeHeaderKey, nil) actual, err := evaluate.EvaluateResponsePolicy(context.Background(), rondInput, testCase.user, testCase.decodedBody) - if testCase.expectedErr { - require.EqualError(t, err, testCase.expectedErrMessage) + if testCase.expectedErr != nil { + require.EqualError(t, err, testCase.expectedErr.Error()) } else { require.NoError(t, err) }