Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
fredmaggiowski committed Jun 30, 2023
2 parents 98f1711 + 79fe754 commit 0fca669
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 95 deletions.
34 changes: 34 additions & 0 deletions core/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2023 Mia srl
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package core

import "fmt"

var (
ErrMissingRegoModules = fmt.Errorf("no rego module found in directory")
ErrRegoModuleReadFailed = fmt.Errorf("failed rego file read")

ErrEvaluatorCreationFailed = fmt.Errorf("error during evaluator creation")
ErrEvaluatorNotFound = fmt.Errorf("evaluator not found")

ErrPolicyEvalFailed = fmt.Errorf("policy evaluation failed")
ErrPartialPolicyEvalFailed = fmt.Errorf("partial %w", ErrPolicyEvalFailed)
ErrResponsePolicyEvalFailed = fmt.Errorf("response %w", ErrPolicyEvalFailed)

ErrFailedInputParse = fmt.Errorf("failed input parse")
ErrFailedInputEncode = fmt.Errorf("failed input encode")
ErrFailedInputRequestParse = fmt.Errorf("failed request body parse")
ErrFailedInputRequestDeserialization = fmt.Errorf("failed request body deserialization")
)
6 changes: 4 additions & 2 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ func CreateRegoQueryInput(

inputBytes, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("failed input JSON encode: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputEncode, err)
}
logger.Tracef("OPA input rego creation in: %+v", time.Since(opaInputCreationTime))
logger.
WithField("inputCreationTimeMicroseconds", time.Since(opaInputCreationTime).Microseconds()).
Tracef("input creation time")
return inputBytes, nil
}

