Skip to content

Commit

Permalink
refactor: reworked error messages (#214)
Browse files Browse the repository at this point in the history
* fix

* refactor: core error messages

* fix: missing saved file

* fix: test

* fix: test
  • Loading branch information
fredmaggiowski authored Jun 30, 2023
1 parent d6af895 commit 48ec7dd
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 53 deletions.
38 changes: 38 additions & 0 deletions core/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2023 Mia srl
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import "fmt"

var (
ErrMissingRegoModules = fmt.Errorf("no rego module found in directory")
ErrRegoModuleReadFailed = fmt.Errorf("failed rego file read")

ErrEvaluatorCreationFailed = fmt.Errorf("error during evaluator creation")
ErrEvaluatorNotFound = fmt.Errorf("evaluator not found")

ErrPolicyEvalFailed = fmt.Errorf("policy evaluation failed")
ErrPartialPolicyEvalFailed = fmt.Errorf("partial %w", ErrPolicyEvalFailed)
ErrResponsePolicyEvalFailed = fmt.Errorf("response %w", ErrPolicyEvalFailed)

ErrFailedInputParse = fmt.Errorf("failed input parse")
ErrFailedInputEncode = fmt.Errorf("failed input encode")
ErrFailedInputRequestParse = fmt.Errorf("failed request body parse")
ErrFailedInputRequestDeserialization = fmt.Errorf("failed request body deserialization")

ErrUnexepectedContentType = fmt.Errorf("unexpected content type")

ErrOPATransportInvalidResponseBody = fmt.Errorf("response body is not valid")
)
10 changes: 6 additions & 4 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ func CreateRegoQueryInput(

inputBytes, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("failed input JSON encode: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputEncode, err)
}
logger.Tracef("OPA input rego creation in: %+v", time.Since(opaInputCreationTime))
logger.
WithField("inputCreationTimeMicroseconds", time.Since(opaInputCreationTime).Microseconds()).
Tracef("input creation time")
return inputBytes, nil
}

Expand All @@ -131,10 +133,10 @@ func (req requestInfo) Input(user types.User, responseBody any) (Input, error) {
if shouldParseJSONBody {
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return Input{}, fmt.Errorf("failed request body parse: %s", err.Error())
return Input{}, fmt.Errorf("%w: %s", ErrFailedInputRequestParse, err.Error())
}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
return Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error())
return Input{}, fmt.Errorf("%w: %s", ErrFailedInputRequestDeserialization, err.Error())
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
Expand Down
6 changes: 3 additions & 3 deletions core/opa_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er

if !utils.HasApplicationJSONContentType(resp.Header) {
t.logger.WithField("foundContentType", resp.Header.Get(utils.ContentTypeHeaderKey)).Debug("found content type")
t.responseWithError(resp, fmt.Errorf("content-type is not application/json"), http.StatusInternalServerError)
t.responseWithError(resp, fmt.Errorf("%w: response content-type is not application/json", ErrUnexepectedContentType), http.StatusInternalServerError)
return resp, nil
}

var decodedBody interface{}
if err := json.Unmarshal(b, &decodedBody); err != nil {
return nil, fmt.Errorf("response body is not valid: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrOPATransportInvalidResponseBody, err.Error())
}

userInfo, err := mongoclient.RetrieveUserBindingsAndRoles(t.logger, t.request, t.userHeaders)
Expand All @@ -121,7 +121,7 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er
}

func (t *OPATransport) responseWithError(resp *http.Response, err error, statusCode int) {
t.logger.WithField("error", logrus.Fields{"message": err.Error()}).Error("error while evaluating column filter query")
t.logger.WithField("error", logrus.Fields{"message": err.Error()}).Error(ErrResponsePolicyEvalFailed)
message := utils.NO_PERMISSIONS_ERROR_MESSAGE
if statusCode != http.StatusForbidden {
message = utils.GENERIC_BUSINESS_ERROR_MESSAGE
Expand Down
63 changes: 38 additions & 25 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@ import (
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/rond-authz/rond/custom_builtins"
"github.com/rond-authz/rond/internal/metrics"
"github.com/rond-authz/rond/internal/mongoclient"
"github.com/rond-authz/rond/internal/opatranslator"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"

"github.com/rond-authz/rond/custom_builtins"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand Down Expand Up @@ -65,23 +64,29 @@ type PartialEvaluator struct {
}

func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (*PartialEvaluator, error) {
logger.Infof("precomputing rego query for allow policy: %s", policy)
logger.WithField("policyName", policy).Info("precomputing rego policy")

policyEvaluatorTime := time.Now()
partialResultEvaluator, err := NewPartialResultEvaluator(ctx, policy, opaModuleConfig, options)
if err == nil {
logger.Infof("computed rego query for policy: %s in %s", policy, time.Since(policyEvaluatorTime))
return &PartialEvaluator{
PartialEvaluator: partialResultEvaluator,
}, nil
if err != nil {
return nil, err
}
return nil, err

logger.
WithFields(logrus.Fields{
"policyName": policy,
"computationTimeMicroserconds": time.Since(policyEvaluatorTime).Microseconds,
}).
Info("precomputation time")

return &PartialEvaluator{PartialEvaluator: partialResultEvaluator}, nil
}

func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (PartialResultsEvaluators, error) {
if oas == nil {
return nil, fmt.Errorf("oas must not be nil")
}

policyEvaluators := PartialResultsEvaluators{}
for path, OASContent := range oas.Paths {
for verb, verbConfig := range OASContent {
Expand All @@ -92,17 +97,24 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope
allowPolicy := verbConfig.PermissionV2.RequestFlow.PolicyName
responsePolicy := verbConfig.PermissionV2.ResponseFlow.PolicyName

logger.Infof("precomputing rego queries for API: %s %s. Allow policy: %s. Response policy: %s.", verb, path, allowPolicy, responsePolicy)
logger.
WithFields(logrus.Fields{
"verb": verb,
"policyName": allowPolicy,
"path": path,
"responsePolicyName": responsePolicy,
}).
Info("precomputing rego queries for API")

if allowPolicy == "" {
// allow policy is required, if missing assume the API has no valid x-rond configuration.
continue
}

if _, ok := policyEvaluators[allowPolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, allowPolicy, oas, opaModuleConfig, options)

if err != nil {
return nil, fmt.Errorf("error during evaluator creation: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[allowPolicy] = *evaluator
Expand All @@ -111,9 +123,8 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope
if responsePolicy != "" {
if _, ok := policyEvaluators[responsePolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, responsePolicy, oas, opaModuleConfig, options)

if err != nil {
return nil, fmt.Errorf("error during evaluator creation: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[responsePolicy] = *evaluator
Expand Down Expand Up @@ -182,7 +193,7 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod
}
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("failed input parse: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

sanitizedPolicy := strings.Replace(policy, ".", "_", -1)
Expand Down Expand Up @@ -221,10 +232,12 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger
opaEvaluatorInstanceTime := time.Now()
evaluator, err := NewOPAEvaluator(ctx, policy, config, input, options)
if err != nil {
logger.WithError(err).Error("failed RBAC policy creation")
logger.WithError(err).Error(ErrEvaluatorCreationFailed)
return nil, err
}
logger.Tracef("OPA evaluator instantiated in: %+v", time.Since(opaEvaluatorInstanceTime))
logger.
WithField("evaluatorCreationTimeMicroseconds", time.Since(opaEvaluatorInstanceTime).Microseconds()).
Trace("evaluator creation time")
return evaluator, nil
}

Expand Down Expand Up @@ -266,7 +279,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
if eval, ok := partialEvaluators[policy]; ok {
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("failed input parse: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

evaluator := eval.PartialEvaluator.Rego(
Expand All @@ -284,7 +297,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
routerInfo: options.RouterInfo,
}, nil
}
return nil, fmt.Errorf("policy evaluator not found: %s", policy)
return nil, fmt.Errorf("%w: %s", ErrEvaluatorNotFound, policy)
}

func (evaluator *OPAEvaluator) metrics() metrics.Metrics {
Expand All @@ -298,7 +311,7 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.Context)
if err != nil {
return nil, fmt.Errorf("policy Evaluation has failed when partially evaluating the query: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
Expand Down Expand Up @@ -336,7 +349,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro

results, err := evaluator.PolicyEvaluator.Eval(evaluator.Context)
if err != nil {
return nil, fmt.Errorf("policy Evaluation has failed when evaluating the query: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
Expand Down Expand Up @@ -364,7 +377,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro
if allowed {
return responseBodyOverwriter, nil
}
return nil, fmt.Errorf("RBAC policy evaluation failed, user is not allowed")
return nil, ErrPolicyEvalFailed
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig) (interface{}, primitive.M, error) {
Expand Down Expand Up @@ -418,11 +431,11 @@ func LoadRegoModule(rootDirectory string) (*OPAModuleConfig, error) {
})

if regoModulePath == "" {
return nil, fmt.Errorf("no rego module found in directory")
return nil, ErrMissingRegoModules
}
fileContent, err := utils.ReadFile(regoModulePath)
if err != nil {
return nil, fmt.Errorf("failed rego file read: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrRegoModuleReadFailed, err.Error())
}

return &OPAModuleConfig{
Expand Down
37 changes: 16 additions & 21 deletions core/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
reqHeaders map[string]string
mongoClient types.IMongoClient

expectedPolicy PolicyResult
expectedErr bool
expectedErrMessage string
expectedPolicy PolicyResult
expectedErr error
}

t.Run("evaluate request", func(t *testing.T) {
Expand Down Expand Up @@ -196,9 +195,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
UserID: "my-user",
},

expectedPolicy: PolicyResult{},
expectedErr: true,
expectedErrMessage: "RBAC policy evaluation failed, user is not allowed",
expectedPolicy: PolicyResult{},
expectedErr: ErrPolicyEvalFailed,
},
"not allowed policy result": {
method: http.MethodGet,
Expand All @@ -208,9 +206,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
},
opaModuleContent: `package policies todo { false }`,

expectedPolicy: PolicyResult{},
expectedErr: true,
expectedErrMessage: "RBAC policy evaluation failed, user is not allowed",
expectedPolicy: PolicyResult{},
expectedErr: ErrPolicyEvalFailed,
},
"with empty filter query": {
method: http.MethodGet,
Expand Down Expand Up @@ -396,8 +393,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)

actual, err := evaluate.EvaluateRequestPolicy(context.Background(), rondInput, testCase.user)
if testCase.expectedErr {
require.EqualError(t, err, testCase.expectedErrMessage)
if testCase.expectedErr != nil {
require.EqualError(t, err, testCase.expectedErr.Error())
} else {
require.NoError(t, err)
}
Expand Down Expand Up @@ -490,10 +487,9 @@ func TestEvaluateResponsePolicy(t *testing.T) {

decodedBody any

expectedBody string
expectedErr bool
expectedErrMessage string
notAllowed bool
expectedBody string
expectedErr error
notAllowed bool
}

t.Run("evaluate response", func(t *testing.T) {
Expand Down Expand Up @@ -537,10 +533,9 @@ func TestEvaluateResponsePolicy(t *testing.T) {
false
body := input.response.body
}`,
expectedErr: true,
expectedErrMessage: "RBAC policy evaluation failed, user is not allowed",
expectedBody: "",
notAllowed: true,
expectedErr: ErrPolicyEvalFailed,
expectedBody: "",
notAllowed: true,
},
"with mongo query and body changed": {
method: http.MethodGet,
Expand Down Expand Up @@ -605,8 +600,8 @@ func TestEvaluateResponsePolicy(t *testing.T) {
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)

actual, err := evaluate.EvaluateResponsePolicy(context.Background(), rondInput, testCase.user, testCase.decodedBody)
if testCase.expectedErr {
require.EqualError(t, err, testCase.expectedErrMessage)
if testCase.expectedErr != nil {
require.EqualError(t, err, testCase.expectedErr.Error())
} else {
require.NoError(t, err)
}
Expand Down

0 comments on commit 48ec7dd

Please sign in to comment.