Skip to content

Commit

Permalink
Delete tokens that the controller says are invalid (#3782)
Browse files Browse the repository at this point in the history
  • Loading branch information
talanknight committed Oct 16, 2023
1 parent 224e882 commit 049557a
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 19 deletions.
39 changes: 34 additions & 5 deletions internal/daemon/cache/repository_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/hashicorp/boundary/api/targets"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/util"
)

// TargetRetrievalFunc is a function that retrieves targets
Expand Down Expand Up @@ -64,6 +65,10 @@ func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Se
// The returned auth tokens have not been validated against boundary
func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[AuthToken]string, error) {
const op = "cache.(Repository).cleanAndPickAuthTokens"
switch {
case util.IsNil(u):
return nil, errors.New(ctx, errors.InvalidParameter, op, "user is nil")
}
ret := make(map[AuthToken]string)

tokens, err := r.listTokens(ctx, u)
Expand All @@ -78,16 +83,41 @@ func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[A
for _, kt := range keyringTokens {
at := r.tokenKeyringFn(kt.KeyringType, kt.TokenName)
switch {
case at == nil, at.Id != kt.AuthTokenId:
case at == nil, at.Id != kt.AuthTokenId, at.UserId != t.UserId:
// delete the keyring token if the auth token in the keyring
// has changed since it was stored in the cache.
if err := r.deleteKeyringToken(ctx, *kt); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
case at != nil:
_, err := r.tokenReadFromBoundaryFn(ctx, u.Address, at.Token)
var apiErr *api.Error
switch {
case err != nil && api.ErrUnauthorized.Is(err):
if err := r.deleteKeyringToken(ctx, *kt); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
continue
case err != nil && !errors.Is(err, apiErr):
event.WriteError(ctx, op, err, event.WithInfoMsg("validating keyring stored token against boundary", "auth token id", at.Id))
continue
}
ret[*t] = at.Token
}
}
if atv, ok := r.idToKeyringlessAuthToken.Load(t.Id); ok {
if at, ok := atv.(*authtokens.AuthToken); ok {
_, err := r.tokenReadFromBoundaryFn(ctx, u.Address, at.Token)
var apiErr *api.Error
switch {
case err != nil && api.ErrUnauthorized.Is(err):
r.idToKeyringlessAuthToken.Delete(t.Id)
continue
case err != nil && !errors.Is(err, apiErr):
event.WriteError(ctx, op, err, event.WithInfoMsg("validating in memory stored token against boundary", "auth token id", at.Id))
continue
}

ret[*t] = at.Token
}
}
Expand All @@ -105,6 +135,9 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
if err := r.cleanExpiredOrOrphanedAuthTokens(ctx); err != nil {
return errors.Wrap(ctx, err, op)
}
if err := r.syncKeyringlessTokensWithDb(ctx); err != nil {
return errors.Wrap(ctx, err, op)
}

opts, err := getOpts(opt...)
if err != nil {
Expand Down Expand Up @@ -134,8 +167,6 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
for at, t := range tokens {
resp, err := opts.withTargetRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
Expand All @@ -151,8 +182,6 @@ func (r *Repository) Refresh(ctx context.Context, opt ...Option) error {
for at, t := range tokens {
resp, err := opts.withSessionRetrievalFunc(ctx, u.Address, t)
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
Expand Down
122 changes: 121 additions & 1 deletion internal/daemon/cache/repository_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ package cache

import (
"context"
"errors"
stdErrors "errors"
"fmt"
"sync"
"testing"
"time"

"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/authtokens"
"github.com/hashicorp/boundary/api/sessions"
"github.com/hashicorp/boundary/api/targets"
Expand Down Expand Up @@ -62,11 +64,33 @@ func TestCleanAndPickTokens(t *testing.T) {
UserId: keyringOnlyUser.Id,
ExpirationTime: time.Now().Add(time.Minute),
}

boundaryAuthTokens := []*authtokens.AuthToken{at1a, keyringAuthToken1, at1b, keyringAuthToken2}
unauthorizedAuthTokens := []*authtokens.AuthToken{}
randomErrorAuthTokens := []*authtokens.AuthToken{}
fakeBoundaryLookupFn := func(ctx context.Context, addr, at string) (*authtokens.AuthToken, error) {
for _, v := range randomErrorAuthTokens {
if at == v.Token {
return nil, errors.New("test error")
}
}
for _, v := range unauthorizedAuthTokens {
if at == v.Token {
return nil, api.ErrUnauthorized
}
}
for _, v := range boundaryAuthTokens {
if at == v.Token {
return v, nil
}
}
return nil, stdErrors.New("not found")
}

atMap := make(map[ringToken]*authtokens.AuthToken)
r, err := NewRepository(ctx, s, &sync.Map{},
mapBasedAuthTokenKeyringLookup(atMap),
sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens))
fakeBoundaryLookupFn)
require.NoError(t, err)

t.Run("unknown user", func(t *testing.T) {
Expand Down Expand Up @@ -105,6 +129,102 @@ func TestCleanAndPickTokens(t *testing.T) {
assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token})
})

t.Run("boundary in memory auth token expires", func(t *testing.T) {
require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1a.Token))
require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token))

got, err := r.cleanAndPickAuthTokens(ctx, u1)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token})

t.Cleanup(func() {
unauthorizedAuthTokens = nil
})

unauthorizedAuthTokens = []*authtokens.AuthToken{at1b}
got, err = r.cleanAndPickAuthTokens(ctx, u1)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token})
})

