Skip to content

Commit

Permalink
Option to skip client token signature verification (#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored Sep 27, 2023
1 parent 5de3b4d commit 15adaaa
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 54 deletions.
4 changes: 2 additions & 2 deletions internal/cli/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func verify(config jwtverify.VerifierConfig, ruleConfig rule.Config, token strin
if err != nil {
return jwtverify.ConnectToken{}, err
}
return verifier.VerifyConnectToken(token)
return verifier.VerifyConnectToken(token, false)
}

func verifySub(config jwtverify.VerifierConfig, ruleConfig rule.Config, token string) (jwtverify.SubscribeToken, error) {
Expand All @@ -85,7 +85,7 @@ func verifySub(config jwtverify.VerifierConfig, ruleConfig rule.Config, token st
if err != nil {
return jwtverify.SubscribeToken{}, err
}
return verifier.VerifySubscribeToken(token)
return verifier.VerifySubscribeToken(token, false)
}

// CheckToken checks JWT for user.
Expand Down
8 changes: 4 additions & 4 deletions internal/client/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func (h *Handler) OnClientConnecting(
storage := map[string]any{}

if e.Token != "" {
token, err := h.tokenVerifier.VerifyConnectToken(e.Token)
token, err := h.tokenVerifier.VerifyConnectToken(e.Token, h.ruleContainer.Config().ClientInsecureSkipTokenSignatureVerify)
if err != nil {
if err == jwtverify.ErrTokenExpired {
return centrifuge.ConnectReply{}, centrifuge.ErrorTokenExpired
Expand Down Expand Up @@ -473,7 +473,7 @@ func (h *Handler) OnRefresh(c Client, e centrifuge.RefreshEvent, refreshProxyHan
}
return r, RefreshExtra{}, err
}
token, err := h.tokenVerifier.VerifyConnectToken(e.Token)
token, err := h.tokenVerifier.VerifyConnectToken(e.Token, h.ruleContainer.Config().ClientInsecureSkipTokenSignatureVerify)
if err != nil {
if err == jwtverify.ErrTokenExpired {
return centrifuge.RefreshReply{Expired: true}, RefreshExtra{}, nil
Expand Down Expand Up @@ -531,7 +531,7 @@ func (h *Handler) OnSubRefresh(c Client, subRefreshProxyHandler proxy.SubRefresh
if h.subTokenVerifier != nil {
tokenVerifier = h.subTokenVerifier
}
token, err := tokenVerifier.VerifySubscribeToken(e.Token)
token, err := tokenVerifier.VerifySubscribeToken(e.Token, h.ruleContainer.Config().ClientInsecureSkipTokenSignatureVerify)
if err != nil {
if err == jwtverify.ErrTokenExpired {
return centrifuge.SubRefreshReply{Expired: true}, SubRefreshExtra{}, nil
Expand Down Expand Up @@ -648,7 +648,7 @@ func (h *Handler) OnSubscribe(c Client, e centrifuge.SubscribeEvent, subscribePr
if h.subTokenVerifier != nil {
tokenVerifier = h.subTokenVerifier
}
token, err := tokenVerifier.VerifySubscribeToken(e.Token)
token, err := tokenVerifier.VerifySubscribeToken(e.Token, h.ruleContainer.Config().ClientInsecureSkipTokenSignatureVerify)
if err != nil {
if err == jwtverify.ErrTokenExpired {
return centrifuge.SubscribeReply{}, SubscribeExtra{}, centrifuge.ErrorTokenExpired
Expand Down
4 changes: 2 additions & 2 deletions internal/jwtverify/token_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
)

type Verifier interface {
VerifyConnectToken(token string) (ConnectToken, error)
VerifySubscribeToken(token string) (SubscribeToken, error)
VerifyConnectToken(token string, skipVerify bool) (ConnectToken, error)
VerifySubscribeToken(token string, skipVerify bool) (SubscribeToken, error)
}

type ConnectToken struct {
Expand Down
36 changes: 20 additions & 16 deletions internal/jwtverify/token_verifier_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func (verifier *VerifierJWT) verifySignatureByJWK(token *jwt.Token, tokenVars ma
return verifier.jwksManager.verify(token, tokenVars)
}

func (verifier *VerifierJWT) VerifyConnectToken(t string) (ConnectToken, error) {
func (verifier *VerifierJWT) VerifyConnectToken(t string, skipVerify bool) (ConnectToken, error) {
token, err := jwt.ParseNoVerify([]byte(t)) // Will be verified later.
if err != nil {
return ConnectToken{}, fmt.Errorf("%w: %v", ErrInvalidToken, err)
Expand Down Expand Up @@ -421,13 +421,15 @@ func (verifier *VerifierJWT) VerifyConnectToken(t string) (ConnectToken, error)
}
}

if verifier.jwksManager != nil {
err = verifier.verifySignatureByJWK(token, tokenVars)
} else {
err = verifier.verifySignature(token)
}
if err != nil {
return ConnectToken{}, fmt.Errorf("%w: %v", ErrInvalidToken, err)
if !skipVerify {
if verifier.jwksManager != nil {
err = verifier.verifySignatureByJWK(token, tokenVars)
} else {
err = verifier.verifySignature(token)
}
if err != nil {
return ConnectToken{}, fmt.Errorf("%w: %v", ErrInvalidToken, err)
}
}

if claims.Channel != "" {
Expand Down Expand Up @@ -556,7 +558,7 @@ func (verifier *VerifierJWT) VerifyConnectToken(t string) (ConnectToken, error)
return ct, nil
}

func (verifier *VerifierJWT) VerifySubscribeToken(t string) (SubscribeToken, error) {
func (verifier *VerifierJWT) VerifySubscribeToken(t string, skipVerify bool) (SubscribeToken, error) {
token, err := jwt.ParseNoVerify([]byte(t)) // Will be verified later.
if err != nil {
return SubscribeToken{}, fmt.Errorf("%w: %v", ErrInvalidToken, err)
Expand Down Expand Up @@ -609,13 +611,15 @@ func (verifier *VerifierJWT) VerifySubscribeToken(t string) (SubscribeToken, err
}
}

if verifier.jwksManager != nil {
err = verifier.verifySignatureByJWK(token, tokenVars)
} else {
err = verifier.verifySignature(token)
}
if err != nil {
return SubscribeToken{}, fmt.Errorf("%w: %v", ErrInvalidToken, err)
if !skipVerify {
if verifier.jwksManager != nil {
err = verifier.verifySignatureByJWK(token, tokenVars)
} else {
err = verifier.verifySignature(token)
}
if err != nil {
return SubscribeToken{}, fmt.Errorf("%w: %v", ErrInvalidToken, err)
}
}

now := time.Now()
Expand Down
104 changes: 75 additions & 29 deletions internal/jwtverify/token_verifier_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func Test_tokenVerifierJWT_Valid(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
ct, err := verifier.VerifyConnectToken(jwtValid)
ct, err := verifier.VerifyConnectToken(jwtValid, false)
require.NoError(t, err)
require.Equal(t, "2694", ct.UserID)
require.NotNil(t, ct.Info)
Expand All @@ -219,7 +219,7 @@ func Test_tokenVerifierJWT_Audience(t *testing.T) {
require.NoError(t, err)

// Token without aud.
_, err = verifier.VerifyConnectToken(jwtValid)
_, err = verifier.VerifyConnectToken(jwtValid, false)
require.ErrorIs(t, err, ErrInvalidToken)

// Generate token with aud.
Expand All @@ -229,25 +229,25 @@ func Test_tokenVerifierJWT_Audience(t *testing.T) {
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "test2", "", "", ""}, ruleContainer)
require.NoError(t, err)

_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.ErrorIs(t, err, ErrInvalidToken)

// Verifier with token audience.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "test", "", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.NoError(t, err)

// Verifier with token audience - valid.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "test", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.NoError(t, err)

// Verifier with token audience - invalid.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "test2", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.Error(t, err)
}

Expand All @@ -259,7 +259,7 @@ func Test_tokenVerifierJWT_Issuer(t *testing.T) {
require.NoError(t, err)

// Token without iss.
_, err = verifier.VerifyConnectToken(jwtValid)
_, err = verifier.VerifyConnectToken(jwtValid, false)
require.ErrorIs(t, err, ErrInvalidToken)

// Generate token with iss.
Expand All @@ -268,25 +268,25 @@ func Test_tokenVerifierJWT_Issuer(t *testing.T) {
// Verifier with issuer which does not match token iss.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "test2", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.ErrorIs(t, err, ErrInvalidToken)

// Verifier with token issuer.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "test", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.NoError(t, err)

// Verifier with token issuer regex - valid.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", "test"}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.NoError(t, err)

// Verifier with token issuer regex - invalid.
verifier, err = NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", "test2"}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(token)
_, err = verifier.VerifyConnectToken(token, false)
require.Error(t, err)
}

Expand All @@ -296,7 +296,7 @@ func Test_tokenVerifierJWT_Expired(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(jwtExpired)
_, err = verifier.VerifyConnectToken(jwtExpired, false)
require.Error(t, err)
require.Equal(t, ErrTokenExpired, err)
}
Expand All @@ -307,7 +307,7 @@ func Test_tokenVerifierJWT_DisabledAlgorithm(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(jwtExpired)
_, err = verifier.VerifyConnectToken(jwtExpired, false)
require.Error(t, err)
require.True(t, errors.Is(err, ErrInvalidToken), err.Error())
}
Expand All @@ -318,8 +318,24 @@ func Test_tokenVerifierJWT_InvalidSignature(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(jwtInvalidSignature)
_, err = verifier.VerifyConnectToken(jwtInvalidSignature, false)
require.Error(t, err)

// Test that skipVerify results into accepted token.
ct, err := verifier.VerifyConnectToken(jwtValid+"xxx", true)
require.NoError(t, err)
require.Equal(t, "2694", ct.UserID)
}

func Test_tokenVerifierJWT_InvalidSignature_SkipVerify(t *testing.T) {
ruleConfig := rule.DefaultConfig
ruleContainer, err := rule.NewContainer(ruleConfig)
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
ct, err := verifier.VerifyConnectToken(jwtValid+"xxx", true)
require.NoError(t, err)
require.Equal(t, "2694", ct.UserID)
}

func Test_tokenVerifierJWT_WithNotBefore(t *testing.T) {
Expand All @@ -328,7 +344,11 @@ func Test_tokenVerifierJWT_WithNotBefore(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
_, err = verifier.VerifyConnectToken(jwtNotBefore)
_, err = verifier.VerifyConnectToken(jwtNotBefore, false)
require.Error(t, err)

// Test that skipVerify still results into unaccepted token if it's expired.
_, err = verifier.VerifyConnectToken(jwtNotBefore, true)
require.Error(t, err)
}

Expand All @@ -338,7 +358,7 @@ func Test_tokenVerifierJWT_StringAudience(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
ct, err := verifier.VerifyConnectToken(jwtStringAud)
ct, err := verifier.VerifyConnectToken(jwtStringAud, false)
require.NoError(t, err)
require.Equal(t, "2694", ct.UserID)
}
Expand All @@ -349,7 +369,7 @@ func Test_tokenVerifierJWT_ArrayAudience(t *testing.T) {
require.NoError(t, err)
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(t, err)
ct, err := verifier.VerifyConnectToken(jwtArrayAud)
ct, err := verifier.VerifyConnectToken(jwtArrayAud, false)
require.NoError(t, err)
require.Equal(t, "2694", ct.UserID)
}
Expand Down Expand Up @@ -441,7 +461,7 @@ func Test_tokenVerifierJWT_VerifyConnectToken(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.verifier.VerifyConnectToken(tt.args.token)
got, err := tt.verifier.VerifyConnectToken(tt.args.token, false)
if tt.wantErr && err == nil {
t.Errorf("VerifyConnectToken() should return error")
}
Expand Down Expand Up @@ -533,7 +553,7 @@ func Test_tokenVerifierJWT_VerifyConnectTokenWithJWK(t *testing.T) {

token := getRSAConnToken(tt.token.user, tt.token.exp, privKey, jwt.WithKeyID(tt.jwk.kid))

got, err := verifier.VerifyConnectToken(token)
got, err := verifier.VerifyConnectToken(token, false)
if tt.wantErr {
r.Error(err)
return
Expand Down Expand Up @@ -567,12 +587,13 @@ func Test_tokenVerifierJWT_VerifySubscribeToken(t *testing.T) {

_time := time.Now()
tests := []struct {
name string
verifier Verifier
args args
want SubscribeToken
wantErr bool
expired bool
name string
verifier Verifier
args args
want SubscribeToken
wantErr bool
expired bool
skipVerify bool
}{
{
name: "Empty JWT",
Expand Down Expand Up @@ -614,6 +635,31 @@ func Test_tokenVerifierJWT_VerifySubscribeToken(t *testing.T) {
},
},
wantErr: false,
}, {
name: "Invalid JWT HS",
verifier: verifierJWT,
args: args{
token: getRSASubscribeToken("channel1", "client1", _time.Add(24*time.Hour).Unix(), nil) + "xxx",
},
want: SubscribeToken{},
wantErr: true,
skipVerify: false,
}, {
name: "Invalid JWT HS but verify skipped",
verifier: verifierJWT,
args: args{
token: getRSASubscribeToken("channel1", "client1", _time.Add(24*time.Hour).Unix(), nil) + "xxx",
},
want: SubscribeToken{
Client: "client1",
Channel: "channel1",
Options: centrifuge.SubscribeOptions{
ExpireAt: _time.Add(24 * time.Hour).Unix(),
ChannelInfo: []byte("{}"),
},
},
wantErr: false,
skipVerify: true,
}, {
name: "Valid JWT RS",
verifier: verifierJWT,
Expand Down Expand Up @@ -649,12 +695,12 @@ func Test_tokenVerifierJWT_VerifySubscribeToken(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.verifier.VerifySubscribeToken(tt.args.token)
got, err := tt.verifier.VerifySubscribeToken(tt.args.token, tt.skipVerify)
if tt.wantErr && err == nil {
t.Errorf("VerifySubscribeToken() should return error")
}
if !tt.wantErr && err != nil {
t.Errorf("VerifySubscribeToken() should not return error")
t.Errorf("VerifySubscribeToken() should not return error, but returned: %v", err)
}
if tt.expired && err != ErrTokenExpired {
t.Errorf("VerifySubscribeToken() should return token expired error")
Expand All @@ -674,7 +720,7 @@ func BenchmarkConnectTokenVerify_Valid(b *testing.B) {
require.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := verifierJWT.VerifyConnectToken(jwtValid)
_, err := verifierJWT.VerifyConnectToken(jwtValid, false)
if err != nil {
b.Fatal(err)
}
Expand All @@ -690,7 +736,7 @@ func BenchmarkConnectTokenVerify_Expired(b *testing.B) {
verifier, err := NewTokenVerifierJWT(VerifierConfig{"secret", nil, nil, "", "", "", "", ""}, ruleContainer)
require.NoError(b, err)
for i := 0; i < b.N; i++ {
_, err := verifier.VerifyConnectToken(jwtExpired)
_, err := verifier.VerifyConnectToken(jwtExpired, false)
if err != ErrTokenExpired {
panic(err)
}
Expand Down
Loading

0 comments on commit 15adaaa

Please sign in to comment.