Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: reworked error messages #214

Merged
merged 5 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions core/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2023 Mia srl
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import "fmt"

var (
ErrMissingRegoModules = fmt.Errorf("no rego module found in directory")
ErrRegoModuleReadFailed = fmt.Errorf("failed rego file read")

ErrEvaluatorCreationFailed = fmt.Errorf("error during evaluator creation")
ErrEvaluatorNotFound = fmt.Errorf("evaluator not found")

ErrPolicyEvalFailed = fmt.Errorf("policy evaluation failed")
ErrPartialPolicyEvalFailed = fmt.Errorf("partial %w", ErrPolicyEvalFailed)
ErrResponsePolicyEvalFailed = fmt.Errorf("response %w", ErrPolicyEvalFailed)

ErrFailedInputParse = fmt.Errorf("failed input parse")
ErrFailedInputEncode = fmt.Errorf("failed input encode")
ErrFailedInputRequestParse = fmt.Errorf("failed request body parse")
ErrFailedInputRequestDeserialization = fmt.Errorf("failed request body deserialization")

ErrUnexepectedContentType = fmt.Errorf("unexpected content type")

ErrOPATransportInvalidResponseBody = fmt.Errorf("response body is not valid")
)
10 changes: 6 additions & 4 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ func CreateRegoQueryInput(

inputBytes, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("failed input JSON encode: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputEncode, err)
}
logger.Tracef("OPA input rego creation in: %+v", time.Since(opaInputCreationTime))
logger.
WithField("inputCreationTimeMicroseconds", time.Since(opaInputCreationTime).Microseconds()).
Tracef("input creation time")
return inputBytes, nil
}

Expand All @@ -131,10 +133,10 @@ func (req requestInfo) Input(user types.User, responseBody any) (Input, error) {
if shouldParseJSONBody {
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return Input{}, fmt.Errorf("failed request body parse: %s", err.Error())
return Input{}, fmt.Errorf("%w: %s", ErrFailedInputRequestParse, err.Error())
}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
return Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error())
return Input{}, fmt.Errorf("%w: %s", ErrFailedInputRequestDeserialization, err.Error())
}
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
Expand Down
6 changes: 3 additions & 3 deletions core/opa_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er

if !utils.HasApplicationJSONContentType(resp.Header) {
t.logger.WithField("foundContentType", resp.Header.Get(utils.ContentTypeHeaderKey)).Debug("found content type")
t.responseWithError(resp, fmt.Errorf("content-type is not application/json"), http.StatusInternalServerError)
t.responseWithError(resp, fmt.Errorf("%w: response content-type is not application/json", ErrUnexepectedContentType), http.StatusInternalServerError)
return resp, nil
}

var decodedBody interface{}
if err := json.Unmarshal(b, &decodedBody); err != nil {
return nil, fmt.Errorf("response body is not valid: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrOPATransportInvalidResponseBody, err.Error())
}

userInfo, err := mongoclient.RetrieveUserBindingsAndRoles(t.logger, t.request, t.userHeaders)
Expand All @@ -121,7 +121,7 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er
}

