diff --git a/core/sdk_test.go b/core/sdk_test.go index 450944cd..e9e44c25 100644 --- a/core/sdk_test.go +++ b/core/sdk_test.go @@ -18,6 +18,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "github.com/rond-authz/rond/internal/mocks" @@ -25,6 +26,7 @@ import ( "github.com/rond-authz/rond/types" "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" @@ -140,7 +142,6 @@ func TestSDK(t *testing.T) { }) } -// TODO: test metrics both in request and response func TestEvaluateRequestPolicy(t *testing.T) { logger := logrus.NewEntry(logrus.New()) @@ -374,32 +375,79 @@ func TestEvaluateRequestPolicy(t *testing.T) { }, } - for name, test := range testCases { + for name, testCase := range testCases { t.Run(name, func(t *testing.T) { + registry := prometheus.NewPedanticRegistry() sdk := getSdk(t, &sdkOptions{ - opaModuleContent: test.opaModuleContent, - oasFilePath: test.oasFilePath, - mongoClient: test.mongoClient, + opaModuleContent: testCase.opaModuleContent, + oasFilePath: testCase.oasFilePath, + mongoClient: testCase.mongoClient, + registry: registry, }) - evaluator, err := sdk.FindEvaluator(logger, test.method, test.path) + log, hook := test.NewNullLogger() + log.Level = logrus.DebugLevel + logger := logrus.NewEntry(log) + evaluate, err := sdk.FindEvaluator(logger, testCase.method, testCase.path) require.NoError(t, err) - req := httptest.NewRequest(test.method, test.path, nil) - if test.reqHeaders != nil { - for k, v := range test.reqHeaders { + req := httptest.NewRequest(testCase.method, testCase.path, nil) + if testCase.reqHeaders != nil { + for k, v := range testCase.reqHeaders { req.Header.Set(k, v) } } rondInput := NewRondInput(req, clientTypeHeaderKey, nil) - actual, err := evaluator.EvaluateRequestPolicy(context.Background(), rondInput, test.user) - if test.expectedErr { - require.EqualError(t, err, test.expectedErrMessage) + actual, err := evaluate.EvaluateRequestPolicy(context.Background(), rondInput, testCase.user) + if testCase.expectedErr { + require.EqualError(t, err, testCase.expectedErrMessage) } else { require.NoError(t, err) } - require.Equal(t, test.expectedPolicy, actual) + require.Equal(t, testCase.expectedPolicy, actual) + + t.Run("logger", func(t *testing.T) { + var actualEntry *logrus.Entry + for _, entry := range hook.AllEntries() { + if entry.Message == "policy evaluation completed" { + actualEntry = entry + } + } + evaluatorInfo := evaluate.(evaluator) + + require.NotNil(t, actual) + delete(actualEntry.Data, "evaluationTimeMicroseconds") + require.Equal(t, logrus.Fields{ + "allowed": actual.Allowed, + "requestedPath": testCase.path, + "matchedPath": evaluatorInfo.routeInfo.MatchedPath, + "method": testCase.method, + "partialEval": evaluate.Config().RequestFlow.GenerateQuery, + "policyName": evaluate.Config().RequestFlow.PolicyName, + }, actualEntry.Data) + }) + + t.Run("metrics", func(t *testing.T) { + metadata := ` + # HELP rond_policy_evaluation_duration_milliseconds A histogram of the policy evaluation durations in milliseconds. + # TYPE rond_policy_evaluation_duration_milliseconds histogram + ` + expected := strings.ReplaceAll(` + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="1"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="5"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="10"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="50"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="100"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="250"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="500"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="+Inf"} 1 + rond_policy_evaluation_duration_milliseconds_sum{policy_name="POLICY_NAME"} 0 + rond_policy_evaluation_duration_milliseconds_count{policy_name="POLICY_NAME"} 1 + `, "POLICY_NAME", evaluate.Config().RequestFlow.PolicyName) + + require.NoError(t, testutil.GatherAndCompare(registry, strings.NewReader(metadata+expected), "rond_policy_evaluation_duration_milliseconds")) + }) }) } }) @@ -494,7 +542,7 @@ func TestEvaluateResponsePolicy(t *testing.T) { }, } - for name, test := range testCases { + for name, testCase := range testCases { t.Run(name, func(t *testing.T) { opaModuleContent := ` package policies @@ -502,34 +550,81 @@ func TestEvaluateResponsePolicy(t *testing.T) { body := input.response.body }` - if test.opaModuleContent != "" { - opaModuleContent = test.opaModuleContent + if testCase.opaModuleContent != "" { + opaModuleContent = testCase.opaModuleContent } + log, hook := test.NewNullLogger() + log.Level = logrus.DebugLevel + logger := logrus.NewEntry(log) + registry := prometheus.NewPedanticRegistry() sdk := getSdk(t, &sdkOptions{ opaModuleContent: opaModuleContent, oasFilePath: "../mocks/rondOasConfig.json", - mongoClient: test.mongoClient, + mongoClient: testCase.mongoClient, + registry: registry, }) - evaluator, err := sdk.FindEvaluator(logger, test.method, test.path) + evaluate, err := sdk.FindEvaluator(logger, testCase.method, testCase.path) require.NoError(t, err) - req := httptest.NewRequest(test.method, test.path, nil) - if test.reqHeaders != nil { - for k, v := range test.reqHeaders { + req := httptest.NewRequest(testCase.method, testCase.path, nil) + if testCase.reqHeaders != nil { + for k, v := range testCase.reqHeaders { req.Header.Set(k, v) } } rondInput := NewRondInput(req, clientTypeHeaderKey, nil) - actual, err := evaluator.EvaluateResponsePolicy(context.Background(), rondInput, test.user, test.decodedBody) - if test.expectedErr { - require.EqualError(t, err, test.expectedErrMessage) + actual, err := evaluate.EvaluateResponsePolicy(context.Background(), rondInput, testCase.user, testCase.decodedBody) + if testCase.expectedErr { + require.EqualError(t, err, testCase.expectedErrMessage) } else { require.NoError(t, err) } - require.JSONEq(t, test.expectedBody, string(actual)) + require.JSONEq(t, testCase.expectedBody, string(actual)) + + t.Run("logger", func(t *testing.T) { + var actual *logrus.Entry + for _, entry := range hook.AllEntries() { + if entry.Message == "policy evaluation completed" { + actual = entry + } + } + evaluatorInfo := evaluate.(evaluator) + + require.NotNil(t, actual) + delete(actual.Data, "evaluationTimeMicroseconds") + require.Equal(t, logrus.Fields{ + "allowed": true, + "requestedPath": testCase.path, + "matchedPath": evaluatorInfo.routeInfo.MatchedPath, + "method": testCase.method, + "partialEval": false, + "policyName": evaluate.Config().ResponseFlow.PolicyName, + }, actual.Data) + }) + + t.Run("metrics", func(t *testing.T) { + metadata := ` + # HELP rond_policy_evaluation_duration_milliseconds A histogram of the policy evaluation durations in milliseconds. + # TYPE rond_policy_evaluation_duration_milliseconds histogram + ` + expected := strings.ReplaceAll(` + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="1"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="5"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="10"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="50"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="100"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="250"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="500"} 1 + rond_policy_evaluation_duration_milliseconds_bucket{policy_name="POLICY_NAME",le="+Inf"} 1 + rond_policy_evaluation_duration_milliseconds_sum{policy_name="POLICY_NAME"} 0 + rond_policy_evaluation_duration_milliseconds_count{policy_name="POLICY_NAME"} 1 + `, "POLICY_NAME", evaluate.Config().ResponseFlow.PolicyName) + + require.NoError(t, testutil.GatherAndCompare(registry, strings.NewReader(metadata+expected), "rond_policy_evaluation_duration_milliseconds")) + }) }) } }) @@ -604,6 +699,7 @@ type sdkOptions struct { oasFilePath string mongoClient types.IMongoClient + registry *prometheus.Registry } type tHelper interface { @@ -635,11 +731,10 @@ func getSdk(t require.TestingT, options *sdkOptions) SDK { if options.opaModuleContent != "" { opaModule.Content = options.opaModuleContent } - registry := prometheus.NewRegistry() sdk, err := NewSDK(context.Background(), logger, openAPISpec, opaModule, &EvaluatorOptions{ EnablePrintStatements: true, MongoClient: options.mongoClient, - }, registry, "") + }, options.registry, "") require.NoError(t, err) return sdk