From 2789116d7e032898fe302995730311822eea222a Mon Sep 17 00:00:00 2001 From: Gregor Weckbecker Date: Wed, 29 May 2019 11:13:30 +0200 Subject: [PATCH] Return empty claims if user claims endpoint responds with not found --- login/user_claims_provider.go | 14 ++++++++++---- login/user_claims_provider_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/login/user_claims_provider.go b/login/user_claims_provider.go index d7cf62a4..afe1ec1d 100644 --- a/login/user_claims_provider.go +++ b/login/user_claims_provider.go @@ -46,6 +46,9 @@ func (provider *userClaimsProvider) Claims(userInfo model.UserInfo) (jwt.Claims, resp.Body.Close() }() + if resp.StatusCode == http.StatusNotFound { + return customClaims(userInfo.AsMap()), nil + } if resp.StatusCode != http.StatusOK { return nil, errors.Errorf("bad http response code %d", resp.StatusCode) } @@ -58,10 +61,7 @@ func (provider *userClaimsProvider) Claims(userInfo model.UserInfo) (jwt.Claims, return nil, err } - claims := customClaims(userInfo.AsMap()) - claims.merge(remoteClaims) - - return claims, nil + return mergeClaims(userInfo, remoteClaims), nil } func (provider *userClaimsProvider) buildURL(userInfo model.UserInfo) string { @@ -91,6 +91,12 @@ func (provider *userClaimsProvider) buildURL(userInfo model.UserInfo) string { return u.String() } +func mergeClaims(userInfo model.UserInfo, remoteClaims map[string]interface{}) customClaims { + claims := customClaims(userInfo.AsMap()) + claims.merge(remoteClaims) + return claims +} + func validateURL(s string) error { _, err := url.Parse(s) return errors.Wrap(err, "invalid claims provider url") diff --git a/login/user_claims_provider_test.go b/login/user_claims_provider_test.go index 108ea5cd..ae964ba1 100644 --- a/login/user_claims_provider_test.go +++ b/login/user_claims_provider_test.go @@ -78,6 +78,36 @@ func Test_userClaimsProvider_Claims(t *testing.T) { ) } +func Test_userClaimsProvider_Claims_NotFound(t *testing.T) { + mock := createMockServer( + mockResponse{ + url: endpointPath, + status: http.StatusNotFound, + body: ``, + }, + ) + defer mock.Close() + provider, err := newUserClaimsProvider(mock.URL+endpointPath, token, time.Minute) + require.NoError(t, err) + + claims, err := provider.Claims(model.UserInfo{ + Sub: "test@example.com", + Origin: "origin", + Domain: "example.com", + }) + + require.NoError(t, err) + + assert.Equal(t, + customClaims{ + "domain": "example.com", + "origin": "origin", + "sub": "test@example.com", + }, + claims, + ) +} + func Test_userClaimsProvider_Claims_EndpointNotReachable(t *testing.T) { provider, err := newUserClaimsProvider("http://not-exists.example.com", token, time.Millisecond) require.NoError(t, err)