diff --git a/core/opa_transport.go b/core/opa_transport.go index 701ba741..8d6933f2 100644 --- a/core/opa_transport.go +++ b/core/opa_transport.go @@ -25,7 +25,6 @@ import ( "github.com/rond-authz/rond/internal/mongoclient" "github.com/rond-authz/rond/internal/utils" - "github.com/rond-authz/rond/openapi" "github.com/rond-authz/rond/types" "github.com/gorilla/mux" @@ -35,15 +34,13 @@ import ( type OPATransport struct { http.RoundTripper // FIXME: this overlaps with the req.Context used during RoundTrip. - context context.Context - logger *logrus.Entry - request *http.Request - permission *openapi.RondConfig - partialResultsEvaluators PartialResultsEvaluators - - clientHeaderKey string - userHeaders types.UserHeadersKeys - evaluatorOptions *EvaluatorOptions + context context.Context + logger *logrus.Entry + request *http.Request + + clientHeaderKey string + userHeaders types.UserHeadersKeys + evaluatorSDK SDKEvaluator } func NewOPATransport( @@ -51,23 +48,19 @@ func NewOPATransport( context context.Context, logger *logrus.Entry, req *http.Request, - permission *openapi.RondConfig, - partialResultsEvaluators PartialResultsEvaluators, clientHeaderKey string, userHeadersKeys types.UserHeadersKeys, - evaluatorOptions *EvaluatorOptions, + evaluatorSDK SDKEvaluator, ) *OPATransport { return &OPATransport{ - RoundTripper: transport, - context: req.Context(), - logger: logger, - request: req, - permission: permission, - partialResultsEvaluators: partialResultsEvaluators, - - clientHeaderKey: clientHeaderKey, - userHeaders: userHeadersKeys, - evaluatorOptions: evaluatorOptions, + RoundTripper: transport, + context: req.Context(), + logger: logger, + request: req, + + clientHeaderKey: clientHeaderKey, + userHeaders: userHeadersKeys, + evaluatorSDK: evaluatorSDK, } } @@ -116,42 +109,14 @@ func (t *OPATransport) RoundTrip(req *http.Request) (resp *http.Response, err er pathParams := mux.Vars(t.request) rondReq := NewRondInput(t.request, t.clientHeaderKey, pathParams) - input, err := rondReq.FromRequestInfo(userInfo, decodedBody) - if err != nil { - t.responseWithError(resp, err, http.StatusInternalServerError) - return resp, nil - } - - regoInput, err := CreateRegoQueryInput(t.logger, input, RegoInputOptions{ - EnableResourcePermissionsMapOptimization: t.permission.Options.EnableResourcePermissionsMapOptimization, - }) - if err != nil { - t.responseWithError(resp, err, http.StatusInternalServerError) - return resp, nil - } - - evaluator, err := t.partialResultsEvaluators.GetEvaluatorFromPolicy(t.context, t.permission.ResponseFlow.PolicyName, regoInput, t.evaluatorOptions) - if err != nil { - t.logger.WithField("error", logrus.Fields{ - "policyName": t.permission.ResponseFlow.PolicyName, - "message": err.Error(), - }).Error("RBAC policy evaluation on response failed") - t.responseWithError(resp, err, http.StatusInternalServerError) - return resp, nil - } - bodyToProxy, err := evaluator.Evaluate(t.logger) + responseBody, err := t.evaluatorSDK.EvaluateResponsePolicy(t.context, rondReq, userInfo, decodedBody) if err != nil { t.responseWithError(resp, err, http.StatusForbidden) return resp, nil } - marshalledBody, err := json.Marshal(bodyToProxy) - if err != nil { - t.responseWithError(resp, err, http.StatusInternalServerError) - return resp, nil - } - overwriteResponse(resp, marshalledBody) + overwriteResponse(resp, responseBody) return resp, nil } diff --git a/core/opa_transport_test.go b/core/opa_transport_test.go index 98d54867..56562375 100644 --- a/core/opa_transport_test.go +++ b/core/opa_transport_test.go @@ -16,7 +16,6 @@ package core import ( "bytes" - "context" "encoding/json" "fmt" "io" @@ -53,17 +52,15 @@ func TestRoundTripErrors(t *testing.T) { JSON(responseBody) req := httptest.NewRequest(http.MethodPost, "http://example.com/some-api", nil) - transport := &OPATransport{ + transport := NewOPATransport( http.DefaultTransport, req.Context(), logrus.NewEntry(logger), req, - nil, - nil, "", types.UserHeadersKeys{}, nil, - } + ) resp, err := transport.RoundTrip(req) require.NoError(t, err, "unexpected error") @@ -92,17 +89,15 @@ func TestOPATransportResponseWithError(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "http://example.com/some-api", nil) - transport := &OPATransport{ + transport := NewOPATransport( http.DefaultTransport, req.Context(), logrus.NewEntry(logger), req, - nil, - nil, "", types.UserHeadersKeys{}, nil, - } + ) t.Run("generic business error message", func(t *testing.T) { resp := &http.Response{ @@ -151,7 +146,7 @@ func TestOPATransportResponseWithError(t *testing.T) { func TestOPATransportRoundTrip(t *testing.T) { logger, _ := test.NewNullLogger() - req := httptest.NewRequest(http.MethodPost, "http://example.com/some-api", nil) + req := httptest.NewRequest(http.MethodGet, "/users", nil) t.Run("returns error on RoundTrip error", func(t *testing.T) { transport := NewOPATransport( @@ -159,7 +154,6 @@ func TestOPATransportRoundTrip(t *testing.T) { req.Context(), logrus.NewEntry(logger), req, - nil, nil, "", types.UserHeadersKeys{ IDHeaderKey: "useridheader", @@ -279,35 +273,36 @@ func TestOPATransportRoundTrip(t *testing.T) { }) t.Run("ok with filter response", func(t *testing.T) { - resp := http.Response{ + resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(`{"some":"field"}`))), ContentLength: 16, Header: http.Header{"Content-Type": []string{"application/json"}}, } - req = req.Clone(metrics.WithValue(openapi.WithRouterInfo(logrus.NewEntry(logger), req.Context(), req), metrics.SetupMetrics(""))) + logEntry := logrus.NewEntry(logger) + req = req.Clone(metrics.WithValue(openapi.WithRouterInfo(logEntry, req.Context(), req), metrics.SetupMetrics(""))) - partialResult, err := NewPartialResultEvaluator(context.Background(), "my_policy", &OPAModuleConfig{ - Content: "package policies my_policy [resources] { resources := input.response.body }", - }, nil) + evaluator := getSdk(t, &sdkOptions{ + oasFilePath: "../mocks/rondOasConfig.json", + opaModuleContent: "package policies responsepolicy [resources] { resources := input.response.body }", + }) + evaluatorSDK, err := evaluator.FindEvaluator(logEntry, http.MethodGet, "/users/") require.NoError(t, err) - transport := &OPATransport{ - RoundTripper: &MockRoundTrip{Response: &resp}, - context: req.Context(), - logger: logrus.NewEntry(logger), - request: req, - permission: &openapi.RondConfig{ - ResponseFlow: openapi.ResponseFlow{PolicyName: "my_policy"}, - }, - partialResultsEvaluators: PartialResultsEvaluators{"my_policy": PartialEvaluator{partialResult}}, - userHeaders: types.UserHeadersKeys{ + transport := NewOPATransport( + &MockRoundTrip{Response: resp}, + req.Context(), + logrus.NewEntry(logger), + req, + "", + types.UserHeadersKeys{ IDHeaderKey: "useridheader", GroupsHeaderKey: "usergroupsheader", PropertiesHeaderKey: "userpropertiesheader", }, - } + evaluatorSDK, + ) actualResp, err := transport.RoundTrip(req) require.NoError(t, err, "response body is not valid") @@ -419,22 +414,28 @@ func TestOPATransportRoundTrip(t *testing.T) { ContentLength: 0, Header: http.Header{"Content-Type": []string{"application/json"}}, } - transport := &OPATransport{ - RoundTripper: &MockRoundTrip{Response: resp}, - context: req.Context(), - logger: logrus.NewEntry(logger), - request: req, - permission: &openapi.RondConfig{ - ResponseFlow: openapi.ResponseFlow{PolicyName: "my_policy"}, - }, - partialResultsEvaluators: PartialResultsEvaluators{"my_policy": {}}, - userHeaders: types.UserHeadersKeys{ + evaluator := getSdk(t, &sdkOptions{ + oasFilePath: "../mocks/rondOasConfig.json", + opaModuleContent: "package policies responsepolicy [resources] { resources := input.response.body }", + }) + logEntry := logrus.NewEntry(logger) + evaluatorSDK, err := evaluator.FindEvaluator(logEntry, http.MethodGet, "/users/") + require.NoError(t, err) + + transport := NewOPATransport( + &MockRoundTrip{Response: resp}, + req.Context(), + logrus.NewEntry(logger), + req, + "", + types.UserHeadersKeys{ IDHeaderKey: "useridheader", GroupsHeaderKey: "usergroupsheader", PropertiesHeaderKey: "userpropertiesheader", }, - } - resp, err := transport.RoundTrip(req) + evaluatorSDK, + ) + resp, err = transport.RoundTrip(req) require.Nil(t, err) require.Equal(t, http.StatusInternalServerError, resp.StatusCode) bodyBytes, err := io.ReadAll(resp.Body) diff --git a/service/handler.go b/service/handler.go index 0674e310..e47a9aa2 100644 --- a/service/handler.go +++ b/service/handler.go @@ -43,10 +43,8 @@ func ReverseProxyOrResponse( evaluatorSdk core.SDKEvaluator, ) { var permission openapi.RondConfig - var partialResultsEvaluators core.PartialResultsEvaluators if evaluatorSdk != nil { permission = evaluatorSdk.Config() - partialResultsEvaluators = evaluatorSdk.PartialResultsEvaluators() } if env.Standalone { @@ -64,7 +62,7 @@ func ReverseProxyOrResponse( } return } - ReverseProxy(logger, env, w, req, &permission, partialResultsEvaluators) + ReverseProxy(logger, env, w, req, &permission, evaluatorSdk) } func rbacHandler(w http.ResponseWriter, req *http.Request) { @@ -149,7 +147,7 @@ func ReverseProxy( w http.ResponseWriter, req *http.Request, permission *openapi.RondConfig, - partialResultsEvaluators core.PartialResultsEvaluators, + evaluatorSdk core.SDKEvaluator, ) { targetHostFromEnv := env.TargetServiceHost proxy := httputil.ReverseProxy{ @@ -164,10 +162,6 @@ func ReverseProxy( }, } - options := &core.EvaluatorOptions{ - EnablePrintStatements: env.IsTraceLogLevel(), - } - // Check on nil is performed to proxy the oas documentation path if permission == nil || permission.ResponseFlow.PolicyName == "" { proxy.ServeHTTP(w, req) @@ -178,8 +172,6 @@ func ReverseProxy( req.Context(), logger, req, - permission, - partialResultsEvaluators, env.ClientTypeHeader, types.UserHeadersKeys{ @@ -187,7 +179,7 @@ func ReverseProxy( GroupsHeaderKey: env.UserGroupsHeader, PropertiesHeaderKey: env.UserPropertiesHeader, }, - options, + evaluatorSdk, ) proxy.ServeHTTP(w, req) }