Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve access rule middleware #485

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions pkg/plugin/oauth2/middleware_access_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ func NewRevokeRulesMiddleware(parser *jwt.Parser, accessRules []*AccessRule) fun
for _, rule := range accessRules {
allowed, err := rule.IsAllowed(claims)
if err != nil {
log.WithError(err).Debug("Rule is not allowed")
continue
}

if allowed {
handler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusUnauthorized)
return
log.WithError(err).Debug("Rule is invalid")
} else if rule.matched {
if allowed {
break
} else {
w.WriteHeader(http.StatusUnauthorized)
return
}
}
}
}
Expand Down
217 changes: 86 additions & 131 deletions pkg/plugin/oauth2/middleware_access_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,174 +15,129 @@ import (

const signingAlg = "HS256"

func TestBlockJWTByCountry(t *testing.T) {
secret := "secret"
func generateToken(alg, key string) (string, error) {
token := basejwt.NewWithClaims(basejwt.GetSigningMethod(alg), basejwt.MapClaims{
"country": "de",
"username": "test@hellofresh.com",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mocked info?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, those are mocked / static claims for testing the access rules. (They were previously in between test cases, I moved them to the top.)

"iat": time.Now().Unix(),
})

revokeRules := []*AccessRule{
{Predicate: "country == 'de'", Action: "deny"},
}
return token.SignedString([]byte(key))
}

func expectRulesToProduceStatus(t *testing.T, statusCode int, rules []*AccessRule) {
secret := "secret"

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret}))

mw := NewRevokeRulesMiddleware(parser, revokeRules)
mw := NewRevokeRulesMiddleware(parser, rules)
token, err := generateToken(signingAlg, secret)
require.NoError(t, err)

w, err := test.Record(
"GET",
"/",
map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
},
mw(http.HandlerFunc(test.Ping)),
)
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, w.Code)
for i := 1; i <= 3; i++ { // middleware caches predicate and should return the same response every time
hits := 0
w, err := test.Record(
"GET",
"/",
map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
},
mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits++
test.Ping(w, r)
})),
)

assert.NoError(t, err, "%d. pass", i)
assert.Equal(t, statusCode, w.Code, "%d. pass", i)
if statusCode == http.StatusOK {
assert.Equal(t, 1, hits, "%d. pass", i)
} else {
assert.Equal(t, 0, hits, "%d. pass", i)
}
}
}

func TestBlockJWTByUsername(t *testing.T) {
secret := "secret"
func TestBlockJWTByCountry(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{
{Predicate: "country == 'de'", Action: "deny"},
})
}

revokeRules := []*AccessRule{
func TestBlockJWTByUsername(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{
{Predicate: "username == 'test@hellofresh.com'", Action: "deny"},
}

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret}))

mw := NewRevokeRulesMiddleware(parser, revokeRules)
token, err := generateToken(signingAlg, secret)
require.NoError(t, err)

w, err := test.Record(
"GET",
"/",
map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
},
mw(http.HandlerFunc(test.Ping)),
)
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, w.Code)
})
}

func TestBlockJWTByIssueDate(t *testing.T) {
secret := "secret"

revokeRules := []*AccessRule{
expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{
{Predicate: fmt.Sprintf("iat < %d", time.Now().Add(1*time.Hour).Unix()), Action: "deny"},
}

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret}))

mw := NewRevokeRulesMiddleware(parser, revokeRules)
token, err := generateToken(signingAlg, secret)
require.NoError(t, err)

w, err := test.Record(
"GET",
"/",
map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
},
mw(http.HandlerFunc(test.Ping)),
)
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, w.Code)
})
}

func TestBlockJWTByCountryAndIssueDate(t *testing.T) {
secret := "secret"

revokeRules := []*AccessRule{
expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{
{Predicate: fmt.Sprintf("country == 'de' && iat < %d", time.Now().Add(1*time.Hour).Unix()), Action: "deny"},
}

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret}))

mw := NewRevokeRulesMiddleware(parser, revokeRules)
token, err := generateToken(signingAlg, secret)
require.NoError(t, err)
})
}

w, err := test.Record(
"GET",
"/",
map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
},
mw(http.HandlerFunc(test.Ping)),
)
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, w.Code)
func TestEmptyAccessRules(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{})
}

func generateToken(alg, key string) (string, error) {
token := basejwt.NewWithClaims(basejwt.GetSigningMethod(alg), basejwt.MapClaims{
"country": "de",
"username": "test@hellofresh.com",
"iat": time.Now().Unix(),
func TestWrongRule(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{
{Predicate: "country == 'wrong'", Action: "deny"},
})

return token.SignedString([]byte(key))
}

func TestEmptyAccessRules(t *testing.T) {
secret := "secret"

revokeRules := []*AccessRule{}
func TestMultipleRulesSecondMatchesAndDenies(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{
{Predicate: "country == 'us'", Action: "deny"},
{Predicate: "country == 'de'", Action: "deny"},
})
}

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret}))
func TestMultipleRulesSecondMatchesAndAllows(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{
{Predicate: "country == 'us'", Action: "allow"},
{Predicate: "country == 'de'", Action: "allow"},
{Predicate: "true", Action: "deny"},
})
}

mw := NewRevokeRulesMiddleware(parser, revokeRules)
func TestMultipleRulesLastMatchesAndDenies(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{
{Predicate: "country == 'us'", Action: "allow"},
{Predicate: "country == 'gb'", Action: "allow"},
{Predicate: "true", Action: "deny"},
})
}

w, err := test.Record(
"GET",
"/",
nil,
mw(http.HandlerFunc(test.Ping)),
)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, w.Code)
func TestMultipleRulesNoneMatch(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{
{Predicate: "country == 'us'", Action: "deny"},
{Predicate: "country == 'gb'", Action: "deny"},
})
}
func TestMultipleRulesMatchAndAllow(t *testing.T) {
expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{
{Predicate: "country == 'de'", Action: "allow"},
{Predicate: "true", Action: "allow"},
})
}

func TestWrongJWT(t *testing.T) {
revokeRules := []*AccessRule{
{Predicate: fmt.Sprintf("country == 'de' && iat < %d", time.Now().Add(1*time.Hour).Unix()), Action: "deny"},
}

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: "wrong secret"}))
parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: "secret"}))

mw := NewRevokeRulesMiddleware(parser, revokeRules)
token, err := generateToken(signingAlg, "secret")
require.NoError(t, err)

w, err := test.Record(
"GET",
"/",
map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
},
mw(http.HandlerFunc(test.Ping)),
)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, w.Code)
}

func TestWrongRule(t *testing.T) {
secret := "secret"

revokeRules := []*AccessRule{
{Predicate: "country == 'wrong'", Action: "deny"},
}

parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret}))

mw := NewRevokeRulesMiddleware(parser, revokeRules)
token, err := generateToken(signingAlg, secret)
token, err := generateToken(signingAlg, "wrong secret")
require.NoError(t, err)

w, err := test.Record(
Expand Down
9 changes: 5 additions & 4 deletions pkg/plugin/oauth2/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,22 @@ type AccessRule struct {
Predicate string `bson:"predicate" json:"predicate"`
Action string `bson:"action" json:"action"`
parsed bool
matched bool
}

// IsAllowed checks if the rule is allowed to
func (r *AccessRule) IsAllowed(claims map[string]interface{}) (bool, error) {
var err error

if !r.parsed {
matched, err := r.parse(claims)
r.matched, err = r.parse(claims)
if err != nil {
return false, err
}
}

if !matched {
return true, nil
}
if !r.matched {
return true, nil
}

return r.Action == "allow", err
Expand Down