From 1bf6dd5074db1d20a2623d1201880a670abc5ecd Mon Sep 17 00:00:00 2001 From: Todd Date: Mon, 25 Sep 2023 16:05:17 -0700 Subject: [PATCH] Add support for keyringless auth tokens (#3765) * Add support for keyringless auth tokens --- internal/cmd/commands/daemon/addtoken.go | 37 +- internal/cmd/commands/daemon/options.go | 13 +- internal/cmd/commands/daemon/options_test.go | 46 ++ internal/cmd/commands/daemon/server.go | 58 +- internal/cmd/commands/daemon/server_test.go | 97 +++ internal/cmd/commands/daemon/testing.go | 22 +- internal/cmd/commands/daemon/token_handler.go | 49 +- internal/cmd/commands/daemon/token_test.go | 219 +++++-- internal/cmd/commands/search/search_test.go | 26 +- internal/daemon/cache/options.go | 10 - internal/daemon/cache/options_test.go | 22 +- internal/daemon/cache/repository.go | 47 +- internal/daemon/cache/repository_refresh.go | 10 +- .../daemon/cache/repository_refresh_test.go | 154 ++++- internal/daemon/cache/repository_sessions.go | 2 +- .../daemon/cache/repository_sessions_test.go | 134 ++-- internal/daemon/cache/repository_targets.go | 2 +- .../daemon/cache/repository_targets_test.go | 118 ++-- internal/daemon/cache/repository_test.go | 52 +- internal/daemon/cache/repository_token.go | 449 +++++++++----- .../daemon/cache/repository_token_test.go | 573 +++++++++++++----- internal/daemon/cache/schema.sql | 28 +- internal/daemon/cache/store_test.go | 10 +- 23 files changed, 1638 insertions(+), 540 deletions(-) create mode 100644 internal/cmd/commands/daemon/options_test.go create mode 100644 internal/cmd/commands/daemon/server_test.go diff --git a/internal/cmd/commands/daemon/addtoken.go b/internal/cmd/commands/daemon/addtoken.go index e52e6e52e6..167fab9506 100644 --- a/internal/cmd/commands/daemon/addtoken.go +++ b/internal/cmd/commands/daemon/addtoken.go @@ -79,26 +79,39 @@ func (c *AddTokenCommand) Run(args []string) int { func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) { const op = "daemon.(AddTokenCommand).Add" - keyringType, tokenName, err := c.DiscoverKeyringTokenInfo() + client, err := c.Client() if err != nil { return nil, err } - at := c.ReadTokenFromKeyring(keyringType, tokenName) - if at == nil { - return nil, errors.New(ctx, errors.Conflict, op, "no auth token available to send to daemon") - } - client, err := c.Client() + + keyringType, tokenName, err := c.DiscoverKeyringTokenInfo() if err != nil { return nil, err } - pa := userTokenToAdd{ - Keyring: &keyringToken{ + pa := upsertTokenRequest{ + BoundaryAddr: client.Addr(), + } + switch keyringType { + case "", base.NoneKeyring: + keyringType = base.NoneKeyring + token := client.Token() + if parts := strings.SplitN(token, "_", 4); len(parts) == 3 { + pa.AuthTokenId = strings.Join(parts[:2], "_") + } else { + return nil, errors.New(ctx, errors.InvalidParameter, op, "found auth token is not in the proper format") + } + pa.AuthToken = token + default: + at := c.ReadTokenFromKeyring(keyringType, tokenName) + if at == nil { + return nil, errors.New(ctx, errors.Conflict, op, "no auth token available to send to daemon") + } + pa.Keyring = &keyringToken{ KeyringType: keyringType, TokenName: tokenName, - }, - BoundaryAddr: client.Addr(), - AuthTokenId: at.Id, + } + pa.AuthTokenId = at.Id } dotPath, err := DefaultDotDirectory(ctx) @@ -109,7 +122,7 @@ func (c *AddTokenCommand) Add(ctx context.Context) (*api.Error, error) { return addToken(ctx, dotPath, &pa) } -func addToken(ctx context.Context, daemonPath string, p *userTokenToAdd) (*api.Error, error) { +func addToken(ctx context.Context, daemonPath string, p *upsertTokenRequest) (*api.Error, error) { const op = "daemon.addToken" client, err := api.NewClient(nil) if err != nil { diff --git a/internal/cmd/commands/daemon/options.go b/internal/cmd/commands/daemon/options.go index d107a3cf08..07ba3b6a71 100644 --- a/internal/cmd/commands/daemon/options.go +++ b/internal/cmd/commands/daemon/options.go @@ -5,10 +5,13 @@ package daemon import ( "context" + + "github.com/hashicorp/boundary/internal/daemon/cache" ) type options struct { - withDebug bool + withDebug bool + withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn } // Option - how options are passed as args @@ -36,3 +39,11 @@ func WithDebug(_ context.Context, debug bool) Option { return nil } } + +// WithBoundaryTokenReaderFunc provides an option for specifying a BoundaryTokenReaderFn +func WithBoundaryTokenReaderFunc(_ context.Context, fn cache.BoundaryTokenReaderFn) Option { + return func(o *options) error { + o.withBoundaryTokenReaderFunc = fn + return nil + } +} diff --git a/internal/cmd/commands/daemon/options_test.go b/internal/cmd/commands/daemon/options_test.go new file mode 100644 index 0000000000..1c307e709e --- /dev/null +++ b/internal/cmd/commands/daemon/options_test.go @@ -0,0 +1,46 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package daemon + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/api/authtokens" + "github.com/hashicorp/boundary/internal/daemon/cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetOpts(t *testing.T) { + t.Parallel() + ctx := context.Background() + + t.Run("default", func(t *testing.T) { + opts, err := getOpts() + require.NoError(t, err) + testOpts := options{} + assert.Equal(t, opts, testOpts) + }) + t.Run("WithDebug", func(t *testing.T) { + opts, err := getOpts(WithDebug(ctx, true)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withDebug = true + assert.Equal(t, opts, testOpts) + }) + t.Run("WithBoundaryTokenReaderFunc", func(t *testing.T) { + var f cache.BoundaryTokenReaderFn = func(ctx context.Context, addr, token string) (*authtokens.AuthToken, error) { + return nil, nil + } + opts, err := getOpts(WithBoundaryTokenReaderFunc(ctx, f)) + require.NoError(t, err) + + assert.NotNil(t, opts.withBoundaryTokenReaderFunc) + opts.withBoundaryTokenReaderFunc = nil + + testOpts := getDefaultOptions() + assert.Equal(t, opts, testOpts) + }) +} diff --git a/internal/cmd/commands/daemon/server.go b/internal/cmd/commands/daemon/server.go index b17668c3b4..18f5cf854e 100644 --- a/internal/cmd/commands/daemon/server.go +++ b/internal/cmd/commands/daemon/server.go @@ -32,11 +32,16 @@ import ( // Commander is an interface that provides a way to get an apiClient // and retrieve the keyring and token information used by a command. type Commander interface { - Client(opt ...base.Option) (*api.Client, error) + ClientProvider DiscoverKeyringTokenInfo() (string, string, error) ReadTokenFromKeyring(keyringType, tokenName string) *authtokens.AuthToken } +// ClientProvider is an interface that provides an api.Client +type ClientProvider interface { + Client(opt ...base.Option) (*api.Client, error) +} + type cacheServer struct { conf *serverConfig @@ -117,22 +122,67 @@ func (s *cacheServer) shutdown(ctx context.Context) error { return shutdownErr } +func defaultBoundaryTokenReader(ctx context.Context, cp ClientProvider) (cache.BoundaryTokenReaderFn, error) { + const op = "daemon.defaultBoundaryTokenReader" + switch { + case util.IsNil(cp): + return nil, errors.New(ctx, errors.InvalidParameter, op, "client provider is nil") + } + return func(ctx context.Context, addr, tok string) (*authtokens.AuthToken, error) { + switch { + case addr == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "address is missing") + case tok == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token is missing") + } + atIdParts := strings.SplitN(tok, "_", 4) + if len(atIdParts) != 3 { + return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token is malformed") + } + atId := strings.Join(atIdParts[:cache.AuthTokenIdSegmentCount], "_") + + c, err := cp.Client() + if err != nil { + return nil, err + } + c.SetAddr(addr) + c.SetToken(tok) + atClient := authtokens.NewClient(c) + + at, err := atClient.Read(ctx, atId) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + return at.GetItem(), nil + }, nil +} + // start will fire up the refresh goroutine and the caching API http server as a // daemon. The daemon bits are included so it's easy for CLI cmds to start the // a cache server -func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener) error { +func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener, opt ...Option) error { const op = "daemon.(cacheServer).start" switch { case util.IsNil(ctx): return errors.New(ctx, errors.InvalidParameter, op, "context is missing") } + opts, err := getOpts(opt...) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if opts.withBoundaryTokenReaderFunc == nil { + opts.withBoundaryTokenReaderFunc, err = defaultBoundaryTokenReader(ctx, cmd) + if err != nil { + return errors.Wrap(ctx, err, op) + } + } + s.info["Listening address"] = l.Addr().String() s.infoKeys = append(s.infoKeys, "Listening address") s.info["Store debug"] = strconv.FormatBool(s.conf.flagStoreDebug) s.infoKeys = append(s.infoKeys, "Store debug") - var err error if s.store, s.storeUrl, err = openStore(ctx, s.conf.flagDatabaseUrl, s.conf.flagStoreDebug); err != nil { return errors.Wrap(ctx, err, op) } @@ -141,7 +191,7 @@ func (s *cacheServer) serve(ctx context.Context, cmd Commander, l net.Listener) s.printInfo(ctx) - repo, err := cache.NewRepository(ctx, s.store, cmd.ReadTokenFromKeyring) + repo, err := cache.NewRepository(ctx, s.store, &sync.Map{}, cmd.ReadTokenFromKeyring, opts.withBoundaryTokenReaderFunc) if err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/cmd/commands/daemon/server_test.go b/internal/cmd/commands/daemon/server_test.go new file mode 100644 index 0000000000..273cd5c802 --- /dev/null +++ b/internal/cmd/commands/daemon/server_test.go @@ -0,0 +1,97 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package daemon + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/api" + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/boundary/internal/daemon/controller" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Note: the name of this test must remain short because the temp dir created +// includes the name of the test and there is a 108 character limit in allowed +// unix socket path names. +func TestDefaultBoundaryTokenReader(t *testing.T) { + ctx := context.Background() + + t.Run("nil client provider", func(t *testing.T) { + resFn, err := defaultBoundaryTokenReader(ctx, nil) + assert.Error(t, err) + assert.ErrorContains(t, err, "client provider is nil") + assert.Nil(t, resFn) + }) + + tc := controller.NewTestController(t, nil) + cp := fakeClientProvider{tc} + + cases := []struct { + name string + address string + token string + errContains string + }{ + { + name: "success", + address: tc.ApiAddrs()[0], + token: tc.Token().Token, + errContains: "", + }, + { + name: "empty address", + address: "", + token: "at_123_testtoken", + errContains: "address is missing", + }, + { + name: "empty token", + address: tc.ApiAddrs()[0], + token: "", + errContains: "auth token is missing", + }, + { + name: "malformed token to many sections", + address: tc.ApiAddrs()[0], + token: "at_123_ignoredtoken_tomanysections", + errContains: "auth token is malformed", + }, + { + name: "malformed token to few sections", + address: tc.ApiAddrs()[0], + token: "at_123", + errContains: "auth token is malformed", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + retFn, err := defaultBoundaryTokenReader(ctx, cp) + require.NoError(t, err) + require.NotNil(t, retFn) + + at, err := retFn(ctx, tc.address, tc.token) + switch tc.errContains { + case "": + assert.NoError(t, err) + assert.NotNil(t, at) + default: + assert.Error(t, err) + assert.ErrorContains(t, err, tc.errContains) + assert.Nil(t, at) + } + }) + } +} + +type fakeClientProvider struct { + *controller.TestController +} + +func (fcp fakeClientProvider) Client(opt ...base.Option) (*api.Client, error) { + return fcp.TestController.Client(), nil +} diff --git a/internal/cmd/commands/daemon/testing.go b/internal/cmd/commands/daemon/testing.go index fc677665f9..1f0db437fa 100644 --- a/internal/cmd/commands/daemon/testing.go +++ b/internal/cmd/commands/daemon/testing.go @@ -5,10 +5,12 @@ package daemon import ( "context" + stdErrors "errors" "io" - "strings" + "sync" "testing" + "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/daemon/cache" @@ -54,7 +56,7 @@ func (s *TestServer) BaseSocketDir() string { // Serve runs the cache server. This is a blocking call and returns when the // server is shutdown or stops for any other reason. -func (s *TestServer) Serve(t *testing.T) error { +func (s *TestServer) Serve(t *testing.T, opt ...Option) error { t.Helper() ctx := context.Background() @@ -64,28 +66,34 @@ func (s *TestServer) Serve(t *testing.T) error { t.Cleanup(func() { s.shutdown(ctx) }) - return s.cacheServer.serve(ctx, s.cmd, l) + return s.cacheServer.serve(ctx, s.cmd, l, opt...) } // AddResources adds targets to the cache for the provided address, token name, // and keyring type. They token info must already be known to the server. -func (s *TestServer) AddResources(t *testing.T, p *cache.AuthToken, tars []*targets.Target, sess []*sessions.Session) { +func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, tars []*targets.Target, sess []*sessions.Session) { t.Helper() ctx := context.Background() - r, err := cache.NewRepository(ctx, s.cacheServer.store, s.cmd.ReadTokenFromKeyring) + r, err := cache.NewRepository(ctx, s.cacheServer.store, &sync.Map{}, s.cmd.ReadTokenFromKeyring, unimplementedAuthTokenReader) require.NoError(t, err) tarFn := func(ctx context.Context, _, tok string) ([]*targets.Target, error) { - if !strings.HasPrefix(tok, p.Id) { + if tok != p.Token { return nil, nil } return tars, nil } sessFn := func(ctx context.Context, _, tok string) ([]*sessions.Session, error) { - if !strings.HasPrefix(tok, p.Id) { + if tok != p.Token { return nil, nil } return sess, nil } require.NoError(t, r.Refresh(ctx, cache.WithTargetRetrievalFunc(tarFn), cache.WithSessionRetrievalFunc(sessFn))) } + +// unimplementedAuthTokenReader is an unimplemented function for reading auth +// tokens from a provided boundary address. +func unimplementedAuthTokenReader(ctx context.Context, addr string, authToken string) (*authtokens.AuthToken, error) { + return nil, stdErrors.New("unimplemented") +} diff --git a/internal/cmd/commands/daemon/token_handler.go b/internal/cmd/commands/daemon/token_handler.go index a4dca8d70b..c64f15eac7 100644 --- a/internal/cmd/commands/daemon/token_handler.go +++ b/internal/cmd/commands/daemon/token_handler.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/daemon/cache" @@ -29,11 +30,14 @@ type keyringToken struct { } // userTokenToAdd is the request body to this handler. -type userTokenToAdd struct { +type upsertTokenRequest struct { // BoundaryAddr is a required field for all requests BoundaryAddr string // The id of the auth token asserted to be attempted to be added AuthTokenId string + // The raw auth token for this user. Either this field or the Keyring field + // must be set but not both. + AuthToken string // Keyring is the keyring info used when adding an auth token held in // keyring to the daemon. Keyring *keyringToken @@ -54,7 +58,7 @@ func newTokenHandlerFunc(ctx context.Context, repo *cache.Repository, refresher writeError(w, "method not allowed", http.StatusMethodNotAllowed) return } - var perReq userTokenToAdd + var perReq upsertTokenRequest data, err := io.ReadAll(r.Body) if err != nil { @@ -70,12 +74,12 @@ func newTokenHandlerFunc(ctx context.Context, repo *cache.Repository, refresher case perReq.BoundaryAddr == "": writeError(w, "BoundaryAddr is a required field but was empty", http.StatusBadRequest) return - case perReq.Keyring == nil: - writeError(w, "TokenName is a required field but was empty", http.StatusBadRequest) - return case perReq.AuthTokenId == "": writeError(w, "AuthTokenId is a required field but was empty", http.StatusBadRequest) return + case perReq.Keyring == nil && perReq.AuthToken == "": + writeError(w, "Either keyring info or the authtoken must be provided but were empty", http.StatusBadRequest) + return case perReq.Keyring != nil: switch { case perReq.Keyring.TokenName == "": @@ -85,10 +89,15 @@ func newTokenHandlerFunc(ctx context.Context, repo *cache.Repository, refresher writeError(w, "KeyringType is a required field but was empty", http.StatusBadRequest) return case perReq.Keyring.KeyringType == base.NoneKeyring: - // TODO: Support personas that have tokens not stored in a keyring writeError(w, fmt.Sprintf("KeyringType is set to %s which is not supported", perReq.Keyring.KeyringType), http.StatusBadRequest) return } + case perReq.AuthToken != "": + switch { + case !strings.HasPrefix(perReq.AuthToken, perReq.AuthTokenId): + writeError(w, "The auth token id doesn't match the auth token's prefix", http.StatusBadRequest) + return + } } tok, err := repo.LookupToken(ctx, perReq.AuthTokenId) @@ -97,21 +106,27 @@ func newTokenHandlerFunc(ctx context.Context, repo *cache.Repository, refresher return } - kt := cache.KeyringToken{ - KeyringType: perReq.Keyring.KeyringType, - TokenName: perReq.Keyring.TokenName, - AuthTokenId: perReq.AuthTokenId, - } - if err = repo.AddKeyringToken(ctx, perReq.BoundaryAddr, kt); err != nil { - writeError(w, "Failed to add a token", http.StatusInternalServerError) - return + switch { + case perReq.Keyring != nil: + kt := cache.KeyringToken{ + KeyringType: perReq.Keyring.KeyringType, + TokenName: perReq.Keyring.TokenName, + AuthTokenId: perReq.AuthTokenId, + } + if err = repo.AddKeyringToken(ctx, perReq.BoundaryAddr, kt); err != nil { + writeError(w, "Failed to add a keyring stored token", http.StatusInternalServerError) + return + } + case perReq.AuthToken != "": + if err = repo.AddRawToken(ctx, perReq.BoundaryAddr, perReq.AuthToken); err != nil { + writeError(w, "Failed to add a raw token", http.StatusInternalServerError) + return + } } w.WriteHeader(http.StatusNoContent) - // TODO: Figure out how to refresh only when the user id has changed - // and not every time the auth token changes. - if tok == nil || tok.Id != perReq.AuthTokenId { + if tok == nil { refresher.refresh() } }, nil diff --git a/internal/cmd/commands/daemon/token_test.go b/internal/cmd/commands/daemon/token_test.go index eb110cf138..1e40ea90c8 100644 --- a/internal/cmd/commands/daemon/token_test.go +++ b/internal/cmd/commands/daemon/token_test.go @@ -5,49 +5,75 @@ package daemon import ( "context" - "fmt" + stdErrors "errors" "net/http" "sync" "testing" "github.com/hashicorp/boundary/api/authtokens" + "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type testRefresher struct { - called bool +// ringToken is a test struct used to group a keyring type and token name +// so it can be used in an authtoken lookup function. +type ringToken struct { + k string + t string } -func (r *testRefresher) refresh() { - r.called = true +// mapBasedAuthTokenKeyringLookup provides a fake KeyringTokenLookupFn that uses +// the provided map to perform lookups for the tokens +func mapBasedAuthTokenKeyringLookup(m map[ringToken]*authtokens.AuthToken) cache.KeyringTokenLookupFn { + return func(k, t string) *authtokens.AuthToken { + return m[ringToken{k, t}] + } +} + +// sliceBasedAuthTokenBoundaryReader provides a fake BoundaryTokenReaderFn that uses +// the provided map to lookup an auth tokens information. +func sliceBasedAuthTokenBoundaryReader(s []*authtokens.AuthToken) cache.BoundaryTokenReaderFn { + return func(ctx context.Context, addr, at string) (*authtokens.AuthToken, error) { + for _, v := range s { + if at == v.Token { + return v, nil + } + } + return nil, stdErrors.New("not found") + } } -type testAtReader struct { - atId string +type testRefresher struct { + called bool } -func (r *testAtReader) ReadTokenFromKeyring(k, a string) *authtokens.AuthToken { - return &authtokens.AuthToken{ - Id: r.atId, - AuthMethodId: "test_auth_method", - Token: fmt.Sprintf("%s_%s", r.atId, a), - UserId: r.atId, - } +func (r *testRefresher) refresh() { + r.called = true } -func TestToken(t *testing.T) { +func TestKeyringToken(t *testing.T) { ctx := context.Background() - s, _, err := openStore(ctx, "", true) + s, _, err := openStore(ctx, "", false) require.NoError(t, err) - atReader := &testAtReader{"at_1234567890"} - repo, err := cache.NewRepository(ctx, s, atReader.ReadTokenFromKeyring) + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "user", + } + boundaryAuthTokens := []*authtokens.AuthToken{at} + keyring := "k" + tokenName := "t" + atMap := map[ringToken]*authtokens.AuthToken{ + {keyring, tokenName}: at, + } + r, err := cache.NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) tr := &testRefresher{} - ph, err := newTokenHandlerFunc(ctx, repo, tr) + ph, err := newTokenHandlerFunc(ctx, r, tr) require.NoError(t, err) mux := http.NewServeMux() @@ -67,13 +93,13 @@ func TestToken(t *testing.T) { }() t.Run("missing keyring", func(t *testing.T) { - pa := &userTokenToAdd{ + pa := &upsertTokenRequest{ Keyring: &keyringToken{ KeyringType: "", - TokenName: "default", + TokenName: tokenName, }, BoundaryAddr: "http://127.0.0.1", - AuthTokenId: atReader.atId, + AuthTokenId: at.Id, } apiErr, err := addToken(ctx, tmpdir, pa) assert.NoError(t, err) @@ -82,14 +108,30 @@ func TestToken(t *testing.T) { assert.False(t, tr.called) }) + t.Run("none keyring", func(t *testing.T) { + pa := &upsertTokenRequest{ + Keyring: &keyringToken{ + KeyringType: base.NoneKeyring, + TokenName: tokenName, + }, + BoundaryAddr: "http://127.0.0.1", + AuthTokenId: at.Id, + } + apiErr, err := addToken(ctx, tmpdir, pa) + assert.NoError(t, err) + require.NotNil(t, apiErr) + assert.Contains(t, apiErr.Message, "KeyringType is set to none which is not supported") + assert.False(t, tr.called) + }) + t.Run("missing token name", func(t *testing.T) { - pa := &userTokenToAdd{ + pa := &upsertTokenRequest{ Keyring: &keyringToken{ - KeyringType: "akeyringtype", + KeyringType: keyring, TokenName: "", }, BoundaryAddr: "http://127.0.0.1", - AuthTokenId: atReader.atId, + AuthTokenId: at.Id, } apiErr, err := addToken(ctx, tmpdir, pa) assert.NoError(t, err) @@ -99,13 +141,13 @@ func TestToken(t *testing.T) { }) t.Run("missing boundary address", func(t *testing.T) { - pa := &userTokenToAdd{ + pa := &upsertTokenRequest{ Keyring: &keyringToken{ - KeyringType: "akeyringtype", - TokenName: "default", + KeyringType: keyring, + TokenName: tokenName, }, BoundaryAddr: "", - AuthTokenId: atReader.atId, + AuthTokenId: at.Id, } apiErr, err := addToken(ctx, tmpdir, pa) assert.NoError(t, err) @@ -115,10 +157,10 @@ func TestToken(t *testing.T) { }) t.Run("missing auth token id", func(t *testing.T) { - pa := &userTokenToAdd{ + pa := &upsertTokenRequest{ Keyring: &keyringToken{ - KeyringType: "akeyringtype", - TokenName: "default", + KeyringType: keyring, + TokenName: tokenName, }, BoundaryAddr: "http://127.0.0.1", AuthTokenId: "", @@ -131,10 +173,10 @@ func TestToken(t *testing.T) { }) t.Run("mismatched auth token id", func(t *testing.T) { - pa := &userTokenToAdd{ + pa := &upsertTokenRequest{ Keyring: &keyringToken{ - KeyringType: "akeyringtype", - TokenName: "default", + KeyringType: keyring, + TokenName: tokenName, }, BoundaryAddr: "http://127.0.0.1", AuthTokenId: "at_doesntmatch", @@ -142,31 +184,122 @@ func TestToken(t *testing.T) { apiErr, err := addToken(ctx, tmpdir, pa) assert.NoError(t, err) assert.NotNil(t, apiErr) - assert.Contains(t, apiErr.Message, "Failed to add a token") + assert.Contains(t, apiErr.Message, "Failed to add a keyring stored token") assert.False(t, tr.called) }) t.Run("success", func(t *testing.T) { - pa := &userTokenToAdd{ + pa := &upsertTokenRequest{ Keyring: &keyringToken{ - KeyringType: "akeyringtype", - TokenName: "default", + KeyringType: keyring, + TokenName: tokenName, }, BoundaryAddr: "http://127.0.0.1", - AuthTokenId: atReader.atId, + AuthTokenId: at.Id, } apiErr, err := addToken(ctx, tmpdir, pa) assert.NoError(t, err) assert.Nil(t, apiErr) assert.True(t, tr.called) - repo, err := cache.NewRepository(ctx, s, (&testAtReader{"at_1234"}).ReadTokenFromKeyring) + p, err := r.LookupToken(ctx, pa.AuthTokenId) require.NoError(t, err) + assert.NotNil(t, p) + assert.Equal(t, at.Id, p.Id) + }) + srv.Shutdown(ctx) + wg.Wait() +} + +func TestKeyringlessToken(t *testing.T) { + ctx := context.Background() + s, _, err := openStore(ctx, "", false) + require.NoError(t, err) + + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "user", + } + boundaryAuthTokens := []*authtokens.AuthToken{at} + atMap := map[ringToken]*authtokens.AuthToken{} + r, err := cache.NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + + tr := &testRefresher{} + ph, err := newTokenHandlerFunc(ctx, r, tr) + require.NoError(t, err) + + mux := http.NewServeMux() + mux.HandleFunc("/v1/tokens", ph) + + tmpdir := t.TempDir() + l, err := listener(ctx, tmpdir) + require.NoError(t, err) + srv := &http.Server{ + Handler: mux, + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + assert.ErrorIs(t, srv.Serve(l), http.ErrServerClosed) + }() + + t.Run("missing boundary address", func(t *testing.T) { + pa := &upsertTokenRequest{ + BoundaryAddr: "", + AuthTokenId: at.Id, + AuthToken: at.Token, + } + apiErr, err := addToken(ctx, tmpdir, pa) + assert.NoError(t, err) + assert.NotNil(t, apiErr) + assert.Contains(t, apiErr.Message, "BoundaryAddr is a required field but was empty") + assert.False(t, tr.called) + }) + + t.Run("missing auth token id", func(t *testing.T) { + pa := &upsertTokenRequest{ + BoundaryAddr: "http://127.0.0.1", + AuthTokenId: "", + AuthToken: at.Token, + } + apiErr, err := addToken(ctx, tmpdir, pa) + assert.NoError(t, err) + assert.NotNil(t, apiErr) + assert.Contains(t, apiErr.Message, "AuthTokenId is a required field but was empty") + assert.False(t, tr.called) + }) + + t.Run("mismatched auth token id", func(t *testing.T) { + pa := &upsertTokenRequest{ + BoundaryAddr: "http://127.0.0.1", + AuthTokenId: "at_doesntmatch", + AuthToken: at.Token, + } + apiErr, err := addToken(ctx, tmpdir, pa) + assert.NoError(t, err) + assert.NotNil(t, apiErr) + assert.Contains(t, apiErr.Message, "The auth token id doesn't match the auth token's prefix") + assert.False(t, tr.called) + }) + + t.Run("success", func(t *testing.T) { + pa := &upsertTokenRequest{ + BoundaryAddr: "http://127.0.0.1", + AuthTokenId: at.Id, + AuthToken: at.Token, + } + apiErr, err := addToken(ctx, tmpdir, pa) + assert.NoError(t, err) + assert.Nil(t, apiErr) + assert.True(t, tr.called) - p, err := repo.LookupToken(ctx, pa.AuthTokenId) + p, err := r.LookupToken(ctx, pa.AuthTokenId) require.NoError(t, err) assert.NotNil(t, p) - assert.Equal(t, atReader.atId, p.Id) + assert.Equal(t, at.Id, p.Id) }) srv.Shutdown(ctx) wg.Wait() diff --git a/internal/cmd/commands/search/search_test.go b/internal/cmd/commands/search/search_test.go index 815518f769..e50bd247fe 100644 --- a/internal/cmd/commands/search/search_test.go +++ b/internal/cmd/commands/search/search_test.go @@ -5,7 +5,7 @@ package search import ( "context" - "fmt" + "errors" "sync" "testing" "time" @@ -16,14 +16,13 @@ import ( "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/cmd/commands/daemon" - "github.com/hashicorp/boundary/internal/daemon/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testCommander struct { t *testing.T - at *cache.AuthToken + at *authtokens.AuthToken } func (r *testCommander) keyring() string { @@ -45,19 +44,15 @@ func (r *testCommander) DiscoverKeyringTokenInfo() (string, string, error) { } func (r *testCommander) ReadTokenFromKeyring(k, a string) *authtokens.AuthToken { - return &authtokens.AuthToken{ - Id: r.at.Id, - AuthMethodId: "test_auth_method", - Token: fmt.Sprintf("%s_restofthetoken", r.at.Id), - UserId: r.at.UserId, - } + return r.at } func TestSearch(t *testing.T) { ctx := context.Background() - at := &cache.AuthToken{ - UserId: "u_1234567890", - Id: "at_authtokenid", + at := &authtokens.AuthToken{ + Id: "at_1", + UserId: "user_1", + Token: "at_1_token", } cmd := &testCommander{t: t, at: at} @@ -66,7 +61,12 @@ func TestSearch(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - srv.Serve(t) + srv.Serve(t, daemon.WithBoundaryTokenReaderFunc(ctx, func(ctx context.Context, addr, authToken string) (*authtokens.AuthToken, error) { + if authToken == at.Token { + return at, nil + } + return nil, errors.New("test not found error") + })) }() // Give the store some time to get initialized time.Sleep(100 * time.Millisecond) diff --git a/internal/daemon/cache/options.go b/internal/daemon/cache/options.go index e2cc979fb0..7b37febd8a 100644 --- a/internal/daemon/cache/options.go +++ b/internal/daemon/cache/options.go @@ -8,8 +8,6 @@ import ( ) type options struct { - withBoundaryAddress string - withAuthTokenId string withDebug bool withUrl string withUpdateLastAccessedTime bool @@ -62,14 +60,6 @@ func WithUpdateLastAccessedTime(b bool) Option { } } -// WithAuthTokenId provides an option for specifying an auth token id -func WithAuthTokenId(id string) Option { - return func(o *options) error { - o.withAuthTokenId = id - return nil - } -} - // WithTargetRetrievalFunc provides an option for specifying a targetRetrievalFunc func WithTargetRetrievalFunc(fn TargetRetrievalFunc) Option { return func(o *options) error { diff --git a/internal/daemon/cache/options_test.go b/internal/daemon/cache/options_test.go index d19c1660ab..d28769611f 100644 --- a/internal/daemon/cache/options_test.go +++ b/internal/daemon/cache/options_test.go @@ -7,6 +7,7 @@ import ( "context" "testing" + "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/go-dbw" "github.com/stretchr/testify/assert" @@ -46,14 +47,6 @@ func Test_GetOpts(t *testing.T) { testOpts.withUpdateLastAccessedTime = true assert.Equal(t, opts, testOpts) }) - t.Run("WithAuthTokenId", func(t *testing.T) { - id := "something" - opts, err := getOpts(WithAuthTokenId(id)) - require.NoError(t, err) - testOpts := getDefaultOptions() - testOpts.withAuthTokenId = id - assert.Equal(t, opts, testOpts) - }) t.Run("WithTargetRetrievalFunc", func(t *testing.T) { var f TargetRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*targets.Target, error) { return nil, nil } opts, err := getOpts(WithTargetRetrievalFunc(f)) @@ -62,6 +55,19 @@ func Test_GetOpts(t *testing.T) { assert.NotNil(t, opts.withTargetRetrievalFunc) opts.withTargetRetrievalFunc = nil + testOpts := getDefaultOptions() + assert.Equal(t, opts, testOpts) + }) + t.Run("WithSessionRetrievalFunc", func(t *testing.T) { + var f SessionRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*sessions.Session, error) { + return nil, nil + } + opts, err := getOpts(WithSessionRetrievalFunc(f)) + require.NoError(t, err) + + assert.NotNil(t, opts.withSessionRetrievalFunc) + opts.withSessionRetrievalFunc = nil + testOpts := getDefaultOptions() assert.Equal(t, opts, testOpts) }) diff --git a/internal/daemon/cache/repository.go b/internal/daemon/cache/repository.go index b276343c51..cf3ee618fe 100644 --- a/internal/daemon/cache/repository.go +++ b/internal/daemon/cache/repository.go @@ -5,6 +5,7 @@ package cache import ( "context" + "sync" "time" "github.com/hashicorp/boundary/api/authtokens" @@ -18,40 +19,62 @@ const ( tokenStalenessLimit = 36 * time.Hour ) -// TokenLookupFn takes a keyring type and token name and returns the token -type TokenLookupFn func(keyring string, tokenName string) *authtokens.AuthToken +// KeyringTokenLookupFn takes a token name and returns the token from the keyring +type KeyringTokenLookupFn func(keyring string, tokenName string) *authtokens.AuthToken + +// BoundaryTokenReaderFn reads an auth token's resource information from boundary +type BoundaryTokenReaderFn func(ctx context.Context, addr string, authToken string) (*authtokens.AuthToken, error) type Repository struct { - rw *db.Db - tokenLookupFn TokenLookupFn + rw *db.Db + tokenKeyringFn KeyringTokenLookupFn + tokenReadFromBoundaryFn BoundaryTokenReaderFn + // idToKeyringlessAuthToken maps an auth token id to an *authtokens.AuthToken + idToKeyringlessAuthToken *sync.Map } -func NewRepository(ctx context.Context, s *Store, tFn TokenLookupFn, opt ...Option) (*Repository, error) { +// NewRepository returns a cache repository. The provided Store must be +func NewRepository(ctx context.Context, s *Store, idToAuthToken *sync.Map, keyringFn KeyringTokenLookupFn, atReadFn BoundaryTokenReaderFn, opt ...Option) (*Repository, error) { const op = "cache.NewRepository" switch { case util.IsNil(s): return nil, errors.New(ctx, errors.InvalidParameter, op, "missing store") - case util.IsNil(tFn): - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token lookup function") + case util.IsNil(idToAuthToken): + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing keyringless auth token map") + case util.IsNil(keyringFn): + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token keyring function") + case util.IsNil(atReadFn): + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token read function") } - - return &Repository{rw: db.New(s.conn), tokenLookupFn: tFn}, nil + return &Repository{ + rw: db.New(s.conn), + tokenKeyringFn: keyringFn, + tokenReadFromBoundaryFn: atReadFn, + // This is passed in instead of being fully owned by the repo so multiple + // instances of the repo can operate on the same backing data + idToKeyringlessAuthToken: idToAuthToken, + }, nil } -func (r *Repository) SaveError(ctx context.Context, resourceType string, err error) error { +func (r *Repository) SaveError(ctx context.Context, u *user, resourceType string, err error) error { const op = "cache.(Repository).StoreError" switch { case resourceType == "": return errors.New(ctx, errors.InvalidParameter, op, "resource type is empty") case err == nil: return errors.New(ctx, errors.InvalidParameter, op, "error is nil") + case u == nil: + return errors.New(ctx, errors.InvalidParameter, op, "user is nil") + case u.Id == "": + return errors.New(ctx, errors.InvalidParameter, op, "user id is empty") } apiErr := &ApiError{ + UserId: u.Id, ResourceType: resourceType, Error: err.Error(), } onConflict := db.OnConflict{ - Target: db.Columns{"token_name", "resource_type"}, + Target: db.Columns{"user_id", "resource_type"}, Action: db.SetColumns([]string{"error", "create_time"}), } if err := r.rw.Create(ctx, apiErr, db.WithOnConflict(&onConflict)); err != nil { @@ -61,7 +84,7 @@ func (r *Repository) SaveError(ctx context.Context, resourceType string, err err } type ApiError struct { - TokenName string `gorm:"primaryKey"` + UserId string `gorm:"primaryKey"` ResourceType string `gorm:"primaryKey"` Error string CreateTime time.Time diff --git a/internal/daemon/cache/repository_refresh.go b/internal/daemon/cache/repository_refresh.go index 9082728c85..42781392ab 100644 --- a/internal/daemon/cache/repository_refresh.go +++ b/internal/daemon/cache/repository_refresh.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/hashicorp/boundary/api" + "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/errors" @@ -75,7 +76,7 @@ func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[A return nil, errors.Wrap(ctx, err, op, errors.WithMsg("for user %v, auth token %q", u, t.Id)) } for _, kt := range keyringTokens { - at := r.tokenLookupFn(kt.KeyringType, kt.TokenName) + at := r.tokenKeyringFn(kt.KeyringType, kt.TokenName) switch { case at == nil, at.Id != kt.AuthTokenId: if err := r.deleteKeyringToken(ctx, *kt); err != nil { @@ -85,6 +86,11 @@ func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[A ret[*t] = at.Token } } + if atv, ok := r.idToKeyringlessAuthToken.Load(t.Id); ok { + if at, ok := atv.(*authtokens.AuthToken); ok { + ret[*t] = at.Token + } + } } return ret, nil } @@ -96,7 +102,7 @@ func (r *Repository) cleanAndPickAuthTokens(ctx context.Context, u *user) (map[A // targets from a boundary address. func (r *Repository) Refresh(ctx context.Context, opt ...Option) error { const op = "cache.(Repository).Refresh" - if err := r.removeStaleTokens(ctx); err != nil { + if err := r.cleanOrphanedAuthTokens(ctx); err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/daemon/cache/repository_refresh_test.go b/internal/daemon/cache/repository_refresh_test.go index ebef9d101d..981a439526 100644 --- a/internal/daemon/cache/repository_refresh_test.go +++ b/internal/daemon/cache/repository_refresh_test.go @@ -5,8 +5,9 @@ package cache import ( "context" - "errors" + stdErrors "errors" "fmt" + "sync" "testing" "github.com/hashicorp/boundary/api/authtokens" @@ -16,6 +17,7 @@ import ( "github.com/hashicorp/boundary/internal/daemon/worker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" _ "github.com/hashicorp/boundary/internal/daemon/controller/handlers/targets/tcp" ) @@ -26,27 +28,140 @@ func noopRetrievalFn[T any](context.Context, string, string) ([]T, error) { return nil, nil } -func TestRefresh(t *testing.T) { +func TestCleanAndPickTokens(t *testing.T) { ctx := context.Background() s, err := Open(ctx) require.NoError(t, err) - internalAuthTokenFn := testAuthTokenLookup - atLookupFunc := func(k, t string) *authtokens.AuthToken { - return internalAuthTokenFn(k, t) + boundaryAddr := "address" + u1 := &user{Id: "u1", Address: boundaryAddr} + at1a := &authtokens.AuthToken{ + Id: "at_1a", + Token: "at_1a_token", + UserId: u1.Id, + } + at1b := &authtokens.AuthToken{ + Id: "at_1b", + Token: "at_1b_token", + UserId: u1.Id, + } + + keyringOnlyUser := &user{Id: "keyringUser", Address: boundaryAddr} + keyringAuthToken1 := &authtokens.AuthToken{ + Id: "at_2a", + Token: "at_2a_token", + UserId: keyringOnlyUser.Id, + } + keyringAuthToken2 := &authtokens.AuthToken{ + Id: "at_2b", + Token: "at_2b_token", + UserId: keyringOnlyUser.Id, } - r, err := NewRepository(ctx, s, atLookupFunc) + boundaryAuthTokens := []*authtokens.AuthToken{at1a, keyringAuthToken1, at1b, keyringAuthToken2} + atMap := make(map[ringToken]*authtokens.AuthToken) + r, err := NewRepository(ctx, s, &sync.Map{}, + mapBasedAuthTokenKeyringLookup(atMap), + sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + + t.Run("unknown user", func(t *testing.T) { + got, err := r.cleanAndPickAuthTokens(ctx, &user{Id: "unknownuser", Address: "unknown"}) + assert.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("both memory and keyring stored token", func(t *testing.T) { + key := ringToken{"k1", "t1"} + atMap[key] = at1a + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{ + KeyringType: key.k, + TokenName: key.t, + AuthTokenId: at1a.Id, + })) + require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token)) + + got, err := r.cleanAndPickAuthTokens(ctx, u1) + assert.NoError(t, err) + assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token}) + + // delete the keyringToken from the keyring and see it get removed from the response + delete(atMap, key) + got, err = r.cleanAndPickAuthTokens(ctx, u1) + assert.NoError(t, err) + assert.ElementsMatch(t, maps.Values(got), []string{at1b.Token}) + }) + + t.Run("2 memory tokens", func(t *testing.T) { + require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1a.Token)) + require.NoError(t, r.AddRawToken(ctx, boundaryAddr, at1b.Token)) + + got, err := r.cleanAndPickAuthTokens(ctx, u1) + assert.NoError(t, err) + assert.ElementsMatch(t, maps.Values(got), []string{at1a.Token, at1b.Token}) + }) + + t.Run("2 keyring tokens", func(t *testing.T) { + key1 := ringToken{"k1", "t1"} + atMap[key1] = keyringAuthToken1 + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{ + KeyringType: key1.k, + TokenName: key1.t, + AuthTokenId: keyringAuthToken1.Id, + })) + key2 := ringToken{"k2", "t2"} + atMap[key2] = keyringAuthToken2 + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{ + KeyringType: key2.k, + TokenName: key2.t, + AuthTokenId: keyringAuthToken2.Id, + })) + + got, err := r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + assert.NoError(t, err) + assert.ElementsMatch(t, maps.Values(got), []string{keyringAuthToken1.Token, keyringAuthToken2.Token}) + + // Removing all keyring references and then cleaning auth tokens + // removes all auth tokens, along with the user + gotU, err := r.listUsers(ctx) + assert.NoError(t, err) + assert.Contains(t, gotU, keyringOnlyUser) + + delete(atMap, key1) + delete(atMap, key2) + got, err = r.cleanAndPickAuthTokens(ctx, keyringOnlyUser) + assert.NoError(t, err) + assert.Empty(t, got) + + gotT, err := r.listTokens(ctx, keyringOnlyUser) + assert.NoError(t, err) + assert.Empty(t, gotT) + gotU, err = r.listUsers(ctx) + assert.NoError(t, err) + assert.NotContains(t, gotU, keyringOnlyUser) + }) +} + +func TestRefresh(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx) require.NoError(t, err) boundaryAddr := "address" - p := KeyringToken{ - KeyringType: "keyring", - TokenName: "token", + u := &user{Id: "u1", Address: boundaryAddr} + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + AuthMethodId: "am_1", } - at := testAuthTokenLookup(p.KeyringType, p.TokenName) - p.AuthTokenId = at.Id - require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, p)) - u := &user{Id: at.UserId, Address: boundaryAddr} + + boundaryAuthTokens := []*authtokens.AuthToken{at} + atMap := make(map[ringToken]*authtokens.AuthToken) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + + atMap[ringToken{"k", "t"}] = at + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) t.Run("set targets", func(t *testing.T) { retTargets := []*targets.Target{ @@ -115,7 +230,7 @@ func TestRefresh(t *testing.T) { }) t.Run("error propogates up", func(t *testing.T) { - innerErr := errors.New("test error") + innerErr := stdErrors.New("test error") err := r.Refresh(ctx, WithSessionRetrievalFunc(noopRetrievalFn[*sessions.Session]), WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { @@ -135,13 +250,10 @@ func TestRefresh(t *testing.T) { }) t.Run("tokens that are no longer in the ring is deleted", func(t *testing.T) { - internalAuthTokenFn = func(k, t string) *authtokens.AuthToken { - return nil - } - t.Cleanup(func() { - internalAuthTokenFn = testAuthTokenLookup - assert.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, p)) - }) + // Remove the token from the keyring, see that we can still see the + // token and then user until a Refresh happens which causes them to be + // cleaned up. + delete(atMap, ringToken{"k", "t"}) ps, err := r.listTokens(ctx, u) require.NoError(t, err) diff --git a/internal/daemon/cache/repository_sessions.go b/internal/daemon/cache/repository_sessions.go index 340ca045a0..58d969c11c 100644 --- a/internal/daemon/cache/repository_sessions.go +++ b/internal/daemon/cache/repository_sessions.go @@ -64,7 +64,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, sessions []*s return nil }) if err != nil { - if saveErr := r.SaveError(ctx, resource.Session.String(), err); saveErr != nil { + if saveErr := r.SaveError(ctx, u, resource.Session.String(), err); saveErr != nil { return stdErrors.Join(err, errors.Wrap(ctx, saveErr, op)) } return errors.Wrap(ctx, err, op) diff --git a/internal/daemon/cache/repository_sessions_test.go b/internal/daemon/cache/repository_sessions_test.go index 891fee3d28..af3755c200 100644 --- a/internal/daemon/cache/repository_sessions_test.go +++ b/internal/daemon/cache/repository_sessions_test.go @@ -5,12 +5,15 @@ package cache import ( "context" + "sync" "testing" + "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/internal/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" ) func TestRepository_refreshSessions(t *testing.T) { @@ -18,13 +21,26 @@ func TestRepository_refreshSessions(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) - require.NoError(t, err) - addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id + u := user{ + Id: "u1", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + kt := KeyringToken{ + KeyringType: "keyring", + TokenName: "token", + AuthTokenId: at.Id, + } + atMap := map[ringToken]*authtokens.AuthToken{ + {kt.KeyringType, kt.TokenName}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) ss := []*sessions.Session{ @@ -112,8 +128,43 @@ func TestRepository_ListSessions(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + u1 := &user{ + Id: "u1", + Address: addr, + } + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u1.Id, + } + kt1 := KeyringToken{ + KeyringType: "k1", + TokenName: "t1", + AuthTokenId: at1.Id, + } + u2 := &user{ + Id: "u2", + Address: addr, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: u2.Id, + } + kt2 := KeyringToken{ + KeyringType: "k2", + TokenName: "t2", + AuthTokenId: at2.Id, + } + atMap := map[ringToken]*authtokens.AuthToken{ + {"k1", "t1"}: at1, + {"k2", "t2"}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) t.Run("auth token id is missing", func(t *testing.T) { l, err := r.ListSessions(ctx, "") @@ -121,17 +172,6 @@ func TestRepository_ListSessions(t *testing.T) { assert.ErrorContains(t, err, "auth token id is missing") }) - addr := "address" - t1 := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(t1.KeyringType, t1.TokenName) - t1.AuthTokenId = at.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, t1)) - - t2 := KeyringToken{KeyringType: "keyring", TokenName: "token2"} - at2 := testAuthTokenLookup(t2.KeyringType, t2.TokenName) - t2.AuthTokenId = at2.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, t2)) - ss := []*sessions.Session{ { Id: "ttcp_1", @@ -152,15 +192,15 @@ func TestRepository_ListSessions(t *testing.T) { Type: "tcp", }, } - require.NoError(t, r.refreshSessions(ctx, &user{Address: addr, Id: at.UserId}, ss)) + require.NoError(t, r.refreshSessions(ctx, u1, ss)) t.Run("wrong user gets no sessions", func(t *testing.T) { - l, err := r.ListSessions(ctx, t2.AuthTokenId) + l, err := r.ListSessions(ctx, kt2.AuthTokenId) assert.NoError(t, err) assert.Empty(t, l) }) t.Run("correct token gets sessions", func(t *testing.T) { - l, err := r.ListSessions(ctx, t1.AuthTokenId) + l, err := r.ListSessions(ctx, kt1.AuthTokenId) assert.NoError(t, err) assert.Len(t, l, len(ss)) assert.ElementsMatch(t, l, ss) @@ -172,8 +212,43 @@ func TestRepository_QuerySessions(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + u1 := &user{ + Id: "u1", + Address: addr, + } + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u1.Id, + } + kt1 := KeyringToken{ + KeyringType: "k1", + TokenName: "t1", + AuthTokenId: at1.Id, + } + u2 := &user{ + Id: "u2", + Address: addr, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: u2.Id, + } + kt2 := KeyringToken{ + KeyringType: "k2", + TokenName: "t2", + AuthTokenId: at2.Id, + } + atMap := map[ringToken]*authtokens.AuthToken{ + {"k1", "t1"}: at1, + {"k2", "t2"}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) query := "status % status1 or status % status2" @@ -203,17 +278,6 @@ func TestRepository_QuerySessions(t *testing.T) { }) } - addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) - - kt2 := KeyringToken{KeyringType: "keyring", TokenName: "token2"} - at2 := testAuthTokenLookup(kt2.KeyringType, kt2.TokenName) - kt2.AuthTokenId = at2.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) - ss := []*sessions.Session{ { Id: "ttcp_1", @@ -234,7 +298,7 @@ func TestRepository_QuerySessions(t *testing.T) { Type: "tcp", }, } - require.NoError(t, r.refreshSessions(ctx, &user{Id: at.UserId, Address: addr}, ss)) + require.NoError(t, r.refreshSessions(ctx, u1, ss)) t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) @@ -242,7 +306,7 @@ func TestRepository_QuerySessions(t *testing.T) { assert.Empty(t, l) }) t.Run("correct token gets sessions", func(t *testing.T) { - l, err := r.QuerySessions(ctx, kt.AuthTokenId, query) + l, err := r.QuerySessions(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) assert.Len(t, l, 2) assert.ElementsMatch(t, l, ss[0:2]) diff --git a/internal/daemon/cache/repository_targets.go b/internal/daemon/cache/repository_targets.go index 4a0a38910d..fc0331f7b5 100644 --- a/internal/daemon/cache/repository_targets.go +++ b/internal/daemon/cache/repository_targets.go @@ -64,7 +64,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, targets []*tar return nil }) if err != nil { - if saveErr := r.SaveError(ctx, resource.Target.String(), err); saveErr != nil { + if saveErr := r.SaveError(ctx, u, resource.Target.String(), err); saveErr != nil { return stdErrors.Join(err, errors.Wrap(ctx, saveErr, op)) } return errors.Wrap(ctx, err, op) diff --git a/internal/daemon/cache/repository_targets_test.go b/internal/daemon/cache/repository_targets_test.go index 746748d86f..5063f0488c 100644 --- a/internal/daemon/cache/repository_targets_test.go +++ b/internal/daemon/cache/repository_targets_test.go @@ -5,12 +5,15 @@ package cache import ( "context" + "sync" "testing" + "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" ) func TestRepository_refreshTargets(t *testing.T) { @@ -18,15 +21,22 @@ func TestRepository_refreshTargets(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) - require.NoError(t, err) - addr := "address" - keyringType := "keyring" - tokenName := "token" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(keyringType, tokenName) - kt.AuthTokenId = at.Id + u := user{ + Id: "u1", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) ts := []*targets.Target{ @@ -116,8 +126,36 @@ func TestRepository_ListTargets(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + u1 := &user{ + Id: "u1", + Address: addr, + } + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u1.Id, + } + kt1 := KeyringToken{KeyringType: "k1", TokenName: "t1", AuthTokenId: at1.Id} + + u2 := &user{ + Id: "u2", + Address: addr, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: u2.Id, + } + kt2 := KeyringToken{KeyringType: "k2", TokenName: "t2", AuthTokenId: at2.Id} + atMap := map[ringToken]*authtokens.AuthToken{ + {"k1", "t1"}: at1, + {"k2", "t2"}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) t.Run("token is missing", func(t *testing.T) { l, err := r.ListTargets(ctx, "") @@ -125,17 +163,6 @@ func TestRepository_ListTargets(t *testing.T) { assert.ErrorContains(t, err, "auth token id is missing") }) - addr := "address" - p1 := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(p1.KeyringType, p1.TokenName) - p1.AuthTokenId = at.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, p1)) - - p2 := KeyringToken{KeyringType: "keyring", TokenName: "token2"} - at2 := testAuthTokenLookup(p2.KeyringType, p2.TokenName) - p2.AuthTokenId = at2.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, p2)) - ts := []*targets.Target{ { Id: "ttcp_1", @@ -159,15 +186,15 @@ func TestRepository_ListTargets(t *testing.T) { SessionMaxSeconds: 333, }, } - require.NoError(t, r.refreshTargets(ctx, &user{Id: at.UserId, Address: addr}, ts)) + require.NoError(t, r.refreshTargets(ctx, u1, ts)) t.Run("wrong user gets no targets", func(t *testing.T) { - l, err := r.ListTargets(ctx, p2.AuthTokenId) + l, err := r.ListTargets(ctx, kt2.AuthTokenId) assert.NoError(t, err) assert.Empty(t, l) }) t.Run("correct token gets targets", func(t *testing.T) { - l, err := r.ListTargets(ctx, p1.AuthTokenId) + l, err := r.ListTargets(ctx, kt1.AuthTokenId) assert.NoError(t, err) assert.Len(t, l, len(ts)) assert.ElementsMatch(t, l, ts) @@ -179,8 +206,36 @@ func TestRepository_QueryTargets(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + u1 := &user{ + Id: "u1", + Address: addr, + } + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u1.Id, + } + kt1 := KeyringToken{KeyringType: "k1", TokenName: "t1", AuthTokenId: at1.Id} + + u2 := &user{ + Id: "u2", + Address: addr, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: u2.Id, + } + kt2 := KeyringToken{KeyringType: "k2", TokenName: "t2", AuthTokenId: at2.Id} + atMap := map[ringToken]*authtokens.AuthToken{ + {"k1", "t1"}: at1, + {"k2", "t2"}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) query := "name % name1 or name % name2" @@ -211,17 +266,6 @@ func TestRepository_QueryTargets(t *testing.T) { }) } - addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) - - kt2 := KeyringToken{KeyringType: "keyring", TokenName: "token2"} - at2 := testAuthTokenLookup(kt2.KeyringType, kt2.TokenName) - kt2.AuthTokenId = at2.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) - ts := []*targets.Target{ { Id: "ttcp_1", @@ -245,7 +289,7 @@ func TestRepository_QueryTargets(t *testing.T) { SessionMaxSeconds: 333, }, } - require.NoError(t, r.refreshTargets(ctx, &user{Id: at.UserId, Address: addr}, ts)) + require.NoError(t, r.refreshTargets(ctx, u1, ts)) t.Run("wrong token gets no targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query) @@ -253,7 +297,7 @@ func TestRepository_QueryTargets(t *testing.T) { assert.Empty(t, l) }) t.Run("correct token gets targets", func(t *testing.T) { - l, err := r.QueryTargets(ctx, kt.AuthTokenId, query) + l, err := r.QueryTargets(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) assert.Len(t, l, 2) assert.ElementsMatch(t, l, ts[0:2]) diff --git a/internal/daemon/cache/repository_test.go b/internal/daemon/cache/repository_test.go index 6401cfd51c..5cc7403bd7 100644 --- a/internal/daemon/cache/repository_test.go +++ b/internal/daemon/cache/repository_test.go @@ -5,33 +5,73 @@ package cache import ( "context" + stdErrors "errors" "fmt" + "sync" "testing" + "github.com/hashicorp/boundary/api/authtokens" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// ringToken is a test struct used to group a keyring type and token name +// so it can be used in an authtoken lookup function. +type ringToken struct { + k string + t string +} + +// mapBasedAuthTokenKeyringLookup provides a fake KeyringTokenLookupFn that uses +// the provided map to perform lookups for the tokens +func mapBasedAuthTokenKeyringLookup(m map[ringToken]*authtokens.AuthToken) KeyringTokenLookupFn { + return func(k, t string) *authtokens.AuthToken { + return m[ringToken{k, t}] + } +} + +// sliceBasedAuthTokenBoundaryReader provides a fake BoundaryTokenReaderFn that uses +// the provided map to lookup an auth tokens information. +func sliceBasedAuthTokenBoundaryReader(s []*authtokens.AuthToken) BoundaryTokenReaderFn { + return func(ctx context.Context, addr, at string) (*authtokens.AuthToken, error) { + for _, v := range s { + if at == v.Token { + return v, nil + } + } + return nil, stdErrors.New("not found") + } +} + func TestRepository_SaveError(t *testing.T) { ctx := context.Background() s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + r, err := NewRepository(ctx, s, &sync.Map{}, + mapBasedAuthTokenKeyringLookup(map[ringToken]*authtokens.AuthToken{}), + sliceBasedAuthTokenBoundaryReader(nil)) require.NoError(t, err) testResource := "test_resource_type" testErr := fmt.Errorf("test error for %q", testResource) + u := &user{ + Id: "u1", + Address: "addr", + } + require.NoError(t, r.rw.Create(ctx, u)) + t.Run("empty resource type", func(t *testing.T) { - assert.ErrorContains(t, r.SaveError(ctx, "", testErr), "resource type is empty") + assert.ErrorContains(t, r.SaveError(ctx, u, "", testErr), "resource type is empty") }) t.Run("nil error", func(t *testing.T) { - assert.ErrorContains(t, r.SaveError(ctx, testResource, nil), "error is nil") + assert.ErrorContains(t, r.SaveError(ctx, u, testResource, nil), "error is nil") + }) + t.Run("nil user", func(t *testing.T) { + assert.ErrorContains(t, r.SaveError(ctx, nil, testResource, testErr), "user is nil") }) t.Run("success", func(t *testing.T) { - assert.NoError(t, r.SaveError(ctx, testResource, testErr)) + assert.NoError(t, r.SaveError(ctx, u, testResource, testErr)) }) - - assert.NoError(t, r.SaveError(ctx, testResource, testErr)) } diff --git a/internal/daemon/cache/repository_token.go b/internal/daemon/cache/repository_token.go index eea6cee483..c01d69f2a4 100644 --- a/internal/daemon/cache/repository_token.go +++ b/internal/daemon/cache/repository_token.go @@ -6,125 +6,231 @@ package cache import ( "context" "database/sql" + "fmt" + "strings" + "sync" "time" + "github.com/hashicorp/boundary/api/authtokens" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/util" ) -// AddToken adds a token to the repository. If the token in the -// keyring doesn't match the id provided an error is returned. If the number of -// tokens now exceed a limit, the token retrieved least recently is deleted. -func (r *Repository) AddKeyringToken(ctx context.Context, bAddr string, token KeyringToken) error { - const op = "cache.(Repository).AddKeyringToken" +// AuthTokenIdSegmentCount are the number of segments, delineated by "_", that +// make up the auth token id inside an auth token. +// For example, an authtoken format should look something like at_1234567890_sometokenpayload +const AuthTokenIdSegmentCount = 2 + +// upsertUserAndAuthToken upserts a user and authToken using the data in the provided authtoken. +// If creating this user results in the number of users exceeding the limit of +// allowed users it deletes the oldest one. +func upsertUserAndAuthToken(ctx context.Context, reader db.Reader, writer db.Writer, bAddr string, at *authtokens.AuthToken) error { + const op = "cache.upsertUserAndAuthToken" switch { - case token.TokenName == "": - return errors.New(ctx, errors.InvalidParameter, op, "token name is empty") - case token.KeyringType == "": - return errors.New(ctx, errors.InvalidParameter, op, "keyring type is empty") - case token.AuthTokenId == "": - return errors.New(ctx, errors.InvalidParameter, op, "boundary auth token id is empty") + // TODO: add check for reader and writer being part of an inflight tx. + case util.IsNil(reader): + return errors.New(ctx, errors.InvalidParameter, op, "reader is nil") + case util.IsNil(writer): + return errors.New(ctx, errors.InvalidParameter, op, "writer is nil") + case util.IsNil(at): + return errors.New(ctx, errors.InvalidParameter, op, "auth token is nil") case bAddr == "": return errors.New(ctx, errors.InvalidParameter, op, "boundary address is empty") + case at.Id == "": + return errors.New(ctx, errors.InvalidParameter, op, "auth token id is empty") + case at.UserId == "": + return errors.New(ctx, errors.InvalidParameter, op, "auth token user id is empty") + } + { + // always make sure the user exists when adding a token + u := &user{ + Id: at.UserId, + Address: bAddr, + } + onConflict := &db.OnConflict{ + Target: db.Columns{"id"}, + Action: db.DoNothing(true), + } + if err := writer.Create(ctx, u, db.WithOnConflict(onConflict)); err != nil { + return errors.Wrap(ctx, err, op) + } } - kt := token.clone() - at := r.tokenLookupFn(kt.KeyringType, kt.TokenName) - if at == nil { - return errors.New(ctx, errors.InvalidParameter, op, "unable to find token in the keyring specified") + { + st := &AuthToken{ + Id: at.Id, + UserId: at.UserId, + LastAccessedTime: time.Now(), + } + onConflict := &db.OnConflict{ + Target: db.Columns{"id"}, + Action: db.SetColumns([]string{"last_accessed_time"}), + } + if err := writer.Create(ctx, st, db.WithOnConflict(onConflict)); err != nil { + return errors.Wrap(ctx, err, op) + } } - if kt.AuthTokenId != at.Id { - return errors.New(ctx, errors.InvalidParameter, op, "provided auth token id doesn't match the one stored") + + var users []*user + if err := reader.SearchWhere(ctx, &users, "true", []any{}, db.WithLimit(-1)); err != nil { + return errors.Wrap(ctx, err, op) + } + if len(users) <= usersLimit { + return nil } - // Even though the auth token is already stored, we still call create so - // the last accessed timestamps can get updated since calling this method - // indicates that the token was used and is still valid. - _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, writer db.Writer) error { - { - t := &AuthToken{ - Id: token.AuthTokenId, - } - err := reader.LookupById(ctx, t) - switch { - case err != nil && !errors.IsNotFoundError(err): - return errors.Wrap(ctx, err, op) - case errors.IsNotFoundError(err): - // TODO: This is the first time this auth token is associated with - // this keyring/token name, lookup the auth token from boundary to - // verify that it is for the user specified. - case t.UserId != at.UserId: - return errors.New(ctx, errors.InvalidParameter, op, "user id doesn't match what is specified in the stored auth token") - } + var oldestUser *user + var oldestUsersTime *time.Time + for _, u := range users { + ats, err := listTokens(ctx, reader, u) + if err != nil { + return errors.Wrap(ctx, err, op) } - - { - // always make sure the user exists when adding a token - u := &user{ - Id: at.UserId, - Address: bAddr, - } - onConflict := &db.OnConflict{ - Target: db.Columns{"id"}, - Action: db.DoNothing(true), - } - if err := writer.Create(ctx, u, db.WithOnConflict(onConflict)); err != nil { - return errors.Wrap(ctx, err, op) + for _, at := range ats { + if oldestUsersTime == nil || oldestUsersTime.After(at.LastAccessedTime) { + oldestUser = u + oldestUsersTime = &at.LastAccessedTime } } - - { - st := &AuthToken{ - Id: at.Id, - UserId: at.UserId, - LastAccessedTime: time.Now(), - } - onConflict := &db.OnConflict{ - Target: db.Columns{"id"}, - Action: db.SetColumns([]string{"last_accessed_time"}), - } - if err := writer.Create(ctx, st, db.WithOnConflict(onConflict)); err != nil { - return errors.Wrap(ctx, err, op) - } + } + if oldestUser != nil { + if _, err := deleteUser(ctx, writer, oldestUser); err != nil { + return errors.Wrap(ctx, err, op) } + } + return nil +} - { - onConflict := &db.OnConflict{ - Target: db.Columns{"keyring_type", "token_name"}, - Action: db.SetColumns([]string{"auth_token_id"}), +// AddRawToken upserts the auth token's user and auth token in the db and +// stores the actual raw auth token in the repositories in memory storage. +// The raw token must be valid and present in boundary and be for a user that +// has permission to send a Read request for itself to boundary. +func (r *Repository) AddRawToken(ctx context.Context, bAddr string, rawToken string) error { + const op = "cache.(Repository).AddRawToken" + switch { + case rawToken == "": + return errors.New(ctx, errors.InvalidParameter, op, "boundary auth token is empty") + case bAddr == "": + return errors.New(ctx, errors.InvalidParameter, op, "boundary address is empty") + } + + // rawToken should look something like at_1234567890_someencryptedpayload + atIdParts := strings.SplitN(rawToken, "_", 4) + if len(atIdParts) != 3 { + return errors.New(ctx, errors.InvalidParameter, op, "boundary auth token is is malformed") + } + atId := strings.Join(atIdParts[:AuthTokenIdSegmentCount], "_") + + var at *authtokens.AuthToken + { + var inMemAuthToken *authtokens.AuthToken + atV, inMem := r.idToKeyringlessAuthToken.Load(atId) + if inMem { + var ok bool + inMemAuthToken, ok = atV.(*authtokens.AuthToken) + if !ok { + return errors.New(ctx, errors.Internal, op, "unable to cast in memory auth token to *authtoken.AuthToken") } - if err := writer.Create(ctx, kt, db.WithOnConflict(onConflict)); err != nil { + } + t := &AuthToken{ + Id: atId, + } + err := r.rw.LookupById(ctx, t) + switch { + case err != nil && !errors.IsNotFoundError(err): + return errors.Wrap(ctx, err, op) + case errors.IsNotFoundError(err) || !inMem: + // if we don't know about it in the cache or we don't know about + // this auth token in memory, get it from boundary to sure up the + // cache information about this auth token. + at, err = r.tokenReadFromBoundaryFn(ctx, bAddr, rawToken) + if err != nil { return errors.Wrap(ctx, err, op) } + case t.UserId != inMemAuthToken.UserId: + return errors.New(ctx, errors.InvalidParameter, op, "user id doesn't match what is specified in the stored auth token") + case inMem: + at = inMemAuthToken } + } - var users []*user - if err := reader.SearchWhere(ctx, &users, "true", []any{}, db.WithLimit(-1)); err != nil { + if at != nil { + // The token is never returned from boundary except in the original auth + // request so we must set the token. + at.Token = rawToken + } + + _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, writer db.Writer) error { + if err := upsertUserAndAuthToken(ctx, reader, writer, bAddr, at); err != nil { return errors.Wrap(ctx, err, op) } - if len(users) <= usersLimit { - return nil - } + r.idToKeyringlessAuthToken.Store(at.Id, at) + return nil + }) + if err != nil { + return err + } + + return nil +} - var oldestUser *user - var oldestUsersTime *time.Time - for _, u := range users { - ats, err := listTokens(ctx, reader, u) +// AddKeyringToken adds a token to the repository. If the token's id in the +// keyring doesn't match the id provided an error is returned. +// The token must be valid and present in boundary and be for a user that +// has permission to send a self-Read request to boundary. The user id +// stored in the keyring must also match the user id returned from boundary. +func (r *Repository) AddKeyringToken(ctx context.Context, bAddr string, token KeyringToken) error { + const op = "cache.(Repository).AddKeyringToken" + switch { + case token.TokenName == "": + return errors.New(ctx, errors.InvalidParameter, op, "token name is empty") + case token.KeyringType == "": + return errors.New(ctx, errors.InvalidParameter, op, "keyring type is empty") + case token.AuthTokenId == "": + return errors.New(ctx, errors.InvalidParameter, op, "boundary auth token id is empty") + case bAddr == "": + return errors.New(ctx, errors.InvalidParameter, op, "boundary address is empty") + } + kt := token.clone() + keyringStoredAt := r.tokenKeyringFn(kt.KeyringType, kt.TokenName) + if keyringStoredAt == nil { + return errors.New(ctx, errors.InvalidParameter, op, "unable to find token in the keyring specified") + } + if kt.AuthTokenId != keyringStoredAt.Id { + return errors.New(ctx, errors.InvalidParameter, op, "provided auth token id doesn't match the one stored") + } + + var at *authtokens.AuthToken + { + cachedAt := &AuthToken{ + Id: kt.AuthTokenId, + } + err := r.rw.LookupById(ctx, cachedAt) + switch { + case err != nil && !errors.IsNotFoundError(err): + return errors.Wrap(ctx, err, op) + case errors.IsNotFoundError(err): + at, err = r.tokenReadFromBoundaryFn(ctx, bAddr, keyringStoredAt.Token) if err != nil { return errors.Wrap(ctx, err, op) } - for _, at := range ats { - if oldestUsersTime == nil || oldestUsersTime.After(at.LastAccessedTime) { - oldestUser = u - oldestUsersTime = &at.LastAccessedTime - } - } + case cachedAt.UserId != keyringStoredAt.UserId: + return errors.New(ctx, errors.InvalidParameter, op, "user id doesn't match what is specified in the stored auth token") + default: + at = keyringStoredAt } - if oldestUser != nil { - if _, err := deleteUser(ctx, writer, oldestUser); err != nil { - return errors.Wrap(ctx, err, op) - } + } + _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, writer db.Writer) error { + if err := upsertUserAndAuthToken(ctx, reader, writer, bAddr, at); err != nil { + return errors.Wrap(ctx, err, op) + } + onConflict := &db.OnConflict{ + Target: db.Columns{"keyring_type", "token_name"}, + Action: db.SetColumns([]string{"auth_token_id"}), + } + if err := writer.Create(ctx, kt, db.WithOnConflict(onConflict)); err != nil { + return errors.Wrap(ctx, err, op) } return nil }) @@ -136,10 +242,10 @@ func (r *Repository) AddKeyringToken(ctx context.Context, bAddr string, token Ke } // LookupToken returns the Token in the cache if one exists. -// Accepts withUpdateLastAccessedTime, WithAuthTokenId, and WithBoundaryAddress -// options. If withUpdateLastAccessedTime is provided, the last update time -// of the returned token will be updated to the current time and reflected -// in the db. The returned AuthToken will not have the updated time. +// Accepts withUpdateLastAccessedTime options. If withUpdateLastAccessedTime +// is provided, the last update time of the returned token will be updated to +// the current time and reflected in the db. The returned AuthToken will not +// have the updated time. func (r *Repository) LookupToken(ctx context.Context, authTokenId string, opt ...Option) (*AuthToken, error) { const op = "cache.(Repository).LookupToken" switch { @@ -175,7 +281,7 @@ func (r *Repository) LookupToken(ctx context.Context, authTokenId string, opt .. } // deleteKeyringToken deletes a keyring token -func (r *Repository) deleteKeyringToken(ctx context.Context, kt KeyringToken) (retErr error) { +func (r *Repository) deleteKeyringToken(ctx context.Context, kt KeyringToken) error { const op = "cache.(Repository).deleteKeyringToken" switch { case kt.KeyringType == "": @@ -184,26 +290,84 @@ func (r *Repository) deleteKeyringToken(ctx context.Context, kt KeyringToken) (r return errors.New(ctx, errors.InvalidParameter, op, "token name type") } - n, err := deleteKeyringToken(ctx, r.rw, kt) + _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, writer db.Writer) error { + // TODO(https://github.com/go-gorm/gorm/issues/4879): Use the + // writer.Delete() function once the gorm bug is fixed. Until then + // the gorm driver for sqlite has an error which wont execute a + // delete correctly. as a work around we manually execute the + // query here. + n, err := writer.Exec(ctx, "delete from keyring_token where (keyring_type, token_name) = (?, ?)", []any{kt.KeyringType, kt.TokenName}) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + switch n { + case 1: + if err := cleanOrphanedAuthTokens(ctx, writer, r.idToKeyringlessAuthToken); err != nil { + return errors.Wrap(ctx, err, op) + } + return nil + case 0: + return errors.New(ctx, errors.RecordNotFound, op, "token not found when attempting deletion") + default: + return errors.New(ctx, errors.MultipleRecords, op, "multiple tokens deleted when one was requested") + } + }) if err != nil { - return errors.Wrap(ctx, err, op) + return err } + return nil +} - switch n { - case 1: +// cleanAuthTokens removes all tokens which are older than the staleness limit +// or does not have either a keyring or keyringless reference to it. +func (r *Repository) cleanOrphanedAuthTokens(ctx context.Context) error { + const op = "cache.Repository.cleanOrphanedAuthTokens" + _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, writer db.Writer) error { + if err := cleanOrphanedAuthTokens(ctx, writer, r.idToKeyringlessAuthToken); err != nil { + return errors.Wrap(ctx, err, op) + } return nil - case 0: - return errors.New(ctx, errors.RecordNotFound, op, "token not found when attempting deletion") - default: - return errors.New(ctx, errors.MultipleRecords, op, "multiple tokens deleted when one was requested") - } + }) + return err } -// removeStaleTokens removes all tokens which are older than the staleness -func (r *Repository) removeStaleTokens(ctx context.Context, opt ...Option) error { - const op = "cache.(Repository).removeStaleTokens" - if _, err := r.rw.Exec(ctx, "delete from auth_token where last_accessed_time < @last_accessed_time", - []any{sql.Named("last_accessed_time", time.Now().Add(-tokenStalenessLimit))}); err != nil { +// cleanAuthTokens removes all tokens which are older than the staleness limit +// or does not have either a keyring or keyringless reference to it. +func cleanOrphanedAuthTokens(ctx context.Context, writer db.Writer, idToKeyringlessAuthToken *sync.Map) error { + const op = "cache.cleanAuthTokens" + switch { + // TODO: Add check here to see if a transaction is in flight. + case util.IsNil(writer): + return errors.New(ctx, errors.InvalidParameter, op, "writer is nil") + case idToKeyringlessAuthToken == nil: + return errors.New(ctx, errors.InvalidParameter, op, "keyringless auth token map is nil") + } + + var keyringlessAuthTokens []string + idToKeyringlessAuthToken.Range(func(key, _ any) bool { + keyringlessAuthTokens = append(keyringlessAuthTokens, key.(string)) + return true + }) + + deleteOrphanedAuthTokens := ` + delete from auth_token + where + last_accessed_time < @last_accessed_time + or + %s + ` + args := []any{sql.Named("last_accessed_time", time.Now().Add(-tokenStalenessLimit))} + + idInSection := "id not in (select auth_token_id from keyring_token)" + if len(keyringlessAuthTokens) > 0 { + // Note: We have to build the statement like this because if the slice of string + // is empty this gets converted to a query that says " and id not in (NULL)" + idInSection = fmt.Sprintf("(%s and id not in @keyringless_token_ids)", idInSection) + args = append(args, sql.Named("keyringless_token_ids", keyringlessAuthTokens)) + } + + if _, err := writer.Exec(ctx, fmt.Sprintf(deleteOrphanedAuthTokens, idInSection), args); err != nil { return errors.Wrap(ctx, err, op) } return nil @@ -219,22 +383,6 @@ func (r *Repository) listUsers(ctx context.Context) ([]*user, error) { return ret, nil } -// listKeyringTokens returns all known keyring tokens in the cache for the provided auth token -func (r *Repository) listKeyringTokens(ctx context.Context, at *AuthToken) ([]*KeyringToken, error) { - const op = "cache.(Repository).listTokens" - switch { - case util.IsNil(at): - return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token is nil") - case at.Id == "": - return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is empty") - } - var ret []*KeyringToken - if err := r.rw.SearchWhere(ctx, &ret, "auth_token_id = ?", []any{at.Id}); err != nil { - return nil, errors.Wrap(ctx, err, op) - } - return ret, nil -} - // listTokens returns all known tokens in the cache for the provided user func (r *Repository) listTokens(ctx context.Context, u *user) ([]*AuthToken, error) { const op = "cache.(Repository).listTokens" @@ -261,27 +409,48 @@ func listTokens(ctx context.Context, reader db.Reader, u *user) ([]*AuthToken, e return ret, nil } -// deleteKeyringToken executes a delete command using the provided db.Writer for the provided token. -func deleteKeyringToken(ctx context.Context, w db.Writer, kt KeyringToken) (int, error) { - const op = "cache.deleteKeyringToken" +// listKeyringTokens returns all known keyring tokens in the cache for the provided auth token +func (r *Repository) listKeyringTokens(ctx context.Context, at *AuthToken) ([]*KeyringToken, error) { + const op = "cache.(Repository).listTokens" switch { - case util.IsNil(w): - return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "writer is nil") - case kt.KeyringType == "": - return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing keyring type") - case kt.TokenName == "": - return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "token name type") + case util.IsNil(at): + return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token is nil") + case at.Id == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is empty") } - // TODO(https://github.com/go-gorm/gorm/issues/4879): Use the - // writer.Delete() function once the gorm bug is fixed. Until then - // the gorm driver for sqlite has an error which wont execute a - // delete correctly. as a work around we manually execute the - // query here. - n, err := w.Exec(ctx, "delete from keyring_token where (keyring_type, token_name) = (?, ?)", []any{kt.KeyringType, kt.TokenName}) - if err != nil { - err = errors.Wrap(ctx, err, op) + var ret []*KeyringToken + if err := r.rw.SearchWhere(ctx, &ret, "auth_token_id = ?", []any{at.Id}); err != nil { + return nil, errors.Wrap(ctx, err, op) } - return n, err + return ret, nil +} + +// syncKeyringlessTokensWithDb removes the in memory storage of auth tokens if +// they are no longer represented in the db. +func syncKeyringlessTokensWithDb(ctx context.Context, reader db.Reader, ringlessAuthTokens *sync.Map) error { + const op = "cache.syncKeyringlessTokensWithDb" + switch { + case util.IsNil(reader): + return errors.New(ctx, errors.InvalidParameter, op, "reader is nil") + case ringlessAuthTokens == nil: + return errors.New(ctx, errors.InvalidParameter, op, "keyringless auth token map is nil") + } + var ret []*AuthToken + if err := reader.SearchWhere(ctx, &ret, "true", nil); err != nil { + return errors.Wrap(ctx, err, op) + } + authTokenIds := make(map[string]struct{}) + for _, at := range ret { + authTokenIds[at.Id] = struct{}{} + } + ringlessAuthTokens.Range(func(key, value any) bool { + k := key.(string) + if _, ok := authTokenIds[k]; !ok { + ringlessAuthTokens.Delete(key) + } + return true + }) + return nil } // deleteUser executes a delete command using the provided db.Writer for the provided user. diff --git a/internal/daemon/cache/repository_token_test.go b/internal/daemon/cache/repository_token_test.go index 86a066d630..4ccdff4291 100644 --- a/internal/daemon/cache/repository_token_test.go +++ b/internal/daemon/cache/repository_token_test.go @@ -6,48 +6,60 @@ package cache import ( "context" "fmt" + "sync" "testing" "time" "github.com/hashicorp/boundary/api/authtokens" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" ) -// ringToken is a test struct used to group a keyring type and token name -// so it can be used in an authtoken lookup function. -type ringToken struct { - k string - t string -} - -func mapBasedAuthTokenLookup(m map[ringToken]*authtokens.AuthToken) func(k, t string) *authtokens.AuthToken { - return func(k, t string) *authtokens.AuthToken { - return m[ringToken{k, t}] - } -} - -func testAuthTokenLookup(k, t string) *authtokens.AuthToken { - return &authtokens.AuthToken{ - Id: fmt.Sprintf("at_%s", t), - Token: fmt.Sprintf("at_%s_%s", t, k), - UserId: fmt.Sprintf("u_%s", t), - AuthMethodId: fmt.Sprintf("ampw_%s", t), - AccountId: fmt.Sprintf("acctpw_%s", t), - } -} - func TestRepository_AddKeyringToken(t *testing.T) { ctx := context.Background() s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + u := user{ + Id: "u1", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + // used to test mismatched user ids between keyring and cache + mismatchingAt := &authtokens.AuthToken{ + Id: "at_mismatch", + Token: "at_mismatch_token", + UserId: u.Id, + } + boundaryAuthTokens := []*authtokens.AuthToken{at, mismatchingAt} + keyring := "k" + tokenName := "t" + atMap := make(map[ringToken]*authtokens.AuthToken) + atMap[ringToken{keyring, tokenName}] = at + atMap[ringToken{"mismatch", "mismatch"}] = mismatchingAt + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) - keyring := "keyring" - tokenName := "token" - authTokenId := testAuthTokenLookup(keyring, tokenName).Id + t.Run("userid mismatch between db and keyring", func(t *testing.T) { + require.NoError(t, r.AddKeyringToken(ctx, "address", KeyringToken{ + KeyringType: "mismatch", + TokenName: "mismatch", + AuthTokenId: mismatchingAt.Id, + })) + + mismatchingAt.UserId = "changedToMismatch" + assert.ErrorContains(t, r.AddKeyringToken(ctx, "address", KeyringToken{ + KeyringType: "mismatch", + TokenName: "mismatch", + AuthTokenId: mismatchingAt.Id, + }), "user id doesn't match what is specified in the stored auth token") + }) errCases := []struct { name string @@ -61,16 +73,26 @@ func TestRepository_AddKeyringToken(t *testing.T) { kt: KeyringToken{ KeyringType: keyring, TokenName: tokenName, - AuthTokenId: authTokenId, + AuthTokenId: at.Id, }, errorContains: "", }, + { + name: "not in keyring", + addr: "address", + kt: KeyringToken{ + KeyringType: keyring, + TokenName: "unknowntokenname", + AuthTokenId: at.Id, + }, + errorContains: "unable to find token in the keyring specified", + }, { name: "missing address", kt: KeyringToken{ KeyringType: keyring, TokenName: tokenName, - AuthTokenId: authTokenId, + AuthTokenId: at.Id, }, errorContains: "boundary address is empty", }, @@ -79,7 +101,7 @@ func TestRepository_AddKeyringToken(t *testing.T) { addr: "address", kt: KeyringToken{ KeyringType: keyring, - AuthTokenId: authTokenId, + AuthTokenId: at.Id, }, errorContains: "token name is empty", }, @@ -88,7 +110,7 @@ func TestRepository_AddKeyringToken(t *testing.T) { addr: "address", kt: KeyringToken{ TokenName: tokenName, - AuthTokenId: authTokenId, + AuthTokenId: at.Id, }, errorContains: "keyring type is empty", }, @@ -125,29 +147,129 @@ func TestRepository_AddKeyringToken(t *testing.T) { } } -func TestRepository_AddToken_EvictsOverLimit(t *testing.T) { +func TestRepository_AddRawToken(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx) + require.NoError(t, err) + + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "u1", + } + existingAt := &authtokens.AuthToken{ + Id: "at_existing", + Token: "at_existing_token", + UserId: "u2", + } + boundaryAuthTokens := []*authtokens.AuthToken{at, existingAt} + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(map[ringToken]*authtokens.AuthToken{}), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + + t.Run("mismatched userid between memory and db", func(t *testing.T) { + require.NoError(t, r.AddRawToken(ctx, "address", existingAt.Token)) + loadedExistingV, loaded := r.idToKeyringlessAuthToken.Load(existingAt.Id) + require.True(t, loaded) + loadedExistingAt := loadedExistingV.(*authtokens.AuthToken) + loadedExistingAt.UserId = "mismatchingUserId" + r.idToKeyringlessAuthToken.Store(existingAt.Id, loadedExistingAt) + + err := r.AddRawToken(ctx, "address", loadedExistingAt.Token) + assert.Error(t, err) + assert.ErrorContains(t, err, "user id doesn't match what is specified in the stored auth token") + }) + + errCases := []struct { + name string + addr string + rawAt string + errorContains string + }{ + { + name: "success", + addr: "address", + rawAt: at.Token, + errorContains: "", + }, + { + name: "missing address", + rawAt: at.Token, + errorContains: "boundary address is empty", + }, + { + name: "missing token", + addr: "address", + errorContains: "auth token is empty", + }, + { + name: "malformed auth token", + addr: "address", + rawAt: fmt.Sprintf("%s_extraunderscore", at.Token), + errorContains: "boundary auth token is is malformed", + }, + { + name: "not found in boundary", + addr: "address", + rawAt: "at_123_notfoundinboundary", + errorContains: "not found", + }, + } + + for _, tc := range errCases { + t.Run(tc.name, func(t *testing.T) { + err := r.AddRawToken(ctx, tc.addr, tc.rawAt) + if tc.errorContains == "" { + require.NoError(t, err) + } else { + assert.ErrorContains(t, err, tc.errorContains) + } + }) + } +} + +func TestRepository_AddToken_EvictsOverLimitUsers(t *testing.T) { ctx := context.Background() s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + boundaryAuthTokens := []*authtokens.AuthToken{ + { + UserId: "user_base", + Id: "at_base", + Token: "at_base_token", + }, + } + for i := 0; i < usersLimit; i++ { + iAt := &authtokens.AuthToken{ + UserId: fmt.Sprintf("user%d", i), + Id: fmt.Sprintf("at_%d", i), + } + iAt.Token = fmt.Sprintf("%s_token", iAt.Id) + boundaryAuthTokens = append(boundaryAuthTokens, iAt) + } + + atMap := make(map[ringToken]*authtokens.AuthToken) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id + kt := KeyringToken{ + KeyringType: "keyring", + TokenName: "token", + AuthTokenId: boundaryAuthTokens[0].Id, + } + atMap[ringToken{kt.KeyringType, kt.TokenName}] = boundaryAuthTokens[0] assert.NoError(t, r.AddKeyringToken(ctx, addr, kt)) assert.NoError(t, r.AddKeyringToken(ctx, addr, kt)) lastKtAdded := kt - for i := 0; i < usersLimit; i++ { + for i, at := range boundaryAuthTokens[1:] { kr := fmt.Sprintf("%s%d", kt.KeyringType, i) tn := fmt.Sprintf("%s%d", kt.TokenName, i) - ikt := KeyringToken{KeyringType: kr, TokenName: tn} - at := testAuthTokenLookup(kr, tn) - ikt.AuthTokenId = at.Id + ikt := KeyringToken{KeyringType: kr, TokenName: tn, AuthTokenId: at.Id} + + atMap[ringToken{ikt.KeyringType, ikt.TokenName}] = at assert.NoError(t, r.AddKeyringToken(ctx, addr, ikt)) lastKtAdded = ikt } @@ -163,76 +285,220 @@ func TestRepository_AddToken_EvictsOverLimit(t *testing.T) { assert.NotEmpty(t, gotP) } -func TestRepository_AddToken_AddingExistingUpdatesLastAccessedTime(t *testing.T) { +func TestRepository_AddToken_EvictsOverLimit_Keyringless(t *testing.T) { ctx := context.Background() s, err := Open(ctx) require.NoError(t, err) - addr := "someaddr" - r, err := NewRepository(ctx, s, testAuthTokenLookup) + boundaryAuthTokens := []*authtokens.AuthToken{ + { + UserId: "user_base", + Id: "at_base", + Token: "at_base_token", + }, + } + for i := 0; i < usersLimit; i++ { + iAt := &authtokens.AuthToken{ + UserId: fmt.Sprintf("user%d", i), + Id: fmt.Sprintf("at_%d", i), + } + iAt.Token = fmt.Sprintf("%s_token", iAt.Id) + boundaryAuthTokens = append(boundaryAuthTokens, iAt) + } + + atMap := make(map[ringToken]*authtokens.AuthToken) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) - p1 := KeyringToken{ - TokenName: "default", - KeyringType: "keyring", + addr := "address" + + assert.NoError(t, r.AddRawToken(ctx, addr, boundaryAuthTokens[0].Token)) + assert.NoError(t, r.AddRawToken(ctx, addr, boundaryAuthTokens[0].Token)) + for _, at := range boundaryAuthTokens[1:] { + assert.NoError(t, r.AddRawToken(ctx, addr, at.Token)) } - at1 := testAuthTokenLookup(p1.KeyringType, p1.TokenName) - p1.AuthTokenId = at1.Id - assert.NoError(t, r.AddKeyringToken(ctx, addr, p1)) - p2 := KeyringToken{ - TokenName: "default2", - KeyringType: "keyring", + // Lookup the first persona added. It should have been evicted from the db + // for being used the least recently. It is only removed from the db once + // cleanAuthTokens is called. + gotP, err := r.LookupToken(ctx, boundaryAuthTokens[0].Id) + assert.NoError(t, err) + assert.Nil(t, gotP) + _, ok := r.idToKeyringlessAuthToken.Load(boundaryAuthTokens[0].Id) + assert.True(t, ok) + + gotP, err = r.LookupToken(ctx, boundaryAuthTokens[len(boundaryAuthTokens)-1].Id) + assert.NoError(t, err) + assert.NotEmpty(t, gotP) + _, ok = r.idToKeyringlessAuthToken.Load(boundaryAuthTokens[len(boundaryAuthTokens)-1].Id) + assert.True(t, ok) + + assert.NoError(t, syncKeyringlessTokensWithDb(ctx, r.rw, r.idToKeyringlessAuthToken)) + _, ok = r.idToKeyringlessAuthToken.Load(boundaryAuthTokens[0].Id) + assert.False(t, ok) +} + +func TestRepository_CleanAuthTokens(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx) + require.NoError(t, err) + + at := &authtokens.AuthToken{ + UserId: "user_base", + Id: "at_base", + Token: "at_base_token", } - at2 := testAuthTokenLookup(p2.KeyringType, p2.TokenName) - p2.AuthTokenId = at2.Id - assert.NoError(t, r.AddKeyringToken(ctx, addr, p2)) + boundaryAuthTokens := []*authtokens.AuthToken{at} + + atMap := make(map[ringToken]*authtokens.AuthToken) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + assert.NoError(t, r.AddRawToken(ctx, "baddr", at.Token)) + _, present := r.idToKeyringlessAuthToken.Load(at.Id) + assert.True(t, present) + + _, err = r.rw.Delete(ctx, &user{Id: at.UserId}) + require.NoError(t, err) + + _, present = r.idToKeyringlessAuthToken.Load(at.Id) + assert.True(t, present) + + assert.NoError(t, syncKeyringlessTokensWithDb(ctx, r.rw, r.idToKeyringlessAuthToken)) + + _, present = r.idToKeyringlessAuthToken.Load(at.Id) + assert.False(t, present) +} + +func TestRepository_AddKeyringToken_AddingExistingUpdatesLastAccessedTime(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx) + require.NoError(t, err) + + addr := "address" + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "u_1", + } + kt1 := KeyringToken{ + TokenName: "t1", + KeyringType: "k1", + AuthTokenId: at1.Id, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: "u_2", + } + kt2 := KeyringToken{ + TokenName: "t2", + KeyringType: "k2", + AuthTokenId: "at_2", + } + + boundaryAuthTokens := []*authtokens.AuthToken{at1, at2} + atMap := map[ringToken]*authtokens.AuthToken{ + {kt1.KeyringType, kt1.TokenName}: at1, + {kt2.KeyringType, kt2.TokenName}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + + assert.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + assert.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) time.Sleep(10 * time.Millisecond) - assert.NoError(t, r.AddKeyringToken(ctx, addr, p1)) + assert.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) - gotP1, err := r.LookupToken(ctx, p1.AuthTokenId) + gotP1, err := r.LookupToken(ctx, kt1.AuthTokenId) require.NoError(t, err) require.NotNil(t, gotP1) - gotP2, err := r.LookupToken(ctx, p2.AuthTokenId) + gotP2, err := r.LookupToken(ctx, kt2.AuthTokenId) require.NoError(t, err) require.NotNil(t, gotP2) assert.Greater(t, gotP1.LastAccessedTime, gotP2.LastAccessedTime) } -func TestRepository_ListTokens(t *testing.T) { +func TestRepository_AddRawToken_AddingExistingUpdatesLastAccessedTime(t *testing.T) { ctx := context.Background() s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "u_1", + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: "u_2", + } + + boundaryAuthTokens := []*authtokens.AuthToken{at1, at2} + atMap := map[ringToken]*authtokens.AuthToken{} + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + + assert.NoError(t, r.AddRawToken(ctx, addr, at1.Token)) + assert.NoError(t, r.AddRawToken(ctx, addr, at2.Token)) + + time.Sleep(10 * time.Millisecond) + assert.NoError(t, r.AddRawToken(ctx, addr, at1.Token)) + + gotP1, err := r.LookupToken(ctx, at1.Id) + require.NoError(t, err) + require.NotNil(t, gotP1) + gotP2, err := r.LookupToken(ctx, at2.Id) + require.NoError(t, err) + require.NotNil(t, gotP2) + + assert.Greater(t, gotP1.LastAccessedTime, gotP2.LastAccessedTime) +} + +func TestRepository_ListTokens(t *testing.T) { + ctx := context.Background() + s, err := Open(ctx) require.NoError(t, err) addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id u := &user{ - Id: at.UserId, + Id: "u1", Address: addr, } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } - t.Run("no token", func(t *testing.T) { - gotP, err := r.listTokens(ctx, u) - assert.NoError(t, err) - assert.Empty(t, gotP) - }) + atMap := make(map[ringToken]*authtokens.AuthToken) ktTokenCount := 15 - for i := 0; i < ktTokenCount; i++ { - thisKeyringType := fmt.Sprintf("%s%d", kt.KeyringType, i) - ikt := KeyringToken{KeyringType: thisKeyringType, TokenName: kt.TokenName} - at := testAuthTokenLookup(ikt.KeyringType, ikt.TokenName) - ikt.AuthTokenId = at.Id - require.NoError(t, r.AddKeyringToken(ctx, addr, ikt)) + k := fmt.Sprintf("k%d", i) + t := fmt.Sprintf("t%d", i) + atMap[ringToken{k, t}] = at + } + + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + + for k, v := range atMap { + require.NoError(t, r.AddKeyringToken(ctx, addr, KeyringToken{ + KeyringType: k.k, + TokenName: k.t, + AuthTokenId: v.Id, + })) } + t.Run("no token", func(t *testing.T) { + gotP, err := r.listTokens(ctx, &user{Id: "tokenless"}) + assert.NoError(t, err) + assert.Empty(t, gotP) + }) + t.Run("many tokens", func(t *testing.T) { gotAt, err := r.listTokens(ctx, u) assert.NoError(t, err) @@ -249,7 +515,23 @@ func TestRepository_DeleteKeyringToken(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "u_1", + } + kt1 := KeyringToken{ + TokenName: "t1", + KeyringType: "k1", + AuthTokenId: at1.Id, + } + + boundaryAuthTokens := []*authtokens.AuthToken{at1} + atMap := map[ringToken]*authtokens.AuthToken{ + {kt1.KeyringType, kt1.TokenName}: at1, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) t.Run("delete non existing", func(t *testing.T) { @@ -257,18 +539,14 @@ func TestRepository_DeleteKeyringToken(t *testing.T) { }) t.Run("delete existing", func(t *testing.T) { - addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id - assert.NoError(t, r.AddKeyringToken(ctx, addr, kt)) - p, err := r.LookupToken(ctx, kt.AuthTokenId) + assert.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + p, err := r.LookupToken(ctx, kt1.AuthTokenId) require.NoError(t, err) require.NotNil(t, p) - assert.NoError(t, r.deleteKeyringToken(ctx, kt)) + assert.NoError(t, r.deleteKeyringToken(ctx, kt1)) - got, err := r.LookupToken(ctx, kt.AuthTokenId) + got, err := r.LookupToken(ctx, kt1.AuthTokenId) require.NoError(t, err) require.Nil(t, got) }) @@ -279,7 +557,23 @@ func TestRepository_LookupToken(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - r, err := NewRepository(ctx, s, testAuthTokenLookup) + addr := "address" + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: "u_1", + } + kt := KeyringToken{ + TokenName: "t1", + KeyringType: "k1", + AuthTokenId: at.Id, + } + + boundaryAuthTokens := []*authtokens.AuthToken{at} + atMap := map[ringToken]*authtokens.AuthToken{ + {kt.KeyringType, kt.TokenName}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) t.Run("empty token id", func(t *testing.T) { @@ -293,11 +587,6 @@ func TestRepository_LookupToken(t *testing.T) { assert.Nil(t, p) }) t.Run("found", func(t *testing.T) { - addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id - assert.NoError(t, r.AddKeyringToken(ctx, addr, kt)) p, err := r.LookupToken(ctx, kt.AuthTokenId) assert.NoError(t, err) @@ -305,11 +594,6 @@ func TestRepository_LookupToken(t *testing.T) { }) t.Run("withUpdateLastAccessedTime", func(t *testing.T) { - addr := "address" - kt := KeyringToken{KeyringType: "keyring", TokenName: "token"} - at := testAuthTokenLookup(kt.KeyringType, kt.TokenName) - kt.AuthTokenId = at.Id - assert.NoError(t, r.AddKeyringToken(ctx, addr, kt)) time.Sleep(1 * time.Millisecond) @@ -329,59 +613,66 @@ func TestRepository_RemoveStaleTokens(t *testing.T) { s, err := Open(ctx) require.NoError(t, err) - atMap := make(map[ringToken]*authtokens.AuthToken) - atLookupFn := mapBasedAuthTokenLookup(atMap) + u := &user{ + Id: "user", + Address: "address", + } + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + kt1 := KeyringToken{ + TokenName: "t1", + KeyringType: "k1", + AuthTokenId: at1.Id, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: u.Id, + } + kt2 := KeyringToken{ + TokenName: "t2", + KeyringType: "k2", + AuthTokenId: "at_2", + } - r, err := NewRepository(ctx, s, atLookupFn) + boundaryAuthTokens := []*authtokens.AuthToken{at1, at2} + atMap := map[ringToken]*authtokens.AuthToken{ + {kt1.KeyringType, kt1.TokenName}: at1, + {kt2.KeyringType, kt2.TokenName}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) require.NoError(t, err) + assert.NoError(t, r.AddKeyringToken(ctx, u.Address, kt1)) + assert.NoError(t, r.AddKeyringToken(ctx, u.Address, kt2)) staleTime := time.Now().Add(-(tokenStalenessLimit + 1*time.Hour)) - oldNotStaleTime := time.Now().Add(-(tokenStalenessLimit - 1*time.Hour)) - - userId := "userId" - addr := "address" - keyringType := "keyring" - tokenName := "token" - authTokenId := "authTokenId" - for i := 0; i < usersLimit; i++ { - iKeyringType := fmt.Sprintf("%s%d", keyringType, i) - iTokenName := fmt.Sprintf("%s%d", tokenName, i) - iAuthTokenId := fmt.Sprintf("%s%d", authTokenId, i) - - kt := KeyringToken{ - KeyringType: iKeyringType, - TokenName: iTokenName, - AuthTokenId: iAuthTokenId, - } + freshTime := time.Now().Add(-(tokenStalenessLimit - 1*time.Hour)) - atMap[ringToken{kt.KeyringType, kt.TokenName}] = &authtokens.AuthToken{ - Id: kt.AuthTokenId, - UserId: userId, - Token: fmt.Sprintf("%s_sometokenvalue", kt.AuthTokenId), - } + freshAt := &AuthToken{ + Id: at1.Id, + LastAccessedTime: freshTime, + } + _, err = r.rw.Update(ctx, freshAt, []string{"LastAccessedTime"}, nil) + require.NoError(t, err) - assert.NoError(t, r.AddKeyringToken(ctx, addr, kt)) - p := &AuthToken{ - UserId: userId, - Id: kt.AuthTokenId, - } - switch i % 3 { - case 0: - p.LastAccessedTime = staleTime - _, err := r.rw.Update(ctx, p, []string{"LastAccessedTime"}, nil) - require.NoError(t, err) - case 1: - p.LastAccessedTime = oldNotStaleTime - _, err := r.rw.Update(ctx, p, []string{"LastAccessedTime"}, nil) - require.NoError(t, err) - } + staleAt := &AuthToken{ + Id: at2.Id, + LastAccessedTime: staleTime, } + _, err = r.rw.Update(ctx, staleAt, []string{"LastAccessedTime"}, nil) + require.NoError(t, err) - assert.NoError(t, r.removeStaleTokens(ctx)) - lAt, err := r.listTokens(ctx, &user{ - Id: userId, - Address: addr, - }) + lAt, err := r.listTokens(ctx, u) + assert.NoError(t, err) + assert.Len(t, lAt, 2) + + assert.NoError(t, r.cleanOrphanedAuthTokens(ctx)) + + lAt, err = r.listTokens(ctx, u) assert.NoError(t, err) - assert.Len(t, lAt, usersLimit*2/3) + assert.Len(t, lAt, 1) + assert.Equal(t, lAt[0].Id, at1.Id) } diff --git a/internal/daemon/cache/schema.sql b/internal/daemon/cache/schema.sql index f3ec8551c0..94acdcc8de 100644 --- a/internal/daemon/cache/schema.sql +++ b/internal/daemon/cache/schema.sql @@ -58,22 +58,6 @@ create table if not exists keyring_token ( primary key (keyring_type, token_name) ); --- *delete_orphaned_auth_tokens triggers deletes all auth tokens when it no --- longer has any storage for the auth tokens in the db -create trigger keyring_token_update_delete_orphaned_auth_tokens after update on keyring_token -begin -delete from auth_token -where - id not in (select auth_token_id from keyring_token); -end; - -create trigger keyring_token_delete_delete_orphaned_auth_tokens after delete on keyring_token -begin -delete from auth_token -where - id not in (select auth_token_id from keyring_token); -end; - -- target contains cached boundary target resource for a specific user and with -- specific fields extracted to facilitate searching over those fields create table if not exists target ( @@ -119,11 +103,13 @@ create table if not exists session ( -- contains errors from the last attempt to sync data from boundary for a -- specific resource type create table if not exists api_error ( - user_id text not null, - resource_type text not null, - error text not null, - create_time timestamp not null default current_timestamp, - primary key (user_id, resource_type) + user_id text not null + references user(id) + on delete cascade, + resource_type text not null, + error text not null, + create_time timestamp not null default current_timestamp, + primary key (user_id, resource_type) ); commit; diff --git a/internal/daemon/cache/store_test.go b/internal/daemon/cache/store_test.go index 9f81f9f7f4..e5a1b12cae 100644 --- a/internal/daemon/cache/store_test.go +++ b/internal/daemon/cache/store_test.go @@ -126,17 +126,11 @@ func TestAuthToken_NoMoreKeyringTokens(t *testing.T) { require.NoError(t, rw.Create(ctx, kt2)) assert.NoError(t, rw.LookupById(ctx, u)) - // deleting a single token doesn't remove the user - _, err = rw.Exec(ctx, "delete from keyring_token where (keyring_type, token_name) = (?, ?)", []any{kt1.KeyringType, kt1.TokenName}) + // deleting the keyring tokens doesn't remove the user + _, err = rw.Exec(ctx, "delete from keyring_token", nil) require.NoError(t, err) assert.NoError(t, rw.LookupById(ctx, at)) assert.NoError(t, rw.LookupById(ctx, u)) - - // deleting both tokens _does_ remove the user - _, err = rw.Exec(ctx, "delete from keyring_token", nil) - require.NoError(t, err) - assert.True(t, errors.IsNotFoundError(rw.LookupById(ctx, at))) - assert.True(t, errors.IsNotFoundError(rw.LookupById(ctx, u))) } func TestAuthToken(t *testing.T) {