t.Run("boundary keyring auths token expires", func(t *testing.T) {
key1 := ringToken{"k1", "t1"}
atMap[key1] = keyringAuthToken1
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{
KeyringType: key1.k,
TokenName: key1.t,
AuthTokenId: keyringAuthToken1.Id,
}))
key2 := ringToken{"k2", "t2"}
atMap[key2] = keyringAuthToken2
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{
KeyringType: key2.k,
TokenName: key2.t,
AuthTokenId: keyringAuthToken2.Id,
}))

got, err := r.cleanAndPickAuthTokens(ctx, keyringOnlyUser)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token, keyringAuthToken2.Token})

t.Cleanup(func() {
unauthorizedAuthTokens = nil
})

unauthorizedAuthTokens = []*authtokens.AuthToken{keyringAuthToken2}
got, err = r.cleanAndPickAuthTokens(ctx, keyringOnlyUser)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token})
})

t.Run("boundary in memory auth token check errors", func(t *testing.T) {
require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1a.Token))
require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token))

got, err := r.cleanAndPickAuthTokens(ctx, u1)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token})

t.Cleanup(func() {
randomErrorAuthTokens = nil
})

randomErrorAuthTokens = []*authtokens.AuthToken{at1b}
got, err = r.cleanAndPickAuthTokens(ctx, u1)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token})
})

t.Run("boundary keyring auths token check errors", func(t *testing.T) {
key1 := ringToken{"k1", "t1"}
atMap[key1] = keyringAuthToken1
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{
KeyringType: key1.k,
TokenName: key1.t,
AuthTokenId: keyringAuthToken1.Id,
}))
key2 := ringToken{"k2", "t2"}
atMap[key2] = keyringAuthToken2
require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{
KeyringType: key2.k,
TokenName: key2.t,
AuthTokenId: keyringAuthToken2.Id,
}))

got, err := r.cleanAndPickAuthTokens(ctx, keyringOnlyUser)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token, keyringAuthToken2.Token})

t.Cleanup(func() {
randomErrorAuthTokens = nil
})

randomErrorAuthTokens = []*authtokens.AuthToken{keyringAuthToken2}
got, err = r.cleanAndPickAuthTokens(ctx, keyringOnlyUser)
assert.NoError(t, err)
assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token})
})

t.Run("2 keyring tokens", func(t *testing.T) {
key1 := ringToken{"k1", "t1"}
atMap[key1] = keyringAuthToken1
Expand Down
16 changes: 5 additions & 11 deletions internal/daemon/cache/repository_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,26 +436,20 @@ func (r *Repository) listKeyringTokens(ctx context.Context, at *AuthToken) ([]*K

// syncKeyringlessTokensWithDb removes the in memory storage of auth tokens if
// they are no longer represented in the db.
func syncKeyringlessTokensWithDb(ctx context.Context, reader db.Reader, ringlessAuthTokens *sync.Map) error {
const op = "cache.syncKeyringlessTokensWithDb"
switch {
case util.IsNil(reader):
return errors.New(ctx, errors.InvalidParameter, op, "reader is nil")
case ringlessAuthTokens == nil:
return errors.New(ctx, errors.InvalidParameter, op, "keyringless auth token map is nil")
}
func (r *Repository) syncKeyringlessTokensWithDb(ctx context.Context) error {
const op = "cache.(Repository).syncKeyringlessTokensWithDb"
var ret []*AuthToken
if err := reader.SearchWhere(ctx, &ret, "true", nil); err != nil {
if err := r.rw.SearchWhere(ctx, &ret, "true", nil); err != nil {
return errors.Wrap(ctx, err, op)
}
authTokenIds := make(map[string]struct{})
for _, at := range ret {
authTokenIds[at.Id] = struct{}{}
}
ringlessAuthTokens.Range(func(key, value any) bool {
r.idToKeyringlessAuthToken.Range(func(key, value any) bool {
k := key.(string)
if _, ok := authTokenIds[k]; !ok {
ringlessAuthTokens.Delete(key)
r.idToKeyringlessAuthToken.Delete(key)
}
return true
})
Expand Down
4 changes: 2 additions & 2 deletions internal/daemon/cache/repository_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func TestRepository_AddToken_EvictsOverLimit_Keyringless(t *testing.T) {
_, ok = r.idToKeyringlessAuthToken.Load(boundaryAuthTokens[len(boundaryAuthTokens)-1].Id)
assert.True(t, ok)

assert.NoError(t, syncKeyringlessTokensWithDb(ctx, r.rw, r.idToKeyringlessAuthToken))
assert.NoError(t, r.syncKeyringlessTokensWithDb(ctx))
_, ok = r.idToKeyringlessAuthToken.Load(boundaryAuthTokens[0].Id)
assert.False(t, ok)
}
Expand Down Expand Up @@ -363,7 +363,7 @@ func TestRepository_CleanAuthTokens(t *testing.T) {
_, present = r.idToKeyringlessAuthToken.Load(at.Id)
assert.True(t, present)

assert.NoError(t, syncKeyringlessTokensWithDb(ctx, r.rw, r.idToKeyringlessAuthToken))
assert.NoError(t, r.syncKeyringlessTokensWithDb(ctx))

_, present = r.idToKeyringlessAuthToken.Load(at.Id)
assert.False(t, present)
Expand Down

0 comments on commit 049557a

Please sign in to comment.