Skip to content

Commit

Permalink
update PR
Browse files Browse the repository at this point in the history
  • Loading branch information
davidebianchi committed Jul 10, 2023
1 parent 0898817 commit 41c14b3
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 23 deletions.
12 changes: 8 additions & 4 deletions sdk/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@ package sdk

import (
"context"
"fmt"
"errors"
)

var (
ErrGetEvaluator = errors.New("no Evaluator found in request context")
)

type sdkKey struct{}

func WithEvaluatorSDK(ctx context.Context, evaluator Evaluator) context.Context {
func WithEvaluator(ctx context.Context, evaluator Evaluator) context.Context {
return context.WithValue(ctx, sdkKey{}, evaluator)
}

func GetEvaluatorSKD(ctx context.Context) (Evaluator, error) {
func GetEvaluator(ctx context.Context) (Evaluator, error) {
sdk, ok := ctx.Value(sdkKey{}).(Evaluator)
if !ok {
return nil, fmt.Errorf("no SDKEvaluator found in request context")
return nil, ErrGetEvaluator
}

return sdk, nil
Expand Down
7 changes: 4 additions & 3 deletions sdk/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"

"github.com/rond-authz/rond/openapi"

"github.com/stretchr/testify/require"
)

Expand All @@ -39,15 +40,15 @@ func TestContext(t *testing.T) {
rondConfig: rondConfig,
}

ctx = WithEvaluatorSDK(ctx, expectedEvaluator)
ctx = WithEvaluator(ctx, expectedEvaluator)

actualEvaluator, err := GetEvaluatorSKD(ctx)
actualEvaluator, err := GetEvaluator(ctx)
require.NoError(t, err)
require.Equal(t, expectedEvaluator, actualEvaluator)
})

t.Run("throws if not in context", func(t *testing.T) {
actualEvaluator, err := GetEvaluatorSKD(context.Background())
actualEvaluator, err := GetEvaluator(context.Background())
require.EqualError(t, err, "no SDKEvaluator found in request context")
require.Nil(t, actualEvaluator)
})
Expand Down
4 changes: 3 additions & 1 deletion sdk/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/rond-authz/rond/core"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"

"github.com/sirupsen/logrus"
)

Expand All @@ -31,8 +32,9 @@ type PolicyResult struct {
}

