Skip to content

Commit

Permalink
feat: clean interface
Browse files Browse the repository at this point in the history
  • Loading branch information
davidebianchi committed Jun 26, 2023
1 parent a002292 commit 614942e
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 43 deletions.
11 changes: 0 additions & 11 deletions core/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package core

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -115,8 +114,6 @@ func CreateRegoQueryInput(

type RondInput interface {
FromRequestInfo(user types.User, responseBody any) (Input, error)
Context() context.Context
OriginalRequest() *http.Request
}

type requestInfo struct {
Expand Down Expand Up @@ -164,14 +161,6 @@ func (req requestInfo) FromRequestInfo(user types.User, responseBody any) (Input
}, nil
}

func (r requestInfo) Context() context.Context {
return r.Request.Context()
}

func (r requestInfo) OriginalRequest() *http.Request {
return r.Request
}

func NewRondInput(req *http.Request, clientTypeHeaderKey string, pathParams map[string]string) RondInput {
return requestInfo{
Request: req,
Expand Down
16 changes: 0 additions & 16 deletions core/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,21 +320,5 @@ func TestRondInput(t *testing.T) {
require.NoError(t, err, "Unexpected error")
require.Nil(t, input.Request.Body)
})

t.Run("get context", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)

rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
ctx := rondRequest.Context()
require.Equal(t, req.Context(), ctx)
})

t.Run("get original request", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)

rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams)
originalRequest := rondRequest.OriginalRequest()
require.Equal(t, req, originalRequest)
})
})
}
3 changes: 1 addition & 2 deletions core/opaevaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -192,7 +191,7 @@ func NewOPAEvaluator(ctx context.Context, policy string, opaModuleConfig *OPAMod
}, nil
}

