Skip to content

Commit

Permalink
refactor validator and implement validator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
erudenko committed Jul 12, 2023
1 parent d588774 commit e82aa80
Show file tree
Hide file tree
Showing 19 changed files with 387 additions and 299 deletions.
4 changes: 2 additions & 2 deletions jwt/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func JWT(eh ErrorHandler, c validator.Config) (Handler, error) {

// TokenFromContext returns token from request context.
// Or nil if there is no token in context.
func TokenFromContext(ctx context.Context) model.Token {
v, _ := ctx.Value(model.TokenContextKey).(model.Token)
func TokenFromContext(ctx context.Context) model.JWToken {
v, _ := ctx.Value(model.TokenContextKey).(model.JWToken)
return v
}
10 changes: 5 additions & 5 deletions jwt/middleware/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,22 @@ func TestTokenFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
want model.Token
want model.JWToken
}{
{
name: "no token",
ctx: context.Background(),
want: nil,
want: model.JWToken{},
},
{
name: "nil token",
ctx: context.WithValue(context.Background(), model.TokenContextKey, nil),
want: nil,
want: model.JWToken{},
},
{
name: "token exists",
ctx: context.WithValue(context.Background(), model.TokenContextKey, model.Token(&model.JWToken{})),
want: &model.JWToken{},
ctx: context.WithValue(context.Background(), model.TokenContextKey, model.JWToken{}),
want: model.JWToken{},
},
}
for _, tt := range tests {
Expand Down
6 changes: 3 additions & 3 deletions jwt/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
)

// Parse parses token data from the string representation.
func ParseTokenString(str string) (model.Token, error) {
func ParseTokenString(str string) (model.JWToken, error) {
tokenString := strings.TrimSpace(str)
parser := jwt.Parser{}

token, _, err := parser.ParseUnverified(tokenString, &model.Claims{})
if err != nil {
return nil, err
return model.JWToken{}, err
}

return &model.JWToken{Token: *token}, nil
return model.JWToken{Token: *token}, nil
}
39 changes: 16 additions & 23 deletions jwt/service/jwt_token_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ import (
// Arguments:
// - privateKeyPath - the path to the private key in pem format. Please keep it in a secret place.
// - publicKeyPath - the path to the public key.
func NewJWTokenService(privateKey any, issuer string, stor model.TokenStorage, settings model.SecurityServerSettings) (model.TokenService, error) {
func NewJWTokenService(privateKey any, issuer string, settings model.SecurityServerSettings) (model.TokenService, error) {
if privateKey == nil {
return nil, fmt.Errorf("private key is empty")
}

t := &JWTokenService{
iss: issuer,
ts: stor,
pk: privateKey,
settings: settings,
}
Expand All @@ -41,7 +40,6 @@ func NewJWTokenService(privateKey any, issuer string, stor model.TokenStorage, s
// JWTokenService is a JWT token service.
type JWTokenService struct {
pk any // *ecdsa.PrivateKey, or *rsa.PrivateKey
ts model.TokenStorage
settings model.SecurityServerSettings
iss string
aCache string // algorithm cache
Expand All @@ -58,10 +56,13 @@ func (ts *JWTokenService) PrivateKey() any {
return ts.pk
}

func (ts *JWTokenService) NewToken(tokenType model.TokenType, u model.User, aud []string, fields []string, payload map[string]any) (model.Token, error) {
func (ts *JWTokenService) NewToken(tokenType model.TokenType, u model.User, aud []string, fields []string, payload map[string]any) (model.JWToken, error) {
// we have to collect all payloads to one map
userPayload := xmaps.FieldsToMap(u)
userPayload = xmaps.FilterMap(userPayload, fields)
if payload == nil {
payload = map[string]any{}
}
maps.Copy(payload, userPayload)
lifespan := ts.settings.TokenLifetime(tokenType)
ia := jwt.NewNumericDate(j.TimeFunc())
Expand All @@ -81,30 +82,22 @@ func (ts *JWTokenService) NewToken(tokenType model.TokenType, u model.User, aud

sm := ts.jwtMethod()
if sm == nil {
return nil, l.ErrorTokenMethodInvalid
return model.JWToken{}, l.ErrorTokenMethodInvalid
}

token := model.TokenWithClaims(sm, ts.KeyID(), claims)
return &model.JWToken{
Token: token,
New: true,
}, nil
return token, nil
}

func (ts *JWTokenService) SignToken(t model.Token) (string, error) {
token, ok := t.(*model.JWToken)
if !ok {
return "", l.ErrorTokenInvalid
}

func (ts *JWTokenService) SignToken(t model.JWToken) (string, error) {
if err := t.Validate(); err != nil {
return "", l.LocalizedError{
ErrID: l.ErrorValidatingToken,
Details: []any{err},
}
}

str, err := token.SignedString(ts.pk)
str, err := t.SignedString(ts.pk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -179,7 +172,7 @@ func (ts *JWTokenService) KeyID() string {
}

// Parse parses token data from the string representation.
func (ts *JWTokenService) Parse(s string) (model.Token, error) {
func (ts *JWTokenService) Parse(s string) (model.JWToken, error) {
tokenString := strings.TrimSpace(s)
token, err := jwt.ParseWithClaims(tokenString, &model.Claims{}, func(token *jwt.Token) (any, error) {
// since we only use the one private key to sign the tokens,
Expand All @@ -188,25 +181,25 @@ func (ts *JWTokenService) Parse(s string) (model.Token, error) {
return ts.PublicKey(), nil
})
if err != nil {
return nil, err
return model.JWToken{}, err
}

return &model.JWToken{Token: *token}, nil
return model.JWToken{Token: *token}, nil
}

// ValidateTokenString parses token and validates it.
func (ts *JWTokenService) ValidateTokenString(tstr string, v jv.Validator, tokenType string) (model.Token, error) {
func (ts *JWTokenService) ValidateTokenString(tstr string, v jv.Validator, tokenType string) (model.JWToken, error) {
token, err := ts.Parse(tstr)
if err != nil {
return nil, err
return model.JWToken{}, err
}

if err := v.Validate(token); err != nil {
return nil, err
return model.JWToken{}, err
}

if token.Type() != tokenType {
return nil, err
return model.JWToken{}, err
}

return token, nil
Expand Down
97 changes: 97 additions & 0 deletions jwt/service/jwt_token_service_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,100 @@
package service_test

import (
"reflect"
"testing"

"github.com/madappgang/identifo/v2/jwt/service"
"github.com/madappgang/identifo/v2/model"
"github.com/madappgang/identifo/v2/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
keyPath = "../test_artifacts/private.pem"
testIssuer = "aooth.madappgang.com"
tokenStringExample = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MTYyMzkwMjIsInN1YiI6IjEyMzQ1Njc4OTAifQ.AAlGn8m8YG3emPa8CIS6TS-ndqaZCGUydnhU8FznyZ1McYQKkLlcqDW2c04q9ZxKDZHeiSyNIDOKA-EP0GVthQ"
)

func createTokenService(t *testing.T) model.TokenService {
keyStorage, err := storage.NewKeyStorage(model.FileStorageSettings{
Type: model.FileStorageTypeLocal,
Local: model.FileStorageLocal{
Path: keyPath,
},
})
require.NoError(t, err)

privateKey, err := keyStorage.LoadPrivateKey()
require.NoError(t, err)

tokenService, err := service.NewJWTokenService(
privateKey,
testIssuer,
model.DefaultServerSettings.SecuritySettings,
)
require.NoError(t, err)
return tokenService
}

func TestParseString(t *testing.T) {
tokenService := createTokenService(t)

token, err := tokenService.Parse(tokenStringExample)
assert.NoError(t, err)
assert.NotEmpty(t, token)

// claims
_, ok := token.Claims.(*model.Claims)
require.True(t, ok)

// assert.Equal(t, string(model.TokenTypeAccess), token.Type())

assert.Equal(t, "1234567890", token.Subject())
assert.Equal(t, int64(1516239022), token.IssuedAt().Unix())
}

func TestTokenToString(t *testing.T) {
tokenService := createTokenService(t)

token, err := tokenService.Parse(tokenStringExample)
require.NoError(t, err)
require.NotNil(t, token)

tokenString, err := tokenService.SignToken(token)
assert.NoError(t, err)

token2, err := tokenService.Parse(tokenString)
assert.NoError(t, err)
assert.NotNil(t, token2)

claims1 := token.Claims
claims2 := token2.Claims

if !reflect.DeepEqual(token.Header, token2.Header) {
t.Errorf("Headers = %+v, want %+v", token.Header, token2.Header)
}
if !reflect.DeepEqual(claims1, claims2) {
t.Errorf("Claims = %+v, want %+v", claims1, claims2)
}
}

func TestNewToken(t *testing.T) {
tokenService := createTokenService(t)

user := model.User{
ID: "12345566",
Username: "username",
Email: "username@gmailc.om",
}
token, err := tokenService.NewToken(model.TokenTypeAccess, user, []string{"12345"}, []string{"Email"}, nil)
assert.NoError(t, err)

tokenString, err := tokenService.SignToken(token)
assert.NoError(t, err)

_, err = tokenService.Parse(tokenString)
require.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}
4 changes: 0 additions & 4 deletions jwt/test_artifacts/public.pem

This file was deleted.

Loading

0 comments on commit e82aa80

Please sign in to comment.