diff --git a/jwt/service/jwt_token_service.go b/jwt/service/jwt_token_service.go index 906dd9dd..51a4366f 100644 --- a/jwt/service/jwt_token_service.go +++ b/jwt/service/jwt_token_service.go @@ -92,10 +92,53 @@ func (ts *JWTokenService) Issuer() string { return ts.issuer } -// CreateToken creates token with specific types. -func (ts *JWTokenService) NewToken(tokenType model.TokenType, userID string, payload []any) (model.Token, error) { - // TODO: implement general token creation for all token types - return &model.JWToken{}, nil +func (ts *JWTokenService) NewToken(tokenType model.TokenType, u model.User, fields []string, payload map[string]any) (model.Token, error) { + payload := make(map[string]any) + jwt.StandardClaims + if model.SliceContains(app.TokenPayload, PayloadName) { + payload[PayloadName] = u.Username + } + + tokenType := model.TokenTypeAccess + if len(tokenPayload) > 0 { + for k, v := range tokenPayload { + payload[k] = v + } + } + + now := ijwt.TimeFunc().Unix() + + lifespan := app.TokenLifespan + if lifespan == 0 { + lifespan = TokenLifespan + } + + claims := model.Claims{ + Scopes: strings.Join(scopes, " "), + Payload: payload, + Type: string(tokenType), + StandardClaims: jwt.StandardClaims{ + ExpiresAt: (now + lifespan), + Issuer: ts.issuer, + Subject: u.ID, + Audience: app.ID, + IssuedAt: now, + }, + } + + sm := ts.jwtMethod() + if sm == nil { + return nil, errors.New("unable to creating signing method") + } + + token := model.NewTokenWithClaims(sm, ts.KeyID(), claims) + if token == nil { + return nil, ErrCreatingToken + } + return &model.JWToken{JWT: token, New: true}, nil +} + +func (ts *JWTokenService) SignToken(token model.Token) (string, error) { } // Algorithm returns signature algorithm. @@ -147,17 +190,13 @@ func (ts *JWTokenService) PublicKey() interface{} { return ts.cachedPublicKey } -func (ts *JWTokenService) SetPrivateKey(key interface{}) { +func (ts *JWTokenService) SetPrivateKey(key any) { fmt.Printf("Changing private key for Token service, all new tokens will be signed with a new key!!!\n") ts.privateKey = key ts.cachedPublicKey = nil ts.cachedAlgorithm = "" } -func (ts *JWTokenService) PrivateKey() interface{} { - return ts.privateKey -} - // KeyID returns public key ID, using SHA-1 fingerprint. func (ts *JWTokenService) KeyID() string { pk := ts.PublicKey() diff --git a/model/token.go b/model/token.go index 69030a73..9c3ed74b 100644 --- a/model/token.go +++ b/model/token.go @@ -1,6 +1,7 @@ package model import ( + "encoding/json" "time" jwt "github.com/golang-jwt/jwt/v4" @@ -50,7 +51,7 @@ type Token interface { // NewTokenWithClaims generates new JWT token with claims and keyID. func NewTokenWithClaims(method jwt.SigningMethod, kid string, claims jwt.Claims) *jwt.Token { return &jwt.Token{ - Header: map[string]interface{}{ + Header: map[string]any{ "typ": "JWT", "alg": method.Alg(), "kid": kid, @@ -189,5 +190,17 @@ type Claims struct { jwt.StandardClaims } +func (c *Claims) SC() *jwt.StandardClaims { + return &c.StandardClaims +} + +func (c *Claims) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + *jwt.StandardClaims + }{ + StandardClaims: &c.StandardClaims, + }) +} + // Full example of how to use JWT tokens: // https://github.com/form3tech-oss/jwt-go/blob/master/cmd/jwt/app.go diff --git a/model/token_service.go b/model/token_service.go index 5abb6a2d..0627bf86 100644 --- a/model/token_service.go +++ b/model/token_service.go @@ -24,10 +24,9 @@ type TokenService interface { // keys management // replace the old private key with a new one - SetPrivateKey(key interface{}) - PrivateKey() interface{} + SetPrivateKey(key any) // not using crypto.PublicKey here to avoid dependencies - PublicKey() interface{} + PublicKey() any KeyID() string } diff --git a/model/user_scopes.go b/model/user_scopes.go index b018dbd2..104c8f79 100644 --- a/model/user_scopes.go +++ b/model/user_scopes.go @@ -14,8 +14,8 @@ const ( IDTokenScope = "id" TenantScope = "tenant" AccessTokenScopePrefix = "access:" - TenantScopePrefix = "tenant:" // tenant:123 - request tenant data only for tenant 123 - TenantScopeAll = "all" // "tenant:all" - return all scopes for all ten + TenantScopePrefix = "tenant:" // tenant:123 - request tenant data only for tenant 123 + TenantScopeAll = "tenant:all" // "tenant:all" - return all scopes for all ten ) func FieldsetForScopes(scopes []string) []string { diff --git a/server/controller/slices.go b/server/controller/slices.go deleted file mode 100644 index 7a0ea758..00000000 --- a/server/controller/slices.go +++ /dev/null @@ -1,58 +0,0 @@ -package controller - -// sliceContains checks if value is inside slice. -func sliceContains[T comparable](slice []T, value T) bool { - for _, v := range slice { - if v == value { - return true - } - } - return false -} - -// intersect sl2 from sl1 -func intersect[T comparable](sl1, sl2 []T) []T { - var intersect []T - - // Loop two times, first to find slice1 strings not in slice2, - // second loop to find slice2 strings not in slice1 - for i := 0; i < 2; i++ { - for _, s1 := range sl1 { - found := false - for _, s2 := range sl2 { - if s1 == s2 { - found = true - break - } - } - if found { - intersect = append(intersect, s1) - } - } - // Swap the slices, only if it was the first loop - if i == 0 { - sl1, sl2 = sl2, sl1 - } - } - - return intersect -} - -// removeDuplicate -func removeDuplicate[T comparable](slice []T) []T { - allKeys := make(map[T]bool) - list := []T{} - for _, item := range slice { - if _, ok := allKeys[item]; !ok { - allKeys[item] = true - list = append(list, item) - } - } - return list -} - -// concatUnique sl2 with sl1, only unique values -func concatUnique[T comparable](sl1, sl2 []T) []T { - res := append(sl1, sl2...) - return removeDuplicate(res) -} diff --git a/server/controller/user_controller_login.go b/server/controller/user_controller_login.go index e83b44d6..d1df0c06 100644 --- a/server/controller/user_controller_login.go +++ b/server/controller/user_controller_login.go @@ -5,19 +5,36 @@ import ( "strings" "github.com/madappgang/identifo/v2/model" + "golang.org/x/exp/maps" ) -// TODO! we need to add tenant related information flattered, as: -// "112233:admin:user", where 112233 - tenant ID, admin - a group, user - role in a group +// tenant related information flattered, as: +// "112233:admin" : "user", where 112233 - tenant ID, admin - a group, user - role in a group +// and tenant name added as well: +// "tenant:112233": "tenant corporation" func (c *UserStorageController) getJWTTokens(ctx context.Context, app model.AppData, u model.User, scopes []string) (model.AuthResponse, error) { // check if we are + var err error // TODO: implement custom payload provider for app resp := model.AuthResponse{} + ud := model.UserData{} + ap := AccessTokenScopes(scopes) // fields for access token apf := model.FieldsetForScopes(scopes) + data := map[string]any{} - at, err := c.ts.NewToken(model.TokenTypeAccess, u, apf, nil) + // access token needs tenant data in it + if needTenantInfo(ap) { + ud, err = c.u.UserData(ctx, u.ID, model.UserDataFieldTenantMembership) + if err != nil { + return resp, err + } + ti := TenantData(ud.TenantMembership, scopes) + maps.Copy(data, ti) + } + // create access token + at, err := c.ts.NewToken(model.TokenTypeAccess, u, apf, data) if err != nil { return resp, err } @@ -33,6 +50,20 @@ func (c *UserStorageController) getJWTTokens(ctx context.Context, app model.AppD f := model.FieldsetForScopes(scopes) data := map[string]any{} + // if we need tenant data in id token + if needTenantInfo(scopes) { + // we can already have userData fetched for access token + if len(ud.UserID) == 0 { + ud, err = c.u.UserData(ctx, u.ID, model.UserDataFieldTenantMembership) + if err != nil { + return resp, err + } + } + ti := TenantData(ud.TenantMembership, scopes) + maps.Copy(data, ti) + } + + // create id token idt, err := c.ts.NewToken(model.TokenTypeID, u, f, data) if err != nil { return resp, err @@ -79,10 +110,25 @@ func AccessTokenScopes(scopes []string) []string { return result } -func TenantData(ud model.UserData) map[string]any { +func TenantData(ud []model.TenantMembership, scopes []string) map[string]any { res := map[string]any{} - for _, t := range ud.TenantMembership { + filter := []string{} + getAll := false + for _, s := range scopes { + if s == model.TenantScopeAll { + getAll = true + break + } else if strings.HasPrefix(s, model.TenantScopePrefix) && len(s) > len(model.TenantScopePrefix) { + filter = append(filter, s[len(model.TenantScopePrefix):]) + } + } + for _, t := range ud { + // skip the scopes we don't need to have + if !getAll && !sliceContains(filter, t.TenantID) { + continue + } tid := t.TenantID + res["tenant:"+t.TenantID] = t.TenantName for k, v := range t.Groups { // "tenant_id:group_id" : "role" res[tid+":"+k] = v @@ -90,3 +136,12 @@ func TenantData(ud model.UserData) map[string]any { } return res } + +func needTenantInfo(scopes []string) bool { + for _, s := range scopes { + if strings.HasPrefix(s, model.TenantScopePrefix) { + return true + } + } + return false +} diff --git a/server/controller/user_controller_login_test.go b/server/controller/user_controller_login_test.go index 312bae29..8eba093d 100644 --- a/server/controller/user_controller_login_test.go +++ b/server/controller/user_controller_login_test.go @@ -5,6 +5,7 @@ import ( "github.com/madappgang/identifo/v2/model" "github.com/madappgang/identifo/v2/server/controller" + "github.com/madappgang/identifo/v2/storage/mock" "github.com/stretchr/testify/assert" ) @@ -12,26 +13,52 @@ func TestAddTenantData(t *testing.T) { ud := model.UserData{ TenantMembership: []model.TenantMembership{ { - TenantID: "tenant1", - Groups: map[string]string{"default": "admin", "group1": "user"}, + TenantID: "tenant1", + TenantName: "I am a tenant 1", + Groups: map[string]string{"default": "admin", "group1": "user"}, }, { - TenantID: "tenant2", - Groups: map[string]string{"default": "guest", "group33": "admin"}, + TenantID: "tenant2", + TenantName: "Apple corporation", + Groups: map[string]string{"default": "guest", "group33": "admin"}, }, }, } - flattenData := controller.TenantData(ud) + flattenData := controller.TenantData(ud.TenantMembership, nil) + assert.Len(t, flattenData, 0) + + flattenData = controller.TenantData(ud.TenantMembership, []string{"tenant:all"}) assert.Contains(t, flattenData, "tenant1:default") assert.Contains(t, flattenData, "tenant1:group1") assert.Contains(t, flattenData, "tenant2:default") assert.Contains(t, flattenData, "tenant2:group33") + assert.Contains(t, flattenData, "tenant:tenant1") + assert.Contains(t, flattenData, "tenant:tenant2") assert.Equal(t, flattenData["tenant1:default"], "admin") assert.Equal(t, flattenData["tenant1:group1"], "user") assert.Equal(t, flattenData["tenant2:default"], "guest") assert.Equal(t, flattenData["tenant2:group33"], "admin") - assert.Len(t, flattenData, 4) + assert.Equal(t, flattenData["tenant:tenant1"], "I am a tenant 1") + assert.Equal(t, flattenData["tenant:tenant2"], "Apple corporation") + assert.Len(t, flattenData, 6) + + scopes := []string{"tenant:tenant1", "tenant:tenant3"} + flattenData = controller.TenantData(ud.TenantMembership, scopes) + assert.Contains(t, flattenData, "tenant1:default") + assert.Contains(t, flattenData, "tenant1:group1") + assert.NotContains(t, flattenData, "tenant2:default") + assert.NotContains(t, flattenData, "tenant2:group33") + assert.Contains(t, flattenData, "tenant:tenant1") + assert.NotContains(t, flattenData, "tenant:tenant2") + assert.Equal(t, flattenData["tenant1:default"], "admin") + assert.Equal(t, flattenData["tenant1:group1"], "user") + assert.Equal(t, flattenData["tenant:tenant1"], "I am a tenant 1") + assert.Len(t, flattenData, 3) + + scopes = []string{"tenant:tenant1", "tenant:tenant3", "tenant:all"} + flattenData = controller.TenantData(ud.TenantMembership, scopes) + assert.Len(t, flattenData, 6) } func TestAccessTokenScopes(t *testing.T) { @@ -48,3 +75,39 @@ func TestAccessTokenScopes(t *testing.T) { assert.Contains(t, r, "profile") assert.Contains(t, r, "oidc") } + +func TestRequestJWT(t *testing.T) { + tm := []model.TenantMembership{ + { + TenantID: "tenant1", + TenantName: "I am a tenant 1", + Groups: map[string]string{"default": "admin", "group1": "user"}, + }, + { + TenantID: "tenant2", + TenantName: "Apple corporation", + Groups: map[string]string{"default": "guest", "group33": "admin"}, + }, + } + u := mock.UserStorage{ + UData: map[string]model.UserData{ + "user1": { + UserID: "user1", + TenantMembership: tm, + }, + }, + } + + + c := controller.NewUserStorageController( + &u, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + model.ServerSettings{}, + ) +} diff --git a/storage/mock/user_storage.go b/storage/mock/user_storage.go index cab2673f..7b965356 100644 --- a/storage/mock/user_storage.go +++ b/storage/mock/user_storage.go @@ -10,6 +10,7 @@ import ( type UserStorage struct { Storage Users []model.User + UData map[string]model.UserData } func (us *UserStorage) UserByID(ctx context.Context, id string) (model.User, error) { @@ -37,7 +38,12 @@ func (us *UserStorage) UserByFederatedID(ctx context.Context, idType model.UserF } func (us *UserStorage) UserData(ctx context.Context, userID string, fields ...model.UserDataField) (model.UserData, error) { - return model.UserData{}, l.ErrorLoginTypeNotSupported + d, ok := us.UData[userID] + if !ok { + return model.UserData{}, l.ErrorUserNotFound + } + + return d, nil } func (us *UserStorage) ImportJSON(data []byte, clearOldData bool) error { diff --git a/model/copy_fields.go b/tools/xmaps/copy_fields.go similarity index 99% rename from model/copy_fields.go rename to tools/xmaps/copy_fields.go index 38388b64..d9ac8524 100644 --- a/model/copy_fields.go +++ b/tools/xmaps/copy_fields.go @@ -1,4 +1,4 @@ -package model +package xmaps import ( "errors" diff --git a/model/copy_fields_test.go b/tools/xmaps/copy_fields_test.go similarity index 89% rename from model/copy_fields_test.go rename to tools/xmaps/copy_fields_test.go index 130c3fcc..3fa996cd 100644 --- a/model/copy_fields_test.go +++ b/tools/xmaps/copy_fields_test.go @@ -1,9 +1,10 @@ -package model_test +package xmaps_test import ( "testing" "github.com/madappgang/identifo/v2/model" + "github.com/madappgang/identifo/v2/tools/xmaps" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,7 +27,7 @@ func TestCopyFields(t *testing.T) { } fields := []string{"ID", "Name", "Password"} - result := model.CopyFields(u, fields) + result := xmaps.CopyFields(u, fields) assert.Empty(t, result.Phone) assert.Empty(t, result.Address) assert.Equal(t, u.ID, result.ID) @@ -51,7 +52,7 @@ func TestCopyDstFields(t *testing.T) { } dst := testShortUser{} - err := model.CopyDstFields(u, &dst) + err := xmaps.CopyDstFields(u, &dst) require.NoError(t, err) assert.Empty(t, dst.Other) @@ -88,23 +89,23 @@ func TestFilledValues(t *testing.T) { expected := []string{"ID", "Name", "Married", "Company.Name", "Company.People"} // pointer to struct should works - result := model.Filled(&tu) + result := xmaps.Filled(&tu) assert.EqualValues(t, result, expected) // reference to struct should works - result = model.Filled(tu) + result = xmaps.Filled(tu) assert.EqualValues(t, result, expected) // let's clear the Name tu.Name = nil expected = []string{"ID", "Married", "Company.Name", "Company.People"} - result = model.Filled(tu) + result = xmaps.Filled(tu) assert.EqualValues(t, result, expected) // empty value should be treated as non nil tu.Name = sp("") expected = []string{"ID", "Name", "Married", "Company.Name", "Company.People"} - result = model.Filled(tu) + result = xmaps.Filled(tu) assert.EqualValues(t, result, expected) } @@ -133,7 +134,7 @@ func TestCopyOnlyFilledValues(t *testing.T) { } dst := model.User{} - err := model.CopyDstFields(tu, &dst) + err := xmaps.CopyDstFields(tu, &dst) assert.NoError(t, err) assert.Equal(t, *tu.ID, dst.ID) @@ -151,7 +152,7 @@ func TestContainsFields(t *testing.T) { }, } - contains := model.ContainsFields(tu, []string{"ID", "Name", "NonPointer", "Whatever", "Company.Name", "Company.People", "Company"}) + contains := xmaps.ContainsFields(tu, []string{"ID", "Name", "NonPointer", "Whatever", "Company.Name", "Company.People", "Company"}) expected := []string{"ID", "Name", "NonPointer", "Company.Name", "Company.People"} assert.Equal(t, contains, expected) } diff --git a/tools/xmaps/maps.go b/tools/xmaps/maps.go new file mode 100644 index 00000000..cb72ebf3 --- /dev/null +++ b/tools/xmaps/maps.go @@ -0,0 +1,66 @@ +package xmaps + +import ( + "fmt" + "reflect" + + "golang.org/x/exp/maps" +) + +// FieldsToMap converts any struct to map[string]any. +func FieldsToMap(s any) map[string]any { + return fieldsToMapNested("", s) +} + +func fieldsToMapNested(prefix string, src any) map[string]any { + ur := reflect.ValueOf(src) + + // if the src is interface, get underlying value behind that. + if ur.Kind() == reflect.Interface && !ur.IsNil() { + elm := ur.Elem() + if elm.Kind() == reflect.Ptr && !elm.IsNil() && elm.Elem().Kind() == reflect.Ptr { + ur = elm + } + } + + // if type is pointer - get a value referenced by a pointer + // maybe do while? for pinter to pinter to pointer case? + if ur.Kind() == reflect.Ptr { + ur = ur.Elem() + } + + fn := ur.NumField() + if len(prefix) > 0 { + prefix += "." + } + + f := map[string]any{} + for i := 0; i < fn; i++ { + fv := ur.Field(i) + + if fv.Kind() == reflect.Pointer { + if fv.IsNil() { + continue + } + fv = fv.Elem() + } + + if fv.Kind() == reflect.Struct { + ff := fieldsToMapNested(prefix+ur.Type().Field(i).Name, fv.Interface()) + maps.Copy(f, ff) + } else if fv.Kind() == reflect.Slice { + if !fv.IsZero() { + for j := 0; j < fv.Len(); j++ { + pr := fmt.Sprintf("%s%s[%d]", prefix, ur.Type().Field(i).Name, j) + ff := fieldsToMapNested(pr, fv.Index(j).Interface()) + maps.Copy(f, ff) + } + } + } else { + if !fv.IsZero() { + f[prefix+ur.Type().Field(i).Name] = fv.Interface() + } + } + } + return f +} diff --git a/tools/xmaps/maps_test.go b/tools/xmaps/maps_test.go new file mode 100644 index 00000000..a92d697d --- /dev/null +++ b/tools/xmaps/maps_test.go @@ -0,0 +1,59 @@ +package xmaps_test + +import ( + "fmt" + "testing" + + "github.com/madappgang/identifo/v2/tools/xmaps" + "github.com/stretchr/testify/assert" +) + +type person struct { + Name string + Age int + Address address + Word *address + Other []address +} + +type address struct { + Street string + Apt int +} + +func TestFieldsToMap(t *testing.T) { + p := person{ + Name: "John", + Age: 30, + Address: address{ + Street: "321 Main St", + Apt: 123, + }, + } + m := xmaps.FieldsToMap(p) + fmt.Printf("%v\n", m) + assert.Len(t, m, 4) + assert.Equal(t, 123, m["Address.Apt"]) + assert.Equal(t, "321 Main St", m["Address.Street"]) +} + +func TestFieldsToMapWithArray(t *testing.T) { + p := person{ + Name: "John", + Age: 30, + Other: []address{ + { + Street: "321 Main St", + Apt: 123, + }, + { + Street: "Other street", + }, + }, + } + m := xmaps.FieldsToMap(p) + fmt.Printf("%v\n", m) + assert.Len(t, m, 5) + assert.Equal(t, nil, m["Address.Apt"]) + assert.Equal(t, "Other street", m["Other[1].Street"]) +} diff --git a/tools/xslices/slices.go b/tools/xslices/slices.go new file mode 100644 index 00000000..172bfa26 --- /dev/null +++ b/tools/xslices/slices.go @@ -0,0 +1,49 @@ +package xslices + +import ( + "golang.org/x/exp/maps" +) + +// Intersect returns intersection of sl2 and sl1 +// the result has only unique values +func Intersect[T comparable](sl1, sl2 []T) []T { + var intersect []T + + // Loop two times, first to find slice1 strings not in slice2, + // second loop to find slice2 strings not in slice1 + for i := 0; i < 2; i++ { + for _, s1 := range sl1 { + found := false + for _, s2 := range sl2 { + if s1 == s2 { + found = true + break + } + } + if found { + intersect = append(intersect, s1) + } + } + // Swap the slices, only if it was the first loop + if i == 0 { + sl1, sl2 = sl2, sl1 + } + } + + return Unique(intersect) +} + +// Unique - returns new slice with unique values only. +func Unique[T comparable](slice []T) []T { + m := map[T]bool{} + for _, v := range slice { + m[v] = true + } + return maps.Keys(m) +} + +// ConcatUnique concatenate sl2 with sl1, only unique values +func ConcatUnique[T comparable](sl1, sl2 []T) []T { + res := append(sl1, sl2...) + return Unique(res) +} diff --git a/tools/xslices/slices_test.go b/tools/xslices/slices_test.go new file mode 100644 index 00000000..3161c7c8 --- /dev/null +++ b/tools/xslices/slices_test.go @@ -0,0 +1,34 @@ +package xslices_test + +import ( + "testing" + + "github.com/madappgang/identifo/v2/tools/xslices" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" +) + +func TestSliceContains(t *testing.T) { + assert.True(t, slices.Contains([]string{"a", "b"}, "a")) + assert.True(t, slices.Contains([]string{"a", "b", "1"}, "1")) + assert.False(t, slices.Contains([]string{"a", "b", "1"}, "11")) +} + +func TestIntersect(t *testing.T) { + assert.Contains(t, xslices.Intersect([]string{"a", "b"}, []string{"a"}), "a") + assert.Len(t, xslices.Intersect([]string{"a", "b"}, []string{"a"}), 1) + + assert.Contains(t, xslices.Intersect([]string{"a", "b"}, []string{"a", "b", "c"}), "a") + assert.Contains(t, xslices.Intersect([]string{"a", "b"}, []string{"a", "b", "c"}), "b") + assert.Len(t, xslices.Intersect([]string{"a", "b"}, []string{"a", "b", "c"}), 2) + + assert.Len(t, xslices.Intersect([]string{"a", "b"}, []string{"c"}), 0) +} + +func TestConcatUnique(t *testing.T) { + assert.Contains(t, xslices.ConcatUnique([]string{"a", "b"}, []string{"a"}), "a") + assert.Len(t, xslices.Intersect([]string{"a", "b"}, []string{"a"}), 1) + + r := xslices.ConcatUnique([]string{"a", "b", "c", "d"}, []string{"a"}) + assert.Len(t, r, 4) +}