func (t *OPATransport) responseWithError(resp *http.Response, err error, statusCode int) {
t.logger.WithField("error", logrus.Fields{"message": err.Error()}).Error("error while evaluating column filter query")
t.logger.WithField("error", logrus.Fields{"message": err.Error()}).Error(ErrResponsePolicyEvalFailed)
message := utils.NO_PERMISSIONS_ERROR_MESSAGE
if statusCode != http.StatusForbidden {
message = utils.GENERIC_BUSINESS_ERROR_MESSAGE
Expand Down
63 changes: 38 additions & 25 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@ import (
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/rond-authz/rond/custom_builtins"
"github.com/rond-authz/rond/internal/metrics"
"github.com/rond-authz/rond/internal/mongoclient"
"github.com/rond-authz/rond/internal/opatranslator"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"

"github.com/rond-authz/rond/custom_builtins"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand Down Expand Up @@ -65,23 +64,29 @@ type PartialEvaluator struct {
}

func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (*PartialEvaluator, error) {
logger.Infof("precomputing rego query for allow policy: %s", policy)
logger.WithField("policyName", policy).Info("precomputing rego policy")

policyEvaluatorTime := time.Now()
partialResultEvaluator, err := NewPartialResultEvaluator(ctx, policy, opaModuleConfig, options)
if err == nil {
logger.Infof("computed rego query for policy: %s in %s", policy, time.Since(policyEvaluatorTime))
return &PartialEvaluator{
PartialEvaluator: partialResultEvaluator,
}, nil
if err != nil {
return nil, err
}
return nil, err

logger.
WithFields(logrus.Fields{
"policyName": policy,
"computationTimeMicroserconds": time.Since(policyEvaluatorTime).Microseconds,
}).
Info("precomputation time")

return &PartialEvaluator{PartialEvaluator: partialResultEvaluator}, nil
}

func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (PartialResultsEvaluators, error) {
if oas == nil {
return nil, fmt.Errorf("oas must not be nil")
}

policyEvaluators := PartialResultsEvaluators{}
for path, OASContent := range oas.Paths {
for verb, verbConfig := range OASContent {
Expand All @@ -92,17 +97,24 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope
allowPolicy := verbConfig.PermissionV2.RequestFlow.PolicyName
responsePolicy := verbConfig.PermissionV2.ResponseFlow.PolicyName

logger.Infof("precomputing rego queries for API: %s %s. Allow policy: %s. Response policy: %s.", verb, path, allowPolicy, responsePolicy)
logger.
WithFields(logrus.Fields{
"verb": verb,
"policyName": allowPolicy,
"path": path,
"responsePolicyName": responsePolicy,
}).
Info("precomputing rego queries for API")

if allowPolicy == "" {
// allow policy is required, if missing assume the API has no valid x-rond configuration.
continue
}

if _, ok := policyEvaluators[allowPolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, allowPolicy, oas, opaModuleConfig, options)

if err != nil {
return nil, fmt.Errorf("error during evaluator creation: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[allowPolicy] = *evaluator
Expand All @@ -111,9 +123,8 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope
if responsePolicy != "" {
if _, ok := policyEvaluators[responsePolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, responsePolicy, oas, opaModuleConfig, options)

if err != nil {
return nil, fmt.Errorf("error during evaluator creation: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[responsePolicy] = *evaluator
Expand Down Expand Up @@ -182,7 +193,7 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod
}
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("failed input parse: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

sanitizedPolicy := strings.Replace(policy, ".", "_", -1)
Expand Down Expand Up @@ -221,10 +232,12 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger
opaEvaluatorInstanceTime := time.Now()
evaluator, err := NewOPAEvaluator(ctx, policy, config, input, options)
if err != nil {
logger.WithError(err).Error("failed RBAC policy creation")
logger.WithError(err).Error(ErrEvaluatorCreationFailed)
return nil, err
}
logger.Tracef("OPA evaluator instantiated in: %+v", time.Since(opaEvaluatorInstanceTime))
logger.
WithField("evaluatorCreationTimeMicroseconds", time.Since(opaEvaluatorInstanceTime).Microseconds()).
Trace("evaluator creation time")
return evaluator, nil
}

Expand Down Expand Up @@ -266,7 +279,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
if eval, ok := partialEvaluators[policy]; ok {
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("failed input parse: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

evaluator := eval.PartialEvaluator.Rego(
Expand All @@ -284,7 +297,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
routerInfo: options.RouterInfo,
}, nil
}
return nil, fmt.Errorf("policy evaluator not found: %s", policy)
return nil, fmt.Errorf("%w: %s", ErrEvaluatorNotFound, policy)
}

func (evaluator *OPAEvaluator) metrics() metrics.Metrics {
Expand All @@ -298,7 +311,7 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.Context)
if err != nil {
return nil, fmt.Errorf("policy Evaluation has failed when partially evaluating the query: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
Expand Down Expand Up @@ -336,7 +349,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro

results, err := evaluator.PolicyEvaluator.Eval(evaluator.Context)
if err != nil {
return nil, fmt.Errorf("policy Evaluation has failed when evaluating the query: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
Expand Down Expand Up @@ -364,7 +377,7 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro
if allowed {
return responseBodyOverwriter, nil
}
return nil, fmt.Errorf("RBAC policy evaluation failed, user is not allowed")
return nil, ErrPolicyEvalFailed
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig) (interface{}, primitive.M, error) {
Expand Down Expand Up @@ -418,11 +431,11 @@ func LoadRegoModule(rootDirectory string) (*OPAModuleConfig, error) {
})

if regoModulePath == "" {
return nil, fmt.Errorf("no rego module found in directory")
return nil, ErrMissingRegoModules
}
fileContent, err := utils.ReadFile(regoModulePath)
if err != nil {
return nil, fmt.Errorf("failed rego file read: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrRegoModuleReadFailed, err.Error())
}

return &OPAModuleConfig{
Expand Down
37 changes: 16 additions & 21 deletions core/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
reqHeaders map[string]string
mongoClient types.IMongoClient

expectedPolicy PolicyResult
expectedErr bool
expectedErrMessage string
expectedPolicy PolicyResult
expectedErr error
}

t.Run("evaluate request", func(t *testing.T) {
Expand Down Expand Up @@ -196,9 +195,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
UserID: "my-user",
},

expectedPolicy: PolicyResult{},
expectedErr: true,
expectedErrMessage: "RBAC policy evaluation failed, user is not allowed",
expectedPolicy: PolicyResult{},
expectedErr: ErrPolicyEvalFailed,
},
"not allowed policy result": {
method: http.MethodGet,
Expand All @@ -208,9 +206,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
},
opaModuleContent: `package policies todo { false }`,

expectedPolicy: PolicyResult{},
expectedErr: true,
expectedErrMessage: "RBAC policy evaluation failed, user is not allowed",
expectedPolicy: PolicyResult{},
expectedErr: ErrPolicyEvalFailed,
},
"with empty filter query": {
method: http.MethodGet,
Expand Down Expand Up @@ -396,8 +393,8 @@ func TestEvaluateRequestPolicy(t *testing.T) {
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)

actual, err := evaluate.EvaluateRequestPolicy(context.Background(), rondInput, testCase.user)
if testCase.expectedErr {
require.EqualError(t, err, testCase.expectedErrMessage)
if testCase.expectedErr != nil {
require.EqualError(t, err, testCase.expectedErr.Error())
} else {
require.NoError(t, err)
}
Expand Down Expand Up @@ -490,10 +487,9 @@ func TestEvaluateResponsePolicy(t *testing.T) {

decodedBody any

expectedBody string
expectedErr bool
expectedErrMessage string
notAllowed bool
expectedBody string
expectedErr error
notAllowed bool
}

t.Run("evaluate response", func(t *testing.T) {
Expand Down Expand Up @@ -537,10 +533,9 @@ func TestEvaluateResponsePolicy(t *testing.T) {
false
body := input.response.body
}`,
expectedErr: true,
expectedErrMessage: "RBAC policy evaluation failed, user is not allowed",
expectedBody: "",
notAllowed: true,
expectedErr: ErrPolicyEvalFailed,
expectedBody: "",
notAllowed: true,
},
"with mongo query and body changed": {
method: http.MethodGet,
Expand Down Expand Up @@ -605,8 +600,8 @@ func TestEvaluateResponsePolicy(t *testing.T) {
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)

actual, err := evaluate.EvaluateResponsePolicy(context.Background(), rondInput, testCase.user, testCase.decodedBody)
if testCase.expectedErr {
require.EqualError(t, err, testCase.expectedErrMessage)
if testCase.expectedErr != nil {
require.EqualError(t, err, testCase.expectedErr.Error())
} else {
require.NoError(t, err)
}
Expand Down