// Warning: This interface is experimental, and it could change with breaking also in rond patches.
// Do not use outside this repository until it is not ready.
// Do not use outside this repository until it is ready.
type Evaluator interface {
// retrieve the RondConfig used to generate the evaluator
Config() openapi.RondConfig

EvaluateRequestPolicy(ctx context.Context, input core.RondInput, userInfo types.User) (PolicyResult, error)
Expand Down
1 change: 1 addition & 0 deletions sdk/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package sdk
import (
"github.com/rond-authz/rond/core"
"github.com/rond-authz/rond/openapi"

"github.com/sirupsen/logrus"
"github.com/uptrace/bunrouter"
)
Expand Down
3 changes: 2 additions & 1 deletion sdk/openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ import (
"net/http"
"testing"

"github.com/prometheus/client_golang/prometheus"
"github.com/rond-authz/rond/core"
"github.com/rond-authz/rond/openapi"

"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
Expand Down
1 change: 1 addition & 0 deletions sdk/rondinput/http/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/types"

"github.com/stretchr/testify/require"
)

Expand Down
2 changes: 1 addition & 1 deletion sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewFromOAS(ctx context.Context, opaModuleConfig *core.OPAModuleConfig, oas
options = &FromOASOptions{}
}
if options.Logger == nil {
// TODO: default to a logger instead of return error
// TODO: default to a fake silent logger instead of return error
return nil, fmt.Errorf("logger is required inside options")
}

Expand Down
4 changes: 2 additions & 2 deletions sdk/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
"github.com/rond-authz/rond/internal/mocks"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/types"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"

"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)

Expand Down
2 changes: 1 addition & 1 deletion service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func rbacHandler(w http.ResponseWriter, req *http.Request) {
return
}

evaluatorSdk, err := sdk.GetEvaluatorSKD(requestContext)
evaluatorSdk, err := sdk.GetEvaluator(requestContext)
if err != nil {
logger.WithField("error", logrus.Fields{"message": err.Error()}).Error("no evaluatorSdk found in context")
utils.FailResponse(w, "no evaluators sdk found in context", utils.GENERIC_BUSINESS_ERROR_MESSAGE)
Expand Down
3 changes: 2 additions & 1 deletion service/opa_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import (
"strconv"
"testing"

"github.com/prometheus/client_golang/prometheus"
"github.com/rond-authz/rond/core"
"github.com/rond-authz/rond/internal/mocks"
"github.com/rond-authz/rond/internal/mongoclient"
"github.com/rond-authz/rond/internal/utils"
"github.com/rond-authz/rond/openapi"
"github.com/rond-authz/rond/sdk"
"github.com/rond-authz/rond/types"

"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
Expand Down
2 changes: 1 addition & 1 deletion service/opamiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func OPAMiddleware(
return
}

ctx := sdk.WithEvaluatorSDK(r.Context(), evaluator)
ctx := sdk.WithEvaluator(r.Context(), evaluator)

next.ServeHTTP(w, r.WithContext(ctx))
})
Expand Down
14 changes: 7 additions & 7 deletions service/opamiddleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ foobar { true }`,

middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", nil)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actual, err := sdk.GetEvaluatorSKD(r.Context())
actual, err := sdk.GetEvaluator(r.Context())
require.NoError(t, err, "Unexpected error")
require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "todo"}}, actual.Config())
w.WriteHeader(http.StatusOK)
Expand All @@ -209,7 +209,7 @@ foobar { true }`,

middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", nil)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actual, err := sdk.GetEvaluatorSKD(r.Context())
actual, err := sdk.GetEvaluator(r.Context())
require.NoError(t, err)
require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "todo"}}, actual.Config())
w.WriteHeader(http.StatusOK)
Expand All @@ -232,7 +232,7 @@ very_very_composed_permission { true }`,

middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", nil)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actual, err := sdk.GetEvaluatorSKD(r.Context())
actual, err := sdk.GetEvaluator(r.Context())
require.NoError(t, err)
require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission"}}, actual.Config())
w.WriteHeader(http.StatusOK)
Expand Down Expand Up @@ -260,7 +260,7 @@ very_very_composed_permission_with_eval { true }`,

middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", options)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actual, err := sdk.GetEvaluatorSKD(r.Context())
actual, err := sdk.GetEvaluator(r.Context())
require.NoError(t, err)
require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission.with.eval"}}, actual.Config())
w.WriteHeader(http.StatusOK)
Expand All @@ -278,7 +278,7 @@ very_very_composed_permission_with_eval { true }`,
routesNotToProxy := []string{"/not/proxy"}
middleware := OPAMiddleware(nil, nil, routesNotToProxy, "", nil)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sdk.GetEvaluatorSKD(r.Context())
_, err := sdk.GetEvaluator(r.Context())
require.EqualError(t, err, "no SDKEvaluator found in request context")
}))

Expand Down Expand Up @@ -322,7 +322,7 @@ func TestOPAMiddlewareStandaloneIntegration(t *testing.T) {
rondSDK := getSdk(t, opaModule)
middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", options)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actual, err := sdk.GetEvaluatorSKD(r.Context())
actual, err := sdk.GetEvaluator(r.Context())
require.NoError(t, err)
require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission"}}, actual.Config())
w.WriteHeader(http.StatusOK)
Expand All @@ -345,7 +345,7 @@ very_very_composed_permission_with_eval { true }`,
rondSDK := getSdk(t, opaModule)
middleware := OPAMiddleware(opaModule, rondSDK, routesNotToProxy, "", options)
builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actual, err := sdk.GetEvaluatorSKD(r.Context())
actual, err := sdk.GetEvaluator(r.Context())
require.NoError(t, err)
require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission.with.eval"}}, actual.Config())
w.WriteHeader(http.StatusOK)
Expand Down
2 changes: 1 addition & 1 deletion service/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func createContext(
var partialContext context.Context
partialContext = context.WithValue(originalCtx, config.EnvKey{}, env)

partialContext = sdk.WithEvaluatorSDK(partialContext, evaluator)
partialContext = sdk.WithEvaluator(partialContext, evaluator)

if mongoClient != nil {
partialContext = context.WithValue(partialContext, types.MongoClientContextKey{}, mongoClient)
Expand Down

0 comments on commit 41c14b3

Please sign in to comment.