From 618e5f1f4b4eff4f9567bd6a093567a4afa80703 Mon Sep 17 00:00:00 2001 From: Davide Bianchi Date: Thu, 29 Jun 2023 19:47:15 +0200 Subject: [PATCH] feat: move core and mux specific function in specific package --- core/input.go | 56 ----------- core/input_test.go | 103 +++++--------------- core/sdk_test.go | 42 +++++--- routers/mux/input.go | 80 +++++++++++++++ routers/mux/input_test.go | 103 ++++++++++++++++++++ {core => routers/mux}/opamiddleware.go | 9 +- {core => routers/mux}/opamiddleware_test.go | 53 +++++----- service/handler.go | 5 +- {core => service}/opa_transport.go | 10 +- {core => service}/opa_transport_test.go | 51 +++++++++- service/router.go | 3 +- 11 files changed, 332 insertions(+), 183 deletions(-) create mode 100644 routers/mux/input.go create mode 100644 routers/mux/input_test.go rename {core => routers/mux}/opamiddleware.go (94%) rename {core => routers/mux}/opamiddleware_test.go (90%) rename {core => service}/opa_transport.go (94%) rename {core => service}/opa_transport_test.go (92%) diff --git a/core/input.go b/core/input.go index 4a2b9318..67d82dd0 100644 --- a/core/input.go +++ b/core/input.go @@ -15,15 +15,12 @@ package core import ( - "bytes" "encoding/json" "fmt" - "io" "net/http" "net/url" "time" - "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/types" "github.com/sirupsen/logrus" @@ -115,56 +112,3 @@ func CreateRegoQueryInput( type RondInput interface { Input(user types.User, responseBody any) (Input, error) } - -type requestInfo struct { - *http.Request - clientTypeHeaderKey string - pathParams map[string]string -} - -func (req requestInfo) Input(user types.User, responseBody any) (Input, error) { - shouldParseJSONBody := utils.HasApplicationJSONContentType(req.Header) && - req.ContentLength > 0 && - (req.Method == http.MethodPatch || req.Method == http.MethodPost || req.Method == http.MethodPut || req.Method == http.MethodDelete) - - var requestBody any - if shouldParseJSONBody { - bodyBytes, err := io.ReadAll(req.Body) - if err != nil { - return Input{}, fmt.Errorf("failed request body parse: %s", err.Error()) - } - if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - return Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error()) - } - req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - } - - return Input{ - ClientType: req.Header.Get(req.clientTypeHeaderKey), - Request: InputRequest{ - Method: req.Method, - Path: req.URL.Path, - Headers: req.Header, - Query: req.URL.Query(), - PathParams: req.pathParams, - Body: requestBody, - }, - Response: InputResponse{ - Body: responseBody, - }, - User: InputUser{ - Properties: user.Properties, - Groups: user.UserGroups, - Bindings: user.UserBindings, - Roles: user.UserRoles, - }, - }, nil -} - -func NewRondInput(req *http.Request, clientTypeHeaderKey string, pathParams map[string]string) RondInput { - return requestInfo{ - Request: req, - clientTypeHeaderKey: clientTypeHeaderKey, - pathParams: pathParams, - } -} diff --git a/core/input_test.go b/core/input_test.go index 5da56d4f..5b9d4c8d 100644 --- a/core/input_test.go +++ b/core/input_test.go @@ -15,14 +15,9 @@ package core import ( - "bytes" - "encoding/json" "fmt" - "net/http" - "net/http/httptest" "testing" - "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/types" "github.com/sirupsen/logrus" @@ -247,78 +242,34 @@ func BenchmarkBuildOptimizedResourcePermissionsMap(b *testing.B) { } } -func TestRondInput(t *testing.T) { - user := types.User{} - clientTypeHeaderKey := "clienttypeheader" - pathParams := map[string]string{} - - t.Run("request body integration", func(t *testing.T) { - expectedRequestBody := map[string]interface{}{ - "Key": float64(42), - } - reqBody := struct{ Key int }{ - Key: 42, - } - reqBodyBytes, err := json.Marshal(reqBody) - require.Nil(t, err, "Unexpected error") - - t.Run("ignored on method GET", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(reqBodyBytes)) - - rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) - input, err := rondRequest.Input(user, nil) - require.NoError(t, err, "Unexpected error") - require.Nil(t, input.Request.Body) - }) - - t.Run("ignore nil body on method POST", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/", nil) - req.Header.Set(utils.ContentTypeHeaderKey, "application/json") - - rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) - input, err := rondRequest.Input(user, nil) - require.NoError(t, err, "Unexpected error") - require.Nil(t, input.Request.Body) - }) - - t.Run("added on accepted methods", func(t *testing.T) { - acceptedMethods := []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} - - for _, method := range acceptedMethods { - req := httptest.NewRequest(method, "/", bytes.NewReader(reqBodyBytes)) - req.Header.Set(utils.ContentTypeHeaderKey, "application/json") - rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) - input, err := rondRequest.Input(user, nil) - require.NoError(t, err, "Unexpected error") - require.Equal(t, expectedRequestBody, input.Request.Body) - } - }) - - t.Run("added with content-type specifying charset", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBodyBytes)) - req.Header.Set(utils.ContentTypeHeaderKey, "application/json;charset=UTF-8") - rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) - input, err := rondRequest.Input(user, nil) - require.NoError(t, err, "Unexpected error") - require.Equal(t, expectedRequestBody, input.Request.Body) - }) +type FakeInput struct { + request InputRequest + clientType string +} - t.Run("reject on method POST but with invalid body", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("{notajson}"))) - req.Header.Set(utils.ContentTypeHeaderKey, "application/json") - rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) - _, err := rondRequest.Input(user, nil) - require.ErrorContains(t, err, "failed request body deserialization:") - }) +func (i FakeInput) Input(user types.User, responseBody any) (Input, error) { + return Input{ + User: InputUser{ + Properties: user.Properties, + Groups: user.UserGroups, + Bindings: user.UserBindings, + Roles: user.UserRoles, + }, + Request: i.request, + Response: InputResponse{ + Body: responseBody, + }, + ClientType: i.clientType, + }, nil +} - t.Run("ignore body on method POST but with another content type", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("{notajson}"))) - req.Header.Set(utils.ContentTypeHeaderKey, "multipart/form-data") +func getFakeInput(t require.TestingT, request InputRequest, clientType string) RondInput { + if h, ok := t.(tHelper); ok { + h.Helper() + } - rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) - input, err := rondRequest.Input(user, nil) - require.NoError(t, err, "Unexpected error") - require.Nil(t, input.Request.Body) - }) - }) + return FakeInput{ + request: request, + clientType: clientType, + } } diff --git a/core/sdk_test.go b/core/sdk_test.go index 8d7cfcac..d9e03829 100644 --- a/core/sdk_test.go +++ b/core/sdk_test.go @@ -141,8 +141,6 @@ func TestSDK(t *testing.T) { func TestEvaluateRequestPolicy(t *testing.T) { logger := logrus.NewEntry(logrus.New()) - clientTypeHeaderKey := "client-header-key" - t.Run("throws without RondInput", func(t *testing.T) { sdk := getSdk(t, nil) evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/users/") @@ -387,13 +385,17 @@ func TestEvaluateRequestPolicy(t *testing.T) { evaluate, err := sdk.FindEvaluator(logger, testCase.method, testCase.path) require.NoError(t, err) - req := httptest.NewRequest(testCase.method, testCase.path, nil) + headers := http.Header{} if testCase.reqHeaders != nil { for k, v := range testCase.reqHeaders { - req.Header.Set(k, v) + headers.Set(k, v) } } - rondInput := NewRondInput(req, clientTypeHeaderKey, nil) + rondInput := getFakeInput(t, InputRequest{ + Headers: headers, + Path: testCase.path, + Method: testCase.method, + }, "") actual, err := evaluate.EvaluateRequestPolicy(context.Background(), rondInput, testCase.user) if testCase.expectedErr { @@ -469,8 +471,6 @@ func assertCorrectMetrics(t *testing.T, registry *prometheus.Registry, expected func TestEvaluateResponsePolicy(t *testing.T) { logger := logrus.NewEntry(logrus.New()) - clientTypeHeaderKey := "client-header-key" - t.Run("throws without RondInput", func(t *testing.T) { sdk := getSdk(t, nil) evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/users/") @@ -587,7 +587,17 @@ func TestEvaluateResponsePolicy(t *testing.T) { req.Header.Set(k, v) } } - rondInput := NewRondInput(req, clientTypeHeaderKey, nil) + headers := http.Header{} + if testCase.reqHeaders != nil { + for k, v := range testCase.reqHeaders { + headers.Set(k, v) + } + } + rondInput := getFakeInput(t, InputRequest{ + Headers: headers, + Path: testCase.path, + Method: testCase.method, + }, "") actual, err := evaluate.EvaluateResponsePolicy(context.Background(), rondInput, testCase.user, testCase.decodedBody) if testCase.expectedErr { @@ -677,12 +687,18 @@ func BenchmarkEvaluateRequest(b *testing.B) { for n := 0; n < b.N; n++ { b.StopTimer() - req := httptest.NewRequest(http.MethodGet, "/projects/project123", nil) - req.Header.Set("my-header", "value") + headers := http.Header{} + headers.Set("my-header", "value") recorder := httptest.NewRecorder() - rondInput := NewRondInput(req, "", map[string]string{ - "projectId": "project123", - }) + + rondInput := getFakeInput(b, InputRequest{ + Path: "/projects/project123", + Headers: headers, + Method: http.MethodGet, + PathParams: map[string]string{ + "projectId": "project123", + }, + }, "") b.StartTimer() evaluator, err := sdk.FindEvaluator(logger, http.MethodGet, "/projects/project123") require.NoError(b, err) diff --git a/routers/mux/input.go b/routers/mux/input.go new file mode 100644 index 00000000..30c83ef9 --- /dev/null +++ b/routers/mux/input.go @@ -0,0 +1,80 @@ +// Copyright 2023 Mia srl +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rondmux + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/rond-authz/rond/core" + "github.com/rond-authz/rond/internal/utils" + "github.com/rond-authz/rond/types" +) + +type requestInfo struct { + *http.Request + clientTypeHeaderKey string + pathParams map[string]string +} + +func (req requestInfo) Input(user types.User, responseBody any) (core.Input, error) { + shouldParseJSONBody := utils.HasApplicationJSONContentType(req.Header) && + req.ContentLength > 0 && + (req.Method == http.MethodPatch || req.Method == http.MethodPost || req.Method == http.MethodPut || req.Method == http.MethodDelete) + + var requestBody any + if shouldParseJSONBody { + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return core.Input{}, fmt.Errorf("failed request body parse: %s", err.Error()) + } + if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { + return core.Input{}, fmt.Errorf("failed request body deserialization: %s", err.Error()) + } + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + return core.Input{ + ClientType: req.Header.Get(req.clientTypeHeaderKey), + Request: core.InputRequest{ + Method: req.Method, + Path: req.URL.Path, + Headers: req.Header, + Query: req.URL.Query(), + PathParams: req.pathParams, + Body: requestBody, + }, + Response: core.InputResponse{ + Body: responseBody, + }, + User: core.InputUser{ + Properties: user.Properties, + Groups: user.UserGroups, + Bindings: user.UserBindings, + Roles: user.UserRoles, + }, + }, nil +} + +func NewRondInput(req *http.Request, clientTypeHeaderKey string, pathParams map[string]string) core.RondInput { + return requestInfo{ + Request: req, + clientTypeHeaderKey: clientTypeHeaderKey, + pathParams: pathParams, + } +} diff --git a/routers/mux/input_test.go b/routers/mux/input_test.go new file mode 100644 index 00000000..01d315b2 --- /dev/null +++ b/routers/mux/input_test.go @@ -0,0 +1,103 @@ +// Copyright 2023 Mia srl +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rondmux + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/rond-authz/rond/internal/utils" + "github.com/rond-authz/rond/types" + "github.com/stretchr/testify/require" +) + +func TestRondInput(t *testing.T) { + user := types.User{} + clientTypeHeaderKey := "clienttypeheader" + pathParams := map[string]string{} + + t.Run("request body integration", func(t *testing.T) { + expectedRequestBody := map[string]interface{}{ + "Key": float64(42), + } + reqBody := struct{ Key int }{ + Key: 42, + } + reqBodyBytes, err := json.Marshal(reqBody) + require.Nil(t, err, "Unexpected error") + + t.Run("ignored on method GET", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(reqBodyBytes)) + + rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) + input, err := rondRequest.Input(user, nil) + require.NoError(t, err, "Unexpected error") + require.Nil(t, input.Request.Body) + }) + + t.Run("ignore nil body on method POST", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(utils.ContentTypeHeaderKey, "application/json") + + rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) + input, err := rondRequest.Input(user, nil) + require.NoError(t, err, "Unexpected error") + require.Nil(t, input.Request.Body) + }) + + t.Run("added on accepted methods", func(t *testing.T) { + acceptedMethods := []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} + + for _, method := range acceptedMethods { + req := httptest.NewRequest(method, "/", bytes.NewReader(reqBodyBytes)) + req.Header.Set(utils.ContentTypeHeaderKey, "application/json") + rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) + input, err := rondRequest.Input(user, nil) + require.NoError(t, err, "Unexpected error") + require.Equal(t, expectedRequestBody, input.Request.Body) + } + }) + + t.Run("added with content-type specifying charset", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBodyBytes)) + req.Header.Set(utils.ContentTypeHeaderKey, "application/json;charset=UTF-8") + rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) + input, err := rondRequest.Input(user, nil) + require.NoError(t, err, "Unexpected error") + require.Equal(t, expectedRequestBody, input.Request.Body) + }) + + t.Run("reject on method POST but with invalid body", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("{notajson}"))) + req.Header.Set(utils.ContentTypeHeaderKey, "application/json") + rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) + _, err := rondRequest.Input(user, nil) + require.ErrorContains(t, err, "failed request body deserialization:") + }) + + t.Run("ignore body on method POST but with another content type", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("{notajson}"))) + req.Header.Set(utils.ContentTypeHeaderKey, "multipart/form-data") + + rondRequest := NewRondInput(req, clientTypeHeaderKey, pathParams) + input, err := rondRequest.Input(user, nil) + require.NoError(t, err, "Unexpected error") + require.Nil(t, input.Request.Body) + }) + }) +} diff --git a/core/opamiddleware.go b/routers/mux/opamiddleware.go similarity index 94% rename from core/opamiddleware.go rename to routers/mux/opamiddleware.go index 62adc022..a1fbb717 100644 --- a/core/opamiddleware.go +++ b/routers/mux/opamiddleware.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package rondmux import ( "errors" "net/http" "strings" + "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/openapi" @@ -33,8 +34,8 @@ type OPAMiddlewareOptions struct { } func OPAMiddleware( - opaModuleConfig *OPAModuleConfig, - sdk SDK, + opaModuleConfig *core.OPAModuleConfig, + sdk core.SDK, routesToNotProxy []string, targetServiceOASPath string, options *OPAMiddlewareOptions, @@ -91,7 +92,7 @@ func OPAMiddleware( return } - ctx := WithEvaluatorSDK(r.Context(), evaluator) + ctx := core.WithEvaluatorSDK(r.Context(), evaluator) next.ServeHTTP(w, r.WithContext(ctx)) }) diff --git a/core/opamiddleware_test.go b/routers/mux/opamiddleware_test.go similarity index 90% rename from core/opamiddleware_test.go rename to routers/mux/opamiddleware_test.go index c61806e3..1b8fa3c8 100644 --- a/core/opamiddleware_test.go +++ b/routers/mux/opamiddleware_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package rondmux import ( "context" @@ -23,6 +23,7 @@ import ( "os" "testing" + "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/types" @@ -33,11 +34,11 @@ import ( ) func TestOPAMiddleware(t *testing.T) { - getSDK := func(t *testing.T, oas *openapi.OpenAPISpec, opaModule *OPAModuleConfig) SDK { + getSDK := func(t *testing.T, oas *openapi.OpenAPISpec, opaModule *core.OPAModuleConfig) core.SDK { t.Helper() logger, _ := test.NewNullLogger() - sdk, err := NewSDK( + sdk, err := core.NewSDK( context.Background(), logrus.NewEntry(logger), oas, @@ -53,13 +54,13 @@ func TestOPAMiddleware(t *testing.T) { routesNotToProxy := make([]string, 0) t.Run(`strict mode failure`, func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies todo { true }`, } var openAPISpec *openapi.OpenAPISpec - openAPISpecContent, err := os.ReadFile("../mocks/simplifiedMock.json") + openAPISpecContent, err := os.ReadFile("../../mocks/simplifiedMock.json") require.NoError(t, err) err = json.Unmarshal(openAPISpecContent, &openAPISpec) require.NoError(t, err) @@ -117,14 +118,14 @@ todo { true }`, }) t.Run(`documentation request`, func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies foobar { true }`, } t.Run(`ok - path is known on oas with no permission declared`, func(t *testing.T) { - openAPISpec, err := openapi.LoadOASFile("../mocks/documentationPathMock.json") + openAPISpec, err := openapi.LoadOASFile("../../mocks/documentationPathMock.json") require.NoError(t, err) targetServiceOASPath := "/documentation/json" sdk := getSDK(t, openAPISpec, opaModule) @@ -142,7 +143,7 @@ foobar { true }`, }) t.Run(`ok - path is missing on oas and request is equal to serviceTargetOASPath`, func(t *testing.T) { - openAPISpec, err := openapi.LoadOASFile("../mocks/simplifiedMock.json") + openAPISpec, err := openapi.LoadOASFile("../../mocks/simplifiedMock.json") require.NoError(t, err) targetServiceOASPath := "/documentation/json" sdk := getSDK(t, openAPISpec, opaModule) @@ -160,7 +161,7 @@ foobar { true }`, }) t.Run(`ok - path is NOT known on oas but is proxied anyway`, func(t *testing.T) { - openAPISpec, err := openapi.LoadOASFile("../mocks/simplifiedMock.json") + openAPISpec, err := openapi.LoadOASFile("../../mocks/simplifiedMock.json") require.NoError(t, err) targetServiceOASPath := "/documentation/custom/json" sdk := getSDK(t, openAPISpec, opaModule) @@ -179,11 +180,11 @@ foobar { true }`, }) t.Run(`injects opa instance with correct query`, func(t *testing.T) { - openAPISpec, err := openapi.LoadOASFile("../mocks/simplifiedMock.json") + openAPISpec, err := openapi.LoadOASFile("../../mocks/simplifiedMock.json") require.NoError(t, err) t.Run(`rego package doesn't contain expected policy`, func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies another { true }`, } @@ -191,7 +192,7 @@ foobar { true }`, middleware := OPAMiddleware(opaModule, sdk, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := GetEvaluatorSKD(r.Context()) + actual, err := core.GetEvaluatorSKD(r.Context()) require.NoError(t, err, "Unexpected error") require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "todo"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -205,7 +206,7 @@ foobar { true }`, }) t.Run(`rego package contains expected permission`, func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies todo { true }`, } @@ -213,7 +214,7 @@ foobar { true }`, middleware := OPAMiddleware(opaModule, sdk, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := GetEvaluatorSKD(r.Context()) + actual, err := core.GetEvaluatorSKD(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "todo"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -227,7 +228,7 @@ foobar { true }`, }) t.Run(`rego package contains composed permission`, func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies very_very_composed_permission { true }`, @@ -236,7 +237,7 @@ very_very_composed_permission { true }`, middleware := OPAMiddleware(opaModule, sdk, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := GetEvaluatorSKD(r.Context()) + actual, err := core.GetEvaluatorSKD(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -250,7 +251,7 @@ very_very_composed_permission { true }`, }) t.Run("injects correct permission", func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies very_very_composed_permission_with_eval { true }`, @@ -264,7 +265,7 @@ very_very_composed_permission_with_eval { true }`, middleware := OPAMiddleware(opaModule, sdk, routesNotToProxy, "", options) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := GetEvaluatorSKD(r.Context()) + actual, err := core.GetEvaluatorSKD(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission.with.eval"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -282,7 +283,7 @@ very_very_composed_permission_with_eval { true }`, routesNotToProxy := []string{"/not/proxy"} middleware := OPAMiddleware(nil, nil, routesNotToProxy, "", nil) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := GetEvaluatorSKD(r.Context()) + _, err := core.GetEvaluatorSKD(r.Context()) require.EqualError(t, err, "no SDKEvaluator found in request context") })) @@ -297,18 +298,18 @@ very_very_composed_permission_with_eval { true }`, func TestOPAMiddlewareStandaloneIntegration(t *testing.T) { var routesNotToProxy = []string{} - openAPISpec, err := openapi.LoadOASFile("../mocks/simplifiedMock.json") + openAPISpec, err := openapi.LoadOASFile("../../mocks/simplifiedMock.json") require.Nil(t, err) options := &OPAMiddlewareOptions{ IsStandalone: true, PathPrefixStandalone: "/eval", } - getSdk := func(t *testing.T, opaModule *OPAModuleConfig) SDK { + getSdk := func(t *testing.T, opaModule *core.OPAModuleConfig) core.SDK { t.Helper() log, _ := test.NewNullLogger() logger := logrus.NewEntry(log) - sdk, err := NewSDK( + sdk, err := core.NewSDK( context.Background(), logger, openAPISpec, @@ -322,7 +323,7 @@ func TestOPAMiddlewareStandaloneIntegration(t *testing.T) { } t.Run("injects correct path removing prefix", func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies very_very_composed_permission { true }`, @@ -331,7 +332,7 @@ func TestOPAMiddlewareStandaloneIntegration(t *testing.T) { sdk := getSdk(t, opaModule) middleware := OPAMiddleware(opaModule, sdk, routesNotToProxy, "", options) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := GetEvaluatorSKD(r.Context()) + actual, err := core.GetEvaluatorSKD(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission"}}, actual.Config()) w.WriteHeader(http.StatusOK) @@ -345,7 +346,7 @@ func TestOPAMiddlewareStandaloneIntegration(t *testing.T) { }) t.Run("injects correct path removing only one prefix", func(t *testing.T) { - opaModule := &OPAModuleConfig{ + opaModule := &core.OPAModuleConfig{ Name: "example.rego", Content: `package policies very_very_composed_permission_with_eval { true }`, @@ -354,7 +355,7 @@ very_very_composed_permission_with_eval { true }`, sdk := getSdk(t, opaModule) middleware := OPAMiddleware(opaModule, sdk, routesNotToProxy, "", options) builtHandler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actual, err := GetEvaluatorSKD(r.Context()) + actual, err := core.GetEvaluatorSKD(r.Context()) require.NoError(t, err) require.Equal(t, openapi.RondConfig{RequestFlow: openapi.RequestFlow{PolicyName: "very.very.composed.permission.with.eval"}}, actual.Config()) w.WriteHeader(http.StatusOK) diff --git a/service/handler.go b/service/handler.go index 01b3b5f7..02b330c1 100644 --- a/service/handler.go +++ b/service/handler.go @@ -25,6 +25,7 @@ import ( "github.com/rond-authz/rond/internal/opatranslator" "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/openapi" + rondmux "github.com/rond-authz/rond/routers/mux" "github.com/rond-authz/rond/types" "github.com/gorilla/mux" @@ -111,7 +112,7 @@ func EvaluateRequest( return err } - rondInput := core.NewRondInput(req, env.ClientTypeHeader, mux.Vars(req)) + rondInput := rondmux.NewRondInput(req, env.ClientTypeHeader, mux.Vars(req)) result, err := evaluatorSdk.EvaluateRequestPolicy(req.Context(), rondInput, userInfo) if err != nil { if errors.Is(err, opatranslator.ErrEmptyQuery) && utils.HasApplicationJSONContentType(req.Header) { @@ -167,7 +168,7 @@ func ReverseProxy( proxy.ServeHTTP(w, req) return } - proxy.Transport = core.NewOPATransport( + proxy.Transport = NewOPATransport( http.DefaultTransport, req.Context(), logger, diff --git a/core/opa_transport.go b/service/opa_transport.go similarity index 94% rename from core/opa_transport.go rename to service/opa_transport.go index 0620aceb..06ad7d3c 100644 --- a/core/opa_transport.go +++ b/service/opa_transport.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package service import ( "bytes" @@ -23,8 +23,10 @@ import ( "net/http" "strconv" + "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/internal/mongoclient" "github.com/rond-authz/rond/internal/utils" + rondmux "github.com/rond-authz/rond/routers/mux" "github.com/rond-authz/rond/types" "github.com/gorilla/mux" @@ -40,7 +42,7 @@ type OPATransport struct { clientHeaderKey string userHeaders types.UserHeadersKeys - evaluatorSDK SDKEvaluator + evaluatorSDK core.SDKEvaluator } func NewOPATransport( @@ -50,7 +52,7 @@ func NewOPATransport( req *http.Request, clientHeaderKey string, userHeadersKeys types.UserHeadersKeys, - evaluatorSDK SDKEvaluator, + evaluatorSDK core.SDKEvaluator, ) *OPATransport { return &OPATransport{ RoundTripper: transport, @@ -108,7 +110,7 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er } pathParams := mux.Vars(t.request) - input := NewRondInput(t.request, t.clientHeaderKey, pathParams) + input := rondmux.NewRondInput(t.request, t.clientHeaderKey, pathParams) responseBody, err := t.evaluatorSDK.EvaluateResponsePolicy(t.context, input, userInfo, decodedBody) if err != nil { diff --git a/core/opa_transport_test.go b/service/opa_transport_test.go similarity index 92% rename from core/opa_transport_test.go rename to service/opa_transport_test.go index faa1ee87..87727909 100644 --- a/core/opa_transport_test.go +++ b/service/opa_transport_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package service import ( "bytes" @@ -25,9 +25,12 @@ import ( "strconv" "testing" + "github.com/prometheus/client_golang/prometheus" + "github.com/rond-authz/rond/core" "github.com/rond-authz/rond/internal/mocks" "github.com/rond-authz/rond/internal/mongoclient" "github.com/rond-authz/rond/internal/utils" + "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/types" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" @@ -465,3 +468,49 @@ func (m *MockReader) Read(p []byte) (n int, err error) { func (m *MockReader) Close() error { return m.CloseError } + +type sdkOptions struct { + opaModuleContent string + oasFilePath string + + mongoClient types.IMongoClient + registry *prometheus.Registry +} + +type tHelper interface { + Helper() +} + +func getSdk(t require.TestingT, options *sdkOptions) core.SDK { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + logger := logrus.NewEntry(logrus.New()) + if options == nil { + options = &sdkOptions{} + } + + var oasFilePath = "../mocks/simplifiedMock.json" + if options.oasFilePath != "" { + oasFilePath = options.oasFilePath + } + + openAPISpec, err := openapi.LoadOASFile(oasFilePath) + require.NoError(t, err) + opaModule := &core.OPAModuleConfig{ + Name: "example.rego", + Content: `package policies + todo { true }`, + } + if options.opaModuleContent != "" { + opaModule.Content = options.opaModuleContent + } + sdk, err := core.NewSDK(context.Background(), logger, openAPISpec, opaModule, &core.EvaluatorOptions{ + EnablePrintStatements: true, + MongoClient: options.mongoClient, + }, options.registry, "") + require.NoError(t, err) + + return sdk +} diff --git a/service/router.go b/service/router.go index 1318c7af..ebfc3c57 100644 --- a/service/router.go +++ b/service/router.go @@ -33,6 +33,7 @@ import ( "github.com/rond-authz/rond/internal/mongoclient" "github.com/rond-authz/rond/internal/utils" "github.com/rond-authz/rond/openapi" + rondmux "github.com/rond-authz/rond/routers/mux" "github.com/rond-authz/rond/types" "github.com/gorilla/mux" @@ -152,7 +153,7 @@ func SetupRouter( } } - evalRouter.Use(core.OPAMiddleware(opaModuleConfig, sdk, routesToNotProxy, env.TargetServiceOASPath, &core.OPAMiddlewareOptions{ + evalRouter.Use(rondmux.OPAMiddleware(opaModuleConfig, sdk, routesToNotProxy, env.TargetServiceOASPath, &rondmux.OPAMiddlewareOptions{ IsStandalone: env.Standalone, PathPrefixStandalone: env.PathPrefixStandalone, }))