Skip to content

Commit

Permalink
fix: issue 368
Browse files Browse the repository at this point in the history
  • Loading branch information
davidebianchi committed Jul 30, 2024
1 parent 621f407 commit a1c738c
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 37 deletions.
14 changes: 13 additions & 1 deletion mocks/nestedPathsConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,19 @@
}
}
},
"/trailing-slash-with-variables/:id/": {
"/trailing-slash-with-variables/:id": {
"get": {
"x-rond": {
"requestFlow": {
"policyName": "allow_params_trailing_slash"
},
"options": {
"ignoreTrailingSlash": true
}
}
}
},
"/ends-with-trailing-slash/:id/": {
"get": {
"x-rond": {
"requestFlow": {
Expand Down
4 changes: 3 additions & 1 deletion service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ func EvaluateRequest(
"generateQuery": evaluationConfig.RequestFlow.GenerateQuery,
"resourceOptmizationMap": evaluationConfig.Options.EnableResourcePermissionsMapOptimization,
}).Trace("creating rond input")
rondInput, err := rondhttp.NewInput(&evaluationConfig, req, env.ClientTypeHeader, mux.Vars(req), rondInputUser, nil)
pathParams := mux.Vars(req)
delete(pathParams, trailingSlashVariable)
rondInput, err := rondhttp.NewInput(&evaluationConfig, req, env.ClientTypeHeader, pathParams, rondInputUser, nil)
if err != nil {
logger.WithField("error", logrus.Fields{"message": err.Error()}).Error("failed to create rond input")
utils.FailResponseWithCode(w, http.StatusInternalServerError, "failed to create rond input", utils.GENERIC_BUSINESS_ERROR_MESSAGE)
Expand Down
14 changes: 12 additions & 2 deletions service/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ import (
"github.com/sirupsen/logrus"
)

const serviceName = "rönd"
const (
serviceName = "rönd"
trailingSlashVariable = "/"
)

var routesToNotProxy = utils.Union(statusRoutes, []string{metricsRoutePath})

Expand Down Expand Up @@ -170,7 +173,7 @@ func setupEvalRoutes(router *mux.Router, oas *openapi.OpenAPISpec, env config.En
actualPathToRegister := openapi.ConvertPathVariablesToBrackets(pathToRegister)
shouldIgnoreTrailingSlash := ignoreTrailingSlashMap[path][method]
if shouldIgnoreTrailingSlash {
actualPathToRegister = fmt.Sprintf("/{%s:%s\\/?}", openapi.ConvertPathVariablesToBrackets(pathToRegister), openapi.ConvertPathVariablesToBrackets(pathToRegister))
actualPathToRegister = fmt.Sprintf("%s{%s:\\/?}", removeTrailingSlash(openapi.ConvertPathVariablesToBrackets(pathToRegister)), trailingSlashVariable)
}
router.HandleFunc(actualPathToRegister, rbacHandler).Methods(method)
}
Expand Down Expand Up @@ -257,3 +260,10 @@ func setupServiceRouter(
setupEvalRoutes(evalRouter, oas, env)
return nil
}

func removeTrailingSlash(s string) string {
if strings.HasSuffix(s, "/") {
return s[:len(s)-1]
}
return s
}
149 changes: 116 additions & 33 deletions service/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,36 @@ func TestSetupRoutes(t *testing.T) {
},
},
},
// "/trailing-slash-with-variables/{id}": openapi.PathVerbs{
// "get": openapi.VerbConfig{
// PermissionV2: &core.RondConfig{
// RequestFlow: core.RequestFlow{
// PolicyName: "filter_policy",
// },
// Options: core.PermissionOptions{IgnoreTrailingSlash: true},
// },
// },
// },
"/with/trailing/slash/": openapi.PathVerbs{
"get": openapi.VerbConfig{
PermissionV2: &core.RondConfig{
RequestFlow: core.RequestFlow{
PolicyName: "filter_policy",
},
Options: core.PermissionOptions{IgnoreTrailingSlash: true},
},
},
},
"/trailing-slash-with-variables/{id}": openapi.PathVerbs{
"get": openapi.VerbConfig{
PermissionV2: &core.RondConfig{
RequestFlow: core.RequestFlow{
PolicyName: "filter_policy",
},
Options: core.PermissionOptions{IgnoreTrailingSlash: true},
},
},
},
"/trailing-slash-with-variables/{id}/": openapi.PathVerbs{
"get": openapi.VerbConfig{
PermissionV2: &core.RondConfig{
RequestFlow: core.RequestFlow{
PolicyName: "filter_policy",
},
Options: core.PermissionOptions{IgnoreTrailingSlash: true},
},
},
},
},
}
expectedPaths := []string{
Expand All @@ -95,7 +115,10 @@ func TestSetupRoutes(t *testing.T) {
"/documentation/json",
"/foo",
"/foo/bar",
"/{/with/trailing/slash:/with/trailing/slash\\/?}",
"/trailing-slash-with-variables/{id}{/:\\/?}",
"/trailing-slash-with-variables/{id}{/:\\/?}",
"/with/trailing/slash{/:\\/?}",
"/with/trailing/slash{/:\\/?}",
}

setupEvalRoutes(router, oas, envs)
Expand Down Expand Up @@ -575,9 +598,9 @@ allow_params_trailing_slash {
`,
}

var invoked bool
var invokedTimes int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
invoked = true
invokedTimes++
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
Expand All @@ -587,27 +610,87 @@ allow_params_trailing_slash {

serverURL, _ := url.Parse(server.URL)

evaluator := getEvaluator(t, mockOPAModule, nil, oas, http.MethodGet, "/trailing-slash-with-variables/id/", nil)
ctx := createContext(t,
context.Background(),
config.EnvironmentVariables{TargetServiceHost: serverURL.Host},
evaluator,
nil,
nil,
)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://host/trailing-slash-with-variables/my-id/", nil)
require.NoError(t, err, "Unexpected error")

var matchedRouted mux.RouteMatch
ok := router.Match(req, &matchedRouted)
require.True(t, ok, "Route not found")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)
t.Run("/trailing-slash-with-variables/:id", func(t *testing.T) {
evaluator := getEvaluator(t, mockOPAModule, nil, oas, http.MethodGet, "/trailing-slash-with-variables/my-id", nil)
ctx := createContext(t,
context.Background(),
config.EnvironmentVariables{TargetServiceHost: serverURL.Host},
evaluator,
nil,
nil,
)

t.Run("with trailing slash", func(t *testing.T) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://host/trailing-slash-with-variables/my-id/", nil)
require.NoError(t, err, "Unexpected error")

var matchedRouted mux.RouteMatch
ok := router.Match(req, &matchedRouted)
require.True(t, ok, "Route not found")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Result().StatusCode)
require.Equal(t, 1, invokedTimes, "mock server was not invoked")
})

t.Run("without trailing slash", func(t *testing.T) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://host/trailing-slash-with-variables/my-id", nil)
require.NoError(t, err, "Unexpected error")

var matchedRouted mux.RouteMatch
ok := router.Match(req, &matchedRouted)
require.True(t, ok, "Route not found")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Result().StatusCode)
require.Equal(t, 2, invokedTimes, "mock server was not invoked")
})
})

require.True(t, invoked, "mock server was not invoked")
require.Equal(t, http.StatusOK, w.Result().StatusCode)
t.Run("/ends-with-trailing-slash/:id/", func(t *testing.T) {
evaluator := getEvaluator(t, mockOPAModule, nil, oas, http.MethodGet, "/ends-with-trailing-slash/my-id/", nil)
ctx := createContext(t,
context.Background(),
config.EnvironmentVariables{TargetServiceHost: serverURL.Host},
evaluator,
nil,
nil,
)

t.Run("with trailing slash", func(t *testing.T) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://host/ends-with-trailing-slash/my-id/", nil)
require.NoError(t, err, "Unexpected error")

var matchedRouted mux.RouteMatch
ok := router.Match(req, &matchedRouted)
require.True(t, ok, "Route not found")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Result().StatusCode)
require.Equal(t, 3, invokedTimes, "mock server was not invoked")
})

t.Run("without trailing slash", func(t *testing.T) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://host/ends-with-trailing-slash/my-id", nil)
require.NoError(t, err, "Unexpected error")

var matchedRouted mux.RouteMatch
ok := router.Match(req, &matchedRouted)
require.True(t, ok, "Route not found")

w := httptest.NewRecorder()
router.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Result().StatusCode)
require.Equal(t, 4, invokedTimes, "mock server was not invoked")
})
})
})
}

Expand Down

0 comments on commit a1c738c

Please sign in to comment.