func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger *logrus.Entry, req *http.Request, policy string, input []byte, responseBody interface{}, options *EvaluatorOptions) (*OPAEvaluator, error) {
func (config *OPAModuleConfig) CreateQueryEvaluator(ctx context.Context, logger *logrus.Entry, policy string, input []byte, options *EvaluatorOptions) (*OPAEvaluator, error) {
// TODO: remove logger and set in sdk
logger.WithFields(logrus.Fields{
"policyName": policy,
Expand Down
6 changes: 2 additions & 4 deletions core/opaevaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,20 @@ column_policy{

opaModuleConfig := &OPAModuleConfig{Name: "mypolicy.rego", Content: policy}

r, err := http.NewRequestWithContext(context.Background(), "GET", "http://www.example.com:8080/api", nil)
require.NoError(t, err, "Unexpected error")
log, _ := test.NewNullLogger()
logger := logrus.NewEntry(log)

input := Input{Request: InputRequest{}, Response: InputResponse{}}
inputBytes, _ := json.Marshal(input)

t.Run("create evaluator with allowPolicy", func(t *testing.T) {
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, r, permission.AllowPermission, inputBytes, nil, nil)
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, permission.AllowPermission, inputBytes, nil)
require.True(t, evaluator != nil)
require.NoError(t, err, "Unexpected status code.")
})

t.Run("create evaluator with policy for column filtering", func(t *testing.T) {
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, r, permission.ResponseFilter.Policy, inputBytes, nil, nil)
evaluator, err := opaModuleConfig.CreateQueryEvaluator(context.Background(), logger, permission.ResponseFilter.Policy, inputBytes, nil)
require.True(t, evaluator != nil)
require.NoError(t, err, "Unexpected status code.")
})
Expand Down
9 changes: 4 additions & 5 deletions core/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type SDKEvaluator interface {
Config() openapi.RondConfig
PartialResultsEvaluators() PartialResultsEvaluators

EvaluateRequestPolicy(req RondInput, userInfo types.User) (PolicyResult, error)
EvaluateRequestPolicy(ctx context.Context, req RondInput, userInfo types.User) (PolicyResult, error)
EvaluateResponsePolicy(ctx context.Context, rondInput RondInput, userInfo types.User, decodedBody any) ([]byte, error)
}

Expand All @@ -69,13 +69,12 @@ func (e evaluator) PartialResultsEvaluators() PartialResultsEvaluators {
return e.rond.evaluator
}

func (e evaluator) EvaluateRequestPolicy(req RondInput, userInfo types.User) (PolicyResult, error) {
func (e evaluator) EvaluateRequestPolicy(ctx context.Context, req RondInput, userInfo types.User) (PolicyResult, error) {
if req == nil {
return PolicyResult{}, fmt.Errorf("RondInput cannot be empty")
}

rondConfig := e.Config()
requestContext := req.Context()

input, err := req.FromRequestInfo(userInfo, nil)
if err != nil {
Expand All @@ -91,12 +90,12 @@ func (e evaluator) EvaluateRequestPolicy(req RondInput, userInfo types.User) (Po

var evaluatorAllowPolicy *OPAEvaluator
if !rondConfig.RequestFlow.GenerateQuery {
evaluatorAllowPolicy, err = e.rond.evaluator.GetEvaluatorFromPolicy(requestContext, rondConfig.RequestFlow.PolicyName, regoInput, e.rond.evaluatorOptions)
evaluatorAllowPolicy, err = e.rond.evaluator.GetEvaluatorFromPolicy(ctx, rondConfig.RequestFlow.PolicyName, regoInput, e.rond.evaluatorOptions)
if err != nil {
return PolicyResult{}, err
}
} else {
evaluatorAllowPolicy, err = e.rond.opaModuleConfig.CreateQueryEvaluator(requestContext, e.logger, req.OriginalRequest(), rondConfig.RequestFlow.PolicyName, regoInput, nil, e.rond.evaluatorOptions)
evaluatorAllowPolicy, err = e.rond.opaModuleConfig.CreateQueryEvaluator(ctx, e.logger, rondConfig.RequestFlow.PolicyName, regoInput, e.rond.evaluatorOptions)
if err != nil {
return PolicyResult{}, err
}
Expand Down
6 changes: 3 additions & 3 deletions core/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestEvaluateRequestPolicy(t *testing.T) {
evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/users/")
require.NoError(t, err)

_, err = evaluator.EvaluateRequestPolicy(nil, types.User{})
_, err = evaluator.EvaluateRequestPolicy(context.Background(), nil, types.User{})
require.EqualError(t, err, "RondInput cannot be empty")
})

Expand Down Expand Up @@ -400,7 +400,7 @@ func TestEvaluateRequestPolicy(t *testing.T) {
}
rondInput := NewRondInput(req, clientTypeHeaderKey, nil)

actual, err := evaluator.EvaluateRequestPolicy(rondInput, test.user)
actual, err := evaluator.EvaluateRequestPolicy(context.Background(), rondInput, test.user)
if test.expectedErr {
require.EqualError(t, err, test.expectedErrMessage)
} else {
Expand Down Expand Up @@ -600,7 +600,7 @@ func BenchmarkEvaluateRequest(b *testing.B) {
b.StartTimer()
evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/projects/project123")
require.NoError(b, err)
evaluator.EvaluateRequestPolicy(rondInput, types.User{})
evaluator.EvaluateRequestPolicy(context.Background(), rondInput, types.User{})
b.StopTimer()
require.Equal(b, http.StatusOK, recorder.Code)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/fake/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewSDKEvaluator(
}
}

func (s SDKEvaluator) EvaluateRequestPolicy(req core.RondInput, userInfo types.User) (core.PolicyResult, error) {
func (s SDKEvaluator) EvaluateRequestPolicy(ctx context.Context, req core.RondInput, userInfo types.User) (core.PolicyResult, error) {
if s.requestPolicyEvaluatorResult == nil {
return core.PolicyResult{}, nil
}
Expand Down
2 changes: 1 addition & 1 deletion service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func EvaluateRequest(
}

rondInput := core.NewRondInput(req, env.ClientTypeHeader, mux.Vars(req))
result, err := evaluatorSdk.EvaluateRequestPolicy(rondInput, userInfo)
result, err := evaluatorSdk.EvaluateRequestPolicy(req.Context(), rondInput, userInfo)
if err != nil {
if errors.Is(err, opatranslator.ErrEmptyQuery) && utils.HasApplicationJSONContentType(req.Header) {
w.Header().Set(utils.ContentTypeHeaderKey, utils.JSONContentTypeHeader)
Expand Down

0 comments on commit 614942e

Please sign in to comment.