Expand Down
119 changes: 72 additions & 47 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@ import (
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/rond-authz/rond/custom_builtins"
"github.com/rond-authz/rond/internal/metrics"
"github.com/rond-authz/rond/internal/mongoclient"
"github.com/rond-authz/rond/internal/opatranslator"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"

"github.com/rond-authz/rond/custom_builtins"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand Down Expand Up @@ -65,23 +64,29 @@ type PartialEvaluator struct {
}

func createPartialEvaluator(ctx context.Context, logger *logrus.Entry, policy string, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (*PartialEvaluator, error) {
logger.Infof("precomputing rego query for allow policy: %s", policy)
logger.WithField("policyName", policy).Info("precomputing rego policy")

policyEvaluatorTime := time.Now()
partialResultEvaluator, err := NewPartialResultEvaluator(ctx, policy, opaModuleConfig, options)
if err == nil {
logger.Infof("computed rego query for policy: %s in %s", policy, time.Since(policyEvaluatorTime))
return &PartialEvaluator{
PartialEvaluator: partialResultEvaluator,
}, nil
if err != nil {
return nil, err
}
return nil, err

logger.
WithFields(logrus.Fields{
"policyName": policy,
"computationTimeMicroserconds": time.Since(policyEvaluatorTime).Microseconds,
}).
Info("precomputation time")

return &PartialEvaluator{PartialEvaluator: partialResultEvaluator}, nil
}

func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.OpenAPISpec, opaModuleConfig *OPAModuleConfig, options *EvaluatorOptions) (PartialResultsEvaluators, error) {
if oas == nil {
return nil, fmt.Errorf("oas must not be nil")
}

policyEvaluators := PartialResultsEvaluators{}
for path, OASContent := range oas.Paths {
for verb, verbConfig := range OASContent {
Expand All @@ -92,17 +97,24 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope
allowPolicy := verbConfig.PermissionV2.RequestFlow.PolicyName
responsePolicy := verbConfig.PermissionV2.ResponseFlow.PolicyName

logger.Infof("precomputing rego queries for API: %s %s. Allow policy: %s. Response policy: %s.", verb, path, allowPolicy, responsePolicy)
logger.
WithFields(logrus.Fields{
"verb": verb,
"policyName": allowPolicy,
"path": path,
"responsePolicyName": responsePolicy,
}).
Info("precomputing rego queries for API")

if allowPolicy == "" {
// allow policy is required, if missing assume the API has no valid x-rond configuration.
continue
}

if _, ok := policyEvaluators[allowPolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, allowPolicy, oas, opaModuleConfig, options)

if err != nil {
return nil, fmt.Errorf("error during evaluator creation: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[allowPolicy] = *evaluator
Expand All @@ -111,9 +123,8 @@ func SetupEvaluators(ctx context.Context, logger *logrus.Entry, oas *openapi.Ope
if responsePolicy != "" {
if _, ok := policyEvaluators[responsePolicy]; !ok {
evaluator, err := createPartialEvaluator(ctx, logger, responsePolicy, oas, opaModuleConfig, options)

if err != nil {
return nil, fmt.Errorf("error during evaluator creation: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrEvaluatorCreationFailed, err.Error())
}

policyEvaluators[responsePolicy] = *evaluator
Expand Down Expand Up @@ -182,7 +193,7 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod
}
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("failed input parse: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

sanitizedPolicy := strings.Replace(policy, ".", "_", -1)
Expand Down Expand Up @@ -221,10 +232,12 @@ func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger
opaEvaluatorInstanceTime := time.Now()
evaluator, err := NewOPAEvaluator(ctx, policy, config, input, options)
if err != nil {
logger.WithError(err).Error("failed RBAC policy creation")
logger.WithError(err).Error(ErrEvaluatorCreationFailed)
return nil, err
}
logger.Tracef("OPA evaluator instantiated in: %+v", time.Since(opaEvaluatorInstanceTime))
logger.
WithField("evaluatorCreationTimeMicroseconds", time.Since(opaEvaluatorInstanceTime).Microseconds()).
Trace("evaluator creation time")
return evaluator, nil
}

Expand Down Expand Up @@ -266,7 +279,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
if eval, ok := partialEvaluators[policy]; ok {
inputTerm, err := ast.ParseTerm(string(input))
if err != nil {
return nil, fmt.Errorf("failed input parse: %v", err)
return nil, fmt.Errorf("%w: %v", ErrFailedInputParse, err)
}

evaluator := eval.PartialEvaluator.Rego(
Expand All @@ -284,7 +297,7 @@ func (partialEvaluators PartialResultsEvaluators) GetEvaluatorFromPolicy(ctx con
routerInfo: options.RouterInfo,
}, nil
}
return nil, fmt.Errorf("policy evaluator not found: %s", policy)
return nil, fmt.Errorf("%w: %s", ErrEvaluatorNotFound, policy)
}

func (evaluator *OPAEvaluator) metrics() metrics.Metrics {
Expand All @@ -298,7 +311,7 @@ func (evaluator *OPAEvaluator) partiallyEvaluate(logger *logrus.Entry) (primitiv
opaEvaluationTimeStart := time.Now()
partialResults, err := evaluator.PolicyEvaluator.Partial(evaluator.Context)
if err != nil {
return nil, fmt.Errorf("policy Evaluation has failed when partially evaluating the query: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrPartialPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
Expand Down Expand Up @@ -336,49 +349,35 @@ func (evaluator *OPAEvaluator) Evaluate(logger *logrus.Entry) (interface{}, erro

results, err := evaluator.PolicyEvaluator.Eval(evaluator.Context)
if err != nil {
return nil, fmt.Errorf("policy Evaluation has failed when evaluating the query: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrPolicyEvalFailed, err.Error())
}

opaEvaluationTime := time.Since(opaEvaluationTimeStart)
evaluator.metrics().PolicyEvaluationDurationMilliseconds.With(prometheus.Labels{
"policy_name": evaluator.PolicyName,
}).Observe(float64(opaEvaluationTime.Milliseconds()))

allowed, responseBodyOverwriter := processResults(results)
logger.WithFields(logrus.Fields{
"evaluationTimeMicroseconds": opaEvaluationTime.Microseconds(),
"policyName": evaluator.PolicyName,
"partialEval": false,
"allowed": results.Allowed(),
"allowed": allowed,
"resultsLength": len(results),
"matchedPath": evaluator.routerInfo.MatchedPath,
"requestedPath": evaluator.routerInfo.RequestedPath,
"method": evaluator.routerInfo.Method,
}).Debug("policy evaluation completed")

if results.Allowed() {
logger.WithFields(logrus.Fields{
"policyName": evaluator.PolicyName,
"allowed": results.Allowed(),
"resultsLength": len(results),
}).Tracef("policy results")
return nil, nil
}
// The results returned by OPA are a list of Results object with fields:
// - Expressions: list of list
// - Bindings: object
// e.g. [{Expressions:[[map["element": true]]] Bindings:map[]}]
// Since we are ALWAYS querying ONE specific policy the result length could not be greater than 1
if len(results) == 1 {
if exprs := results[0].Expressions; len(exprs) == 1 {
if value, ok := exprs[0].Value.([]interface{}); ok && value != nil && len(value) != 0 {
return value[0], nil
}
}
}
logger.WithFields(logrus.Fields{
"policyName": evaluator.PolicyName,
}).Error("policy resulted in not allowed")
return nil, fmt.Errorf("RBAC policy evaluation failed, user is not allowed")
"allowed": allowed,
}).Info("policy result")

if allowed {
return responseBodyOverwriter, nil
}
return nil, ErrPolicyEvalFailed
}

func (evaluator *OPAEvaluator) PolicyEvaluation(logger *logrus.Entry, permission *openapi.RondConfig) (interface{}, primitive.M, error) {
Expand Down Expand Up @@ -432,15 +431,41 @@ func LoadRegoModule(rootDirectory string) (*OPAModuleConfig, error) {
})

if regoModulePath == "" {
return nil, fmt.Errorf("no rego module found in directory")
return nil, ErrMissingRegoModules
}
fileContent, err := utils.ReadFile(regoModulePath)
if err != nil {
return nil, fmt.Errorf("failed rego file read: %s", err.Error())
return nil, fmt.Errorf("%w: %s", ErrRegoModuleReadFailed, err.Error())
}

return &OPAModuleConfig{
Name: filepath.Base(regoModulePath),
Content: string(fileContent),
}, nil
}

func processResults(results rego.ResultSet) (allowed bool, responseBodyOverwriter any) {
// Use strict allowed check for basic request flow allow policies.
if results.Allowed() {
allowed = true
return
}

// Here extract first result set to get the response body for the response policy evaluation.
// The results returned by OPA are a list of Results object with fields:
// - Expressions: list of list
// - Bindings: object
// e.g. [{Expressions:[[map["element": true]]] Bindings:map[]}]
// Since we are ALWAYS querying ONE specific policy the result length could not be greater than 1
if len(results) == 1 {
if exprs := results[0].Expressions; len(exprs) == 1 {
if value, ok := exprs[0].Value.([]interface{}); ok && value != nil && len(value) != 0 {
allowed = true
responseBodyOverwriter = value[0]
return
}
}
}

return
}
Loading

0 comments on commit 0fca669

Please sign in to comment.