From 384fc332be61b2592aa7ca0af626e8430c0aec0a Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 28 Aug 2024 10:15:04 -0700 Subject: [PATCH 01/15] Restrict cache results to 250 entries by default (#5049) A query parameter `max_result_set_size` can be specified to set a per-request limit, with `-1` meaning all values. If there were more entries than the max size used for the request, an `incomplete` boolean is set in the response. (cherry picked from commit 3cdd07889be5e4e6f21d044cb60fad1face4be08) --- internal/clientcache/cmd/search/search.go | 43 ++++- internal/clientcache/internal/cache/consts.go | 17 ++ .../clientcache/internal/cache/options.go | 21 ++- .../internal/cache/options_test.go | 15 +- .../internal/cache/refresh_test.go | 166 +++++++++++------- .../cache/repository_resolvable_aliases.go | 22 ++- .../repository_resolvable_aliases_test.go | 142 +++++++++++++-- .../internal/cache/repository_sessions.go | 22 ++- .../cache/repository_sessions_test.go | 142 +++++++++++++-- .../internal/cache/repository_targets.go | 22 ++- .../internal/cache/repository_targets_test.go | 142 +++++++++++++-- internal/clientcache/internal/cache/search.go | 68 ++++--- .../internal/daemon/404_handler.go | 2 +- .../internal/daemon/search_handler.go | 37 ++-- 14 files changed, 703 insertions(+), 158 deletions(-) create mode 100644 internal/clientcache/internal/cache/consts.go diff --git a/internal/clientcache/cmd/search/search.go b/internal/clientcache/cmd/search/search.go index f5fee5e475..e095671811 100644 --- a/internal/clientcache/cmd/search/search.go +++ b/internal/clientcache/cmd/search/search.go @@ -7,6 +7,7 @@ import ( "context" stderrors "errors" "fmt" + "math" "net/url" "os" "strings" @@ -40,9 +41,10 @@ var ( type SearchCommand struct { *base.Command - flagQuery string - flagResource string - flagForceRefresh bool + flagQuery string + flagResource string + flagForceRefresh bool + flagMaxResultSetSize int64 } func (c *SearchCommand) Synopsis() string { @@ -112,6 +114,12 @@ func (c *SearchCommand) Flags() *base.FlagSets { Usage: `Specifies the resource type to search over`, Completion: complete.PredictSet(supportedResourceTypes...), }) + f.Int64Var(&base.Int64Var{ + Name: "max-result-set-size", + Target: &c.flagMaxResultSetSize, + Usage: `Specifies an override to the default maximum result set size. Set to -1 to disable the limit. 0 will use the default.`, + Completion: complete.PredictNothing, + }) f.BoolVar(&base.BoolVar{ Name: "force-refresh", Target: &c.flagForceRefresh, @@ -148,6 +156,15 @@ func (c *SearchCommand) Run(args []string) int { return base.CommandUserError } + switch { + case c.flagMaxResultSetSize < -1: + c.PrintCliError(stderrors.New("Max result set size must be greater than or equal to -1")) + return base.CommandUserError + case c.flagMaxResultSetSize > math.MaxInt: + c.PrintCliError(stderrors.New(fmt.Sprintf("Max result set size must be less than or equal to the %v", math.MaxInt))) + return base.CommandUserError + } + resp, result, apiErr, err := c.Search(ctx) if err != nil { c.PrintCliError(err) @@ -164,6 +181,9 @@ func (c *SearchCommand) Run(args []string) int { return base.CommandCliError } default: + if result.Incomplete { + c.UI.Warn("The maximum result set size was reached and the search results are incomplete. Please narrow your search or adjust the -max-result-set-size parameter.") + } switch { case len(result.ResolvableAliases) > 0: c.UI.Output(printAliasListTable(result.ResolvableAliases)) @@ -199,6 +219,9 @@ func (c *SearchCommand) Search(ctx context.Context) (*api.Response, *daemon.Sear authTokenId: strings.Join(tSlice[:2], "_"), forceRefresh: c.flagForceRefresh, } + if c.flagMaxResultSetSize != 0 { + tf.maxResultSetSize = int(c.flagMaxResultSetSize) + } var opts []client.Option if c.FlagOutputCurlString { opts = append(opts, client.WithOutputCurlString()) @@ -230,6 +253,9 @@ func search(ctx context.Context, daemonPath string, fb filterBy, opt ...client.O if fb.forceRefresh { q.Add("force_refresh", "true") } + if fb.maxResultSetSize != 0 { + q.Add("max_result_set_size", fmt.Sprintf("%d", fb.maxResultSetSize)) + } resp, err := c.Get(ctx, "/v1/search", q, opt...) if err != nil { return nil, nil, nil, fmt.Errorf("Error when sending request to the cache: %w.", err) @@ -424,9 +450,10 @@ func printSessionListTable(items []*sessions.Session) string { } type filterBy struct { - flagFilter string - flagQuery string - authTokenId string - resource string - forceRefresh bool + flagFilter string + flagQuery string + authTokenId string + resource string + forceRefresh bool + maxResultSetSize int } diff --git a/internal/clientcache/internal/cache/consts.go b/internal/clientcache/internal/cache/consts.go new file mode 100644 index 0000000000..2321907d6d --- /dev/null +++ b/internal/clientcache/internal/cache/consts.go @@ -0,0 +1,17 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +/* +Package cache contains the domain logic for the client cache. +*/ +package cache + +const ( + // defaultLimitedResultSetSize is the default number of results to + // return when limiting + defaultLimitedResultSetSize = 250 + + // unlimitedMaxResultSetSize is the value to use when we want to return all + // results + unlimitedMaxResultSetSize = -1 +) diff --git a/internal/clientcache/internal/cache/options.go b/internal/clientcache/internal/cache/options.go index 6f68ebcea3..09e34ee9d8 100644 --- a/internal/clientcache/internal/cache/options.go +++ b/internal/clientcache/internal/cache/options.go @@ -4,6 +4,8 @@ package cache import ( + stderrors "errors" + "github.com/hashicorp/go-dbw" ) @@ -16,6 +18,7 @@ type options struct { withTargetRetrievalFunc TargetRetrievalFunc withSessionRetrievalFunc SessionRetrievalFunc withIgnoreSearchStaleness bool + withMaxResultSetSize int } // Option - how options are passed as args @@ -23,7 +26,8 @@ type Option func(*options) error func getDefaultOptions() options { return options{ - withDbType: dbw.Sqlite, + withDbType: dbw.Sqlite, + withMaxResultSetSize: defaultLimitedResultSetSize, } } @@ -94,3 +98,18 @@ func WithIgnoreSearchStaleness(b bool) Option { return nil } } + +// WithMaxResultSetSize provides an option for limiting the result set, e.g. +// when no filter is provided on a list. A 0 does nothing (keeps the default). +func WithMaxResultSetSize(with int) Option { + return func(o *options) error { + switch { + case with == 0: + return nil + case with < -1: + return stderrors.New("max result set size must be -1 or greater") + } + o.withMaxResultSetSize = with + return nil + } +} diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 2475356203..aed96cd5ba 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -22,7 +22,8 @@ func Test_GetOpts(t *testing.T) { opts, err := getOpts() require.NoError(t, err) testOpts := options{ - withDbType: dbw.Sqlite, + withDbType: dbw.Sqlite, + withMaxResultSetSize: defaultLimitedResultSetSize, } assert.Equal(t, opts, testOpts) }) @@ -93,4 +94,16 @@ func Test_GetOpts(t *testing.T) { testOpts.withIgnoreSearchStaleness = true assert.Equal(t, opts, testOpts) }) + t.Run("withMaxResultSetSize", func(t *testing.T) { + opts, err := getOpts(WithMaxResultSetSize(defaultLimitedResultSetSize)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withMaxResultSetSize = defaultLimitedResultSetSize + assert.Equal(t, opts, testOpts) + opts, err = getOpts(WithMaxResultSetSize(0)) + require.Nil(t, err) + assert.Equal(t, opts, testOpts) + _, err = getOpts(WithMaxResultSetSize(-2)) + require.Error(t, err) + }) } diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index f7c0c003aa..a24dedf198 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -440,15 +440,18 @@ func TestRefreshForSearch(t *testing.T) { assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedTargets) + require.NoError(t, err) + assert.Empty(t, cachedTargets.Targets) + assert.Empty(t, cachedTargets.ResolvableAliases) + assert.Empty(t, cachedTargets.Sessions) + assert.False(t, cachedTargets.Incomplete) // Now load up a few resources and a token, and trying again should // see the RefreshForSearch update more fields. assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // Let 2 milliseconds pass so the items are stale enough time.Sleep(2 * time.Millisecond) @@ -456,7 +459,7 @@ func TestRefreshForSearch(t *testing.T) { assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[2:], cachedTargets) + assert.ElementsMatch(t, retTargets[2:], cachedTargets.Targets) }) t.Run("targets forced refreshed for searching", func(t *testing.T) { @@ -492,27 +495,30 @@ func TestRefreshForSearch(t *testing.T) { assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedTargets) + require.NoError(t, err) + assert.Empty(t, cachedTargets.Targets) + assert.Empty(t, cachedTargets.ResolvableAliases) + assert.Empty(t, cachedTargets.Sessions) + assert.False(t, cachedTargets.Incomplete) // Now load up a few resources and a token, and trying again should // see the RefreshForSearch update more fields. assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // No refresh happened because it is not considered stale assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // Now force refresh assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, append(opts, WithIgnoreSearchStaleness(true))...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[2:], cachedTargets) + assert.ElementsMatch(t, retTargets[2:], cachedTargets.Targets) }) t.Run("no refresh token no refresh for search", func(t *testing.T) { @@ -537,8 +543,11 @@ func TestRefreshForSearch(t *testing.T) { assert.ErrorContains(t, err, ErrRefreshNotSupported.Error()) got, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) // Now that we know that this user doesn't support refresh tokens, they // wont be refreshed any more, and we wont see the error when refreshing @@ -556,8 +565,11 @@ func TestRefreshForSearch(t *testing.T) { assert.Nil(t, err) got, err = r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) // Now simulate the controller updating to support refresh tokens and // the resources starting to be cached. @@ -569,7 +581,7 @@ func TestRefreshForSearch(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Targets, 2) }) t.Run("sessions refreshed for searching", func(t *testing.T) { @@ -605,19 +617,22 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no sessions were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedSessions) + require.NoError(t, err) + assert.Empty(t, cachedSessions.Targets) + assert.Empty(t, cachedSessions.ResolvableAliases) + assert.Empty(t, cachedSessions.Sessions) + assert.False(t, cachedSessions.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[2:], cachedSessions) + assert.ElementsMatch(t, retSess[2:], cachedSessions.Sessions) }) t.Run("sessions forced refreshed for searching", func(t *testing.T) { @@ -654,25 +669,28 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no sessions were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedSessions) + require.NoError(t, err) + assert.Empty(t, cachedSessions.Targets) + assert.Empty(t, cachedSessions.ResolvableAliases) + assert.Empty(t, cachedSessions.Sessions) + assert.False(t, cachedSessions.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Refresh for search doesn't refresh anything because it isn't stale assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Now force the refresh and see things get updated assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, append(opts, WithIgnoreSearchStaleness(true))...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[2:], cachedSessions) + assert.ElementsMatch(t, retSess[2:], cachedSessions.Sessions) }) t.Run("aliases refreshed for searching", func(t *testing.T) { @@ -708,19 +726,22 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no aliases were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedAliases) + require.NoError(t, err) + assert.Empty(t, cachedAliases.Targets) + assert.Empty(t, cachedAliases.ResolvableAliases) + assert.Empty(t, cachedAliases.Sessions) + assert.False(t, cachedAliases.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAl[:3], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAl[:3], cachedAliases.ResolvableAliases) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAl[2:], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAl[2:], cachedAliases.ResolvableAliases) }) t.Run("aliases forced refreshed for searching", func(t *testing.T) { @@ -757,25 +778,28 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no aliases were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedAliases) + require.NoError(t, err) + assert.Empty(t, cachedAliases.Targets) + assert.Empty(t, cachedAliases.ResolvableAliases) + assert.Empty(t, cachedAliases.Sessions) + assert.False(t, cachedAliases.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAls[:3], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAls[:3], cachedAliases.ResolvableAliases) // Refresh for search doesn't refresh anything because it isn't stale assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAls[:3], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAls[:3], cachedAliases.ResolvableAliases) // Now force the refresh and see things get updated assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, append(opts, WithIgnoreSearchStaleness(true))...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAls[2:], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAls[2:], cachedAliases.ResolvableAliases) }) } @@ -829,13 +853,13 @@ func TestRefresh(t *testing.T) { cachedTargets, err := r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[2:], cachedTargets) + assert.ElementsMatch(t, retTargets[2:], cachedTargets.Targets) }) t.Run("set sessions", func(t *testing.T) { @@ -870,13 +894,13 @@ func TestRefresh(t *testing.T) { assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[2:], cachedSessions) + assert.ElementsMatch(t, retSess[2:], cachedSessions.Sessions) }) t.Run("set aliases", func(t *testing.T) { @@ -911,13 +935,13 @@ func TestRefresh(t *testing.T) { assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retAls[:3], cachedAliases) + assert.ElementsMatch(t, retAls[:3], cachedAliases.ResolvableAliases) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retAls[2:], cachedAliases) + assert.ElementsMatch(t, retAls[2:], cachedAliases.ResolvableAliases) }) t.Run("error propagates up", func(t *testing.T) { @@ -1022,8 +1046,11 @@ func TestRecheckCachingSupport(t *testing.T) { WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) got, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), @@ -1032,8 +1059,11 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) // now a full fetch will work since the user has resources and no refresh token assert.NoError(t, rs.RecheckCachingSupport(ctx, @@ -1057,8 +1087,11 @@ func TestRecheckCachingSupport(t *testing.T) { WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), @@ -1067,16 +1100,22 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) }) t.Run("aliases", func(t *testing.T) { @@ -1094,8 +1133,11 @@ func TestRecheckCachingSupport(t *testing.T) { WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), @@ -1104,16 +1146,22 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) }) t.Run("error propagates up", func(t *testing.T) { diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases.go b/internal/clientcache/internal/cache/repository_resolvable_aliases.go index 1bec59e9ae..39e3bf5f36 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases.go @@ -288,20 +288,20 @@ func upsertResolvableAliases(ctx context.Context, w db.Writer, u *user, in []*al return nil } -func (r *Repository) ListResolvableAliases(ctx context.Context, authTokenId string) ([]*aliases.Alias, error) { +func (r *Repository) ListResolvableAliases(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).ListResolvableAliases" switch { case authTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") } - ret, err := r.searchResolvableAliases(ctx, "true", nil, withAuthTokenId(authTokenId)) + ret, err := r.searchResolvableAliases(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) QueryResolvableAliases(ctx context.Context, authTokenId, query string) ([]*aliases.Alias, error) { +func (r *Repository) QueryResolvableAliases(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).QueryResolvableAliases" switch { case authTokenId == "": @@ -314,14 +314,14 @@ func (r *Repository) QueryResolvableAliases(ctx context.Context, authTokenId, qu if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.InvalidParameter)) } - ret, err := r.searchResolvableAliases(ctx, w.Condition, w.Args, withAuthTokenId(authTokenId)) + ret, err := r.searchResolvableAliases(ctx, w.Condition, w.Args, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) searchResolvableAliases(ctx context.Context, condition string, searchArgs []any, opt ...Option) ([]*aliases.Alias, error) { +func (r *Repository) searchResolvableAliases(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).searchResolvableAliases" switch { case condition == "": @@ -346,7 +346,7 @@ func (r *Repository) searchResolvableAliases(ctx context.Context, condition stri } var cachedResolvableAliases []*ResolvableAlias - if err := r.rw.SearchWhere(ctx, &cachedResolvableAliases, condition, searchArgs, db.WithLimit(-1)); err != nil { + if err := r.rw.SearchWhere(ctx, &cachedResolvableAliases, condition, searchArgs, db.WithLimit(opts.withMaxResultSetSize+1)); err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -358,7 +358,15 @@ func (r *Repository) searchResolvableAliases(ctx context.Context, condition stri } retAliases = append(retAliases, &a) } - return retAliases, nil + + sr := &SearchResult{ + ResolvableAliases: retAliases, + } + if opts.withMaxResultSetSize > 0 && len(sr.ResolvableAliases) > opts.withMaxResultSetSize { + sr.ResolvableAliases = sr.ResolvableAliases[:opts.withMaxResultSetSize] + sr.Incomplete = true + } + return sr, nil } type ResolvableAlias struct { diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go index 948241ef61..086bf17c4c 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go @@ -6,6 +6,7 @@ package cache import ( "context" "encoding/json" + "strconv" "sync" "testing" @@ -222,7 +223,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err := r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.ResolvableAliases, 2) // Refreshing again uses the refresh token and get additional aliases, appending // them to the response @@ -232,7 +233,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.ResolvableAliases, 3) // Refreshing again wont return any more resources, but also none should be // removed @@ -242,7 +243,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.ResolvableAliases, 3) // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, @@ -251,7 +252,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.ResolvableAliases, 2) } func TestRepository_ListAliases(t *testing.T) { @@ -332,13 +333,73 @@ func TestRepository_ListAliases(t *testing.T) { t.Run("wrong user gets no aliases", func(t *testing.T) { l, err := r.ListResolvableAliases(ctx, kt2.AuthTokenId) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.ResolvableAliases) }) t.Run("correct token gets aliases", func(t *testing.T) { l, err := r.ListResolvableAliases(ctx, kt1.AuthTokenId) assert.NoError(t, err) - assert.Len(t, l, len(ss)) - assert.ElementsMatch(t, l, ss) + assert.Len(t, l.ResolvableAliases, len(ss)) + assert.ElementsMatch(t, l.ResolvableAliases, ss) + }) +} + +func TestRepository_ListAliasesLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*aliases.Alias + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, alias("s"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: ResolvableAliases, + AuthTokenId: kt.AuthTokenId, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } @@ -442,13 +503,74 @@ func TestRepository_QueryAliases(t *testing.T) { t.Run("wrong token gets no aliases", func(t *testing.T) { l, err := r.QueryResolvableAliases(ctx, kt2.AuthTokenId, query) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.ResolvableAliases) }) t.Run("correct token gets aliases", func(t *testing.T) { l, err := r.QueryResolvableAliases(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) - assert.Len(t, l, 2) - assert.ElementsMatch(t, l, ss[0:2]) + assert.Len(t, l.ResolvableAliases, 2) + assert.ElementsMatch(t, l.ResolvableAliases, ss[0:2]) + }) +} + +func TestRepository_QueryResolvableAliasesLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*aliases.Alias + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, alias("s"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: ResolvableAliases, + AuthTokenId: kt.AuthTokenId, + Query: `(type % 'target')`, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } diff --git a/internal/clientcache/internal/cache/repository_sessions.go b/internal/clientcache/internal/cache/repository_sessions.go index 99f5bad1c3..904202252a 100644 --- a/internal/clientcache/internal/cache/repository_sessions.go +++ b/internal/clientcache/internal/cache/repository_sessions.go @@ -290,20 +290,20 @@ func upsertSessions(ctx context.Context, w db.Writer, u *user, in []*sessions.Se return nil } -func (r *Repository) ListSessions(ctx context.Context, authTokenId string) ([]*sessions.Session, error) { +func (r *Repository) ListSessions(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).ListSessions" switch { case authTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") } - ret, err := r.searchSessions(ctx, "true", nil, withAuthTokenId(authTokenId)) + ret, err := r.searchSessions(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) QuerySessions(ctx context.Context, authTokenId, query string) ([]*sessions.Session, error) { +func (r *Repository) QuerySessions(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).QuerySessions" switch { case authTokenId == "": @@ -316,14 +316,14 @@ func (r *Repository) QuerySessions(ctx context.Context, authTokenId, query strin if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.InvalidParameter)) } - ret, err := r.searchSessions(ctx, w.Condition, w.Args, withAuthTokenId(authTokenId)) + ret, err := r.searchSessions(ctx, w.Condition, w.Args, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) searchSessions(ctx context.Context, condition string, searchArgs []any, opt ...Option) ([]*sessions.Session, error) { +func (r *Repository) searchSessions(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).searchSessions" switch { case condition == "": @@ -348,7 +348,7 @@ func (r *Repository) searchSessions(ctx context.Context, condition string, searc } var cachedSessions []*Session - if err := r.rw.SearchWhere(ctx, &cachedSessions, condition, searchArgs, db.WithLimit(-1)); err != nil { + if err := r.rw.SearchWhere(ctx, &cachedSessions, condition, searchArgs, db.WithLimit(opts.withMaxResultSetSize+1)); err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -360,7 +360,15 @@ func (r *Repository) searchSessions(ctx context.Context, condition string, searc } retSessions = append(retSessions, &sess) } - return retSessions, nil + + sr := &SearchResult{ + Sessions: retSessions, + } + if opts.withMaxResultSetSize > 0 && len(sr.Sessions) > opts.withMaxResultSetSize { + sr.Sessions = sr.Sessions[:opts.withMaxResultSetSize] + sr.Incomplete = true + } + return sr, nil } type Session struct { diff --git a/internal/clientcache/internal/cache/repository_sessions_test.go b/internal/clientcache/internal/cache/repository_sessions_test.go index 45e6847951..4627f8cd88 100644 --- a/internal/clientcache/internal/cache/repository_sessions_test.go +++ b/internal/clientcache/internal/cache/repository_sessions_test.go @@ -6,6 +6,7 @@ package cache import ( "context" "encoding/json" + "strconv" "sync" "testing" "time" @@ -239,7 +240,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err := r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Sessions, 2) // Refreshing again uses the refresh token and get additional sessions, appending // them to the response @@ -249,7 +250,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Sessions, 3) // Refreshing again wont return any more resources, but also none should be // removed @@ -259,7 +260,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Sessions, 3) // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, @@ -268,7 +269,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Sessions, 2) } func TestRepository_ListSessions(t *testing.T) { @@ -355,13 +356,73 @@ func TestRepository_ListSessions(t *testing.T) { t.Run("wrong user gets no sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt2.AuthTokenId) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Sessions) }) t.Run("correct token gets sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt1.AuthTokenId) assert.NoError(t, err) - assert.Len(t, l, len(ss)) - assert.ElementsMatch(t, l, ss) + assert.Len(t, l.Sessions, len(ss)) + assert.ElementsMatch(t, l.Sessions, ss) + }) +} + +func TestRepository_ListSessionsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*sessions.Session + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, session("s"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Sessions, + AuthTokenId: kt.AuthTokenId, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } @@ -471,13 +532,74 @@ func TestRepository_QuerySessions(t *testing.T) { t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Sessions) }) t.Run("correct token gets sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) - assert.Len(t, l, 2) - assert.ElementsMatch(t, l, ss[0:2]) + assert.Len(t, l.Sessions, 2) + assert.ElementsMatch(t, l.Sessions, ss[0:2]) + }) +} + +func TestRepository_QuerySessionsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*sessions.Session + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, session("t"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Sessions, + AuthTokenId: kt.AuthTokenId, + Query: `(id % 'session')`, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } diff --git a/internal/clientcache/internal/cache/repository_targets.go b/internal/clientcache/internal/cache/repository_targets.go index f52c9c0785..e5cfa6e1f9 100644 --- a/internal/clientcache/internal/cache/repository_targets.go +++ b/internal/clientcache/internal/cache/repository_targets.go @@ -291,20 +291,20 @@ func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Targ return nil } -func (r *Repository) ListTargets(ctx context.Context, authTokenId string) ([]*targets.Target, error) { +func (r *Repository) ListTargets(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).ListTargets" switch { case authTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") } - ret, err := r.searchTargets(ctx, "true", nil, withAuthTokenId(authTokenId)) + ret, err := r.searchTargets(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) QueryTargets(ctx context.Context, authTokenId, query string) ([]*targets.Target, error) { +func (r *Repository) QueryTargets(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).QueryTargets" switch { case authTokenId == "": @@ -317,14 +317,14 @@ func (r *Repository) QueryTargets(ctx context.Context, authTokenId, query string if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.InvalidParameter)) } - ret, err := r.searchTargets(ctx, w.Condition, w.Args, withAuthTokenId(authTokenId)) + ret, err := r.searchTargets(ctx, w.Condition, w.Args, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) searchTargets(ctx context.Context, condition string, searchArgs []any, opt ...Option) ([]*targets.Target, error) { +func (r *Repository) searchTargets(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).searchTargets" switch { case condition == "": @@ -349,7 +349,7 @@ func (r *Repository) searchTargets(ctx context.Context, condition string, search } var cachedTargets []*Target - if err := r.rw.SearchWhere(ctx, &cachedTargets, condition, searchArgs, db.WithLimit(-1)); err != nil { + if err := r.rw.SearchWhere(ctx, &cachedTargets, condition, searchArgs, db.WithLimit(opts.withMaxResultSetSize+1)); err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -361,7 +361,15 @@ func (r *Repository) searchTargets(ctx context.Context, condition string, search } retTargets = append(retTargets, &tar) } - return retTargets, nil + + sr := &SearchResult{ + Targets: retTargets, + } + if opts.withMaxResultSetSize > 0 && len(sr.Targets) > opts.withMaxResultSetSize { + sr.Targets = sr.Targets[:opts.withMaxResultSetSize] + sr.Incomplete = true + } + return sr, nil } type Target struct { diff --git a/internal/clientcache/internal/cache/repository_targets_test.go b/internal/clientcache/internal/cache/repository_targets_test.go index a6923ac815..a3c705b513 100644 --- a/internal/clientcache/internal/cache/repository_targets_test.go +++ b/internal/clientcache/internal/cache/repository_targets_test.go @@ -6,6 +6,7 @@ package cache import ( "context" "encoding/json" + "strconv" "sync" "testing" @@ -264,7 +265,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err := r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Targets, 2) // Refreshing again uses the refresh token and get additional sessions, appending // them to the response @@ -274,7 +275,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Targets, 3) // Refreshing again wont return any more resources, but also none should be // removed @@ -284,7 +285,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Targets, 3) // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, @@ -293,7 +294,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Targets, 2) } func TestRepository_ListTargets(t *testing.T) { @@ -349,13 +350,73 @@ func TestRepository_ListTargets(t *testing.T) { t.Run("wrong user gets no targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt2.AuthTokenId) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Targets) }) t.Run("correct token gets targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt1.AuthTokenId) assert.NoError(t, err) - assert.Len(t, l, len(ts)) - assert.ElementsMatch(t, l, ts) + assert.Len(t, l.Targets, len(ts)) + assert.ElementsMatch(t, l.Targets, ts) + }) +} + +func TestRepository_ListTargetsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*targets.Target + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, target("t"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshTargets(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Targets, + AuthTokenId: kt.AuthTokenId, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } @@ -456,13 +517,74 @@ func TestRepository_QueryTargets(t *testing.T) { t.Run("wrong token gets no targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Targets) }) t.Run("correct token gets targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) - assert.Len(t, l, 2) - assert.ElementsMatch(t, l, ts[0:2]) + assert.Len(t, l.Targets, 2) + assert.ElementsMatch(t, l.Targets, ts[0:2]) + }) +} + +func TestRepository_QueryTargetsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*targets.Target + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, target("t"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshTargets(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Targets, + AuthTokenId: kt.AuthTokenId, + Query: `(name % 'name')`, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } diff --git a/internal/clientcache/internal/cache/search.go b/internal/clientcache/internal/cache/search.go index 2239030228..a26aa1a9c9 100644 --- a/internal/clientcache/internal/cache/search.go +++ b/internal/clientcache/internal/cache/search.go @@ -55,6 +55,8 @@ type SearchParams struct { Query string // the optional bexpr filter string that all results will be filtered by Filter string + // Max result set size is an override to the default max result set size + MaxResultSetSize int } // SearchResult returns the results from searching the cache. @@ -62,6 +64,10 @@ type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` Targets []*targets.Target `json:"targets,omitempty"` Sessions []*sessions.Session `json:"sessions,omitempty"` + + // Incomplete is true if the search results are incomplete, that is, we are + // returning only a subset based on the max result set size + Incomplete bool `json:"incomplete,omitempty"` } // SearchService is a domain service that can search across all resources in the @@ -83,22 +89,40 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er ResolvableAliases: &resourceSearchFns[*aliases.Alias]{ list: repo.ListResolvableAliases, query: repo.QueryResolvableAliases, - searchResult: func(a []*aliases.Alias) *SearchResult { - return &SearchResult{ResolvableAliases: a} + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*aliases.Alias, 0, len(in.ResolvableAliases)) + for _, item := range in.ResolvableAliases { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.ResolvableAliases = finalResults }, }, Targets: &resourceSearchFns[*targets.Target]{ list: repo.ListTargets, query: repo.QueryTargets, - searchResult: func(t []*targets.Target) *SearchResult { - return &SearchResult{Targets: t} + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*targets.Target, 0, len(in.Targets)) + for _, item := range in.Targets { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.Targets = finalResults }, }, Sessions: &resourceSearchFns[*sessions.Session]{ list: repo.ListSessions, query: repo.QuerySessions, - searchResult: func(s []*sessions.Session) *SearchResult { - return &SearchResult{Sessions: s} + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*sessions.Session, 0, len(in.Sessions)) + for _, item := range in.Sessions { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.Sessions = finalResults }, }, }, @@ -160,19 +184,15 @@ type resourceSearchFns[T any] struct { // list takes a context and an auth token and returns all resources for the // user of that auth token. If the provided auth token is not in the cache // an empty slice and no error is returned. - list func(context.Context, string) ([]T, error) + list func(context.Context, string, ...Option) (*SearchResult, error) // query takes a context, an auth token, and a query string and returns all // resources for that auth token that matches the provided query parameter. // If the provided auth token is not in the cache an empty slice and no // error is returned. - query func(context.Context, string, string) ([]T, error) - // searchResult is a function which provides a SearchResult based on the - // type of T. SearchResult contains different fields for the different - // resource types returned, so for example if T is *targets.Target the - // returned SearchResult will have it's "Targets" field populated so the - // searchResult should take the passed in paramater and assign it to the - // appropriate field in the SearchResult. - searchResult func([]T) *SearchResult + query func(context.Context, string, string, ...Option) (*SearchResult, error) + // filter takes results and a ready-to-use evaluator and filters the items + // in the result + filter func(*SearchResult, *bexpr.Evaluator) } // resourceSearcher is an interface that only resourceSearchFns[T] is expected @@ -191,33 +211,29 @@ type resourceSearcher interface { func (l *resourceSearchFns[T]) search(ctx context.Context, p SearchParams) (*SearchResult, error) { const op = "cache.(resourceSearchFns).search" - var found []T + var found *SearchResult var err error switch p.Query { case "": - found, err = l.list(ctx, p.AuthTokenId) + found, err = l.list(ctx, p.AuthTokenId, WithMaxResultSetSize(p.MaxResultSetSize)) default: - found, err = l.query(ctx, p.AuthTokenId, p.Query) + found, err = l.query(ctx, p.AuthTokenId, p.Query, WithMaxResultSetSize(p.MaxResultSetSize)) } if err != nil { return nil, errors.Wrap(ctx, err, op) } if p.Filter == "" { - return l.searchResult(found), nil + return found, nil } e, err := bexpr.CreateEvaluator(p.Filter, bexpr.WithTagName("json")) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("couldn't build filter"), errors.WithCode(errors.InvalidParameter)) } - finalResults := make([]T, 0, len(found)) - for _, item := range found { - if m, err := e.Evaluate(filterItem{item}); err == nil && m { - finalResults = append(finalResults, item) - } - } - return l.searchResult(finalResults), nil + + l.filter(found, e) + return found, nil } type filterItem struct { diff --git a/internal/clientcache/internal/daemon/404_handler.go b/internal/clientcache/internal/daemon/404_handler.go index e373319afe..4216cb9ef4 100644 --- a/internal/clientcache/internal/daemon/404_handler.go +++ b/internal/clientcache/internal/daemon/404_handler.go @@ -9,7 +9,7 @@ import ( ) // new404Func creates a handler that returns a custom 404 error message. -func new404Func(ctx context.Context) http.HandlerFunc { +func new404Func(_ context.Context) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { writeError(w, "Not found", http.StatusNotFound) } diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index d79a74b781..f450a76491 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -26,14 +26,16 @@ type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` Targets []*targets.Target `json:"targets,omitempty"` Sessions []*sessions.Session `json:"sessions,omitempty"` + Incomplete bool `json:"incomplete,omitempty"` } const ( - filterKey = "filter" - queryKey = "query" - resourceKey = "resource" - forceRefreshKey = "force_refresh" - authTokenIdKey = "auth_token_id" + filterKey = "filter" + queryKey = "query" + resourceKey = "resource" + forceRefreshKey = "force_refresh" + authTokenIdKey = "auth_token_id" + maxResultSetSizeKey = "max_result_set_size" ) func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshService *cache.RefreshService, logger hclog.Logger) (http.HandlerFunc, error) { @@ -54,8 +56,11 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe return func(w http.ResponseWriter, r *http.Request) { reqCtx := r.Context() - resource := r.URL.Query().Get(resourceKey) - authTokenId := r.URL.Query().Get(authTokenIdKey) + q := r.URL.Query() + resource := q.Get(resourceKey) + authTokenId := q.Get(authTokenIdKey) + maxResultSetSizeStr := q.Get(maxResultSetSizeKey) + maxResultSetSizeInt, maxResultSetSizeIntErr := strconv.Atoi(maxResultSetSizeStr) searchableResource := cache.ToSearchableResource(resource) switch { @@ -71,6 +76,14 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s is a required field but was empty", authTokenIdKey))) writeError(w, fmt.Sprintf("%s is a required field but was empty", authTokenIdKey), http.StatusBadRequest) return + case maxResultSetSizeStr != "" && maxResultSetSizeIntErr != nil: + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s is not able to be parsed as an integer", maxResultSetSizeStr))) + writeError(w, fmt.Sprintf("%s is not able to be parsed as an integer", maxResultSetSizeStr), http.StatusBadRequest) + return + case maxResultSetSizeInt < -1: + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s must be greater than or equal to -1", maxResultSetSizeStr))) + writeError(w, fmt.Sprintf("%s must be greater than or equal to -1", maxResultSetSizeStr), http.StatusBadRequest) + return } t, err := repo.LookupToken(reqCtx, authTokenId, cache.WithUpdateLastAccessedTime(true)) @@ -114,10 +127,11 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe filter := r.URL.Query().Get(filterKey) res, err := s.Search(reqCtx, cache.SearchParams{ - AuthTokenId: authTokenId, - Resource: searchableResource, - Query: query, - Filter: filter, + AuthTokenId: authTokenId, + Resource: searchableResource, + Query: query, + Filter: filter, + MaxResultSetSize: maxResultSetSizeInt, }) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("when performing search", "auth_token_id", authTokenId, "resource", searchableResource, "query", query, "filter", filter)) @@ -152,6 +166,7 @@ func toApiResult(sr *cache.SearchResult) *SearchResult { ResolvableAliases: sr.ResolvableAliases, Targets: sr.Targets, Sessions: sr.Sessions, + Incomplete: sr.Incomplete, } } From 4d8a7dd39add297a28334f65ea1daa9f65007de6 Mon Sep 17 00:00:00 2001 From: Jim Date: Wed, 28 Aug 2024 13:26:50 -0400 Subject: [PATCH 02/15] feat (clientcache): use a persistent cache (#5051) By default, we will use a persistent cache for the client cache. This will allow us to keep the cache across restarts of the server and it will also reduce the amount of time it takes to start the server while reducing the amount of memory used. This change also includes a validation of the cache schema version. If the schema version is different from the one expected, the cache will be reset/recreated. (cherry picked from commit cf8a51f3834033b7f986a9a7e091ab31c4328415) --- go.mod | 2 +- internal/clientcache/cmd/cache/start.go | 3 +- .../clientcache/internal/daemon/options.go | 13 ++- .../internal/daemon/options_test.go | 8 +- .../clientcache/internal/daemon/server.go | 39 +++++++++ .../internal/daemon/server_test.go | 18 ++++ internal/clientcache/internal/db/db.go | 84 ++++++++++++++++++- internal/clientcache/internal/db/db_test.go | 66 +++++++++++++++ internal/clientcache/internal/db/options.go | 10 +++ .../clientcache/internal/db/options_test.go | 8 ++ internal/clientcache/internal/db/schema.sql | 34 ++++++++ .../clientcache/internal/db/schema_reset.sql | 11 +++ 12 files changed, 289 insertions(+), 7 deletions(-) create mode 100644 internal/clientcache/internal/db/db_test.go create mode 100644 internal/clientcache/internal/db/schema_reset.sql diff --git a/go.mod b/go.mod index 0ac8763466..2864f17342 100644 --- a/go.mod +++ b/go.mod @@ -100,7 +100,6 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 github.com/miekg/dns v1.1.58 github.com/mikesmitty/edkey v0.0.0-20170222072505-3356ea4e686a - github.com/mitchellh/go-homedir v1.1.0 github.com/sevlyar/go-daemon v0.1.6 golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 golang.org/x/net v0.25.0 @@ -127,6 +126,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/user v0.1.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect diff --git a/internal/clientcache/cmd/cache/start.go b/internal/clientcache/cmd/cache/start.go index 93a914c5f7..c94d507e09 100644 --- a/internal/clientcache/cmd/cache/start.go +++ b/internal/clientcache/cmd/cache/start.go @@ -19,7 +19,6 @@ import ( "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/errors" "github.com/mitchellh/cli" - "github.com/mitchellh/go-homedir" "github.com/posener/complete" "gopkg.in/natefinch/lumberjack.v2" ) @@ -243,7 +242,7 @@ func (c *StartCommand) Run(args []string) int { // DefaultDotDirectory returns the default path to the boundary dot directory. func DefaultDotDirectory(ctx context.Context) (string, error) { const op = "cache.DefaultDotDirectory" - homeDir, err := homedir.Dir() + homeDir, err := os.UserHomeDir() if err != nil { return "", errors.Wrap(ctx, err, op) } diff --git a/internal/clientcache/internal/daemon/options.go b/internal/clientcache/internal/daemon/options.go index ba20725409..daed203df3 100644 --- a/internal/clientcache/internal/daemon/options.go +++ b/internal/clientcache/internal/daemon/options.go @@ -20,8 +20,9 @@ type options struct { WithReadyToServeNotificationCh chan struct{} withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn - withUrl string - withLogger hclog.Logger + withUrl string + withLogger hclog.Logger + withHomeDir string } // Option - how options are passed as args @@ -42,6 +43,14 @@ func getOpts(opt ...Option) (options, error) { return opts, nil } +// WithHomeDir provides an optional home directory to use. +func WithHomeDir(_ context.Context, dir string) Option { + return func(o *options) error { + o.withHomeDir = dir + return nil + } +} + // withRefreshInterval provides an optional refresh interval. func withRefreshInterval(_ context.Context, d time.Duration) Option { return func(o *options) error { diff --git a/internal/clientcache/internal/daemon/options_test.go b/internal/clientcache/internal/daemon/options_test.go index 1025aa0381..0afcbab1b5 100644 --- a/internal/clientcache/internal/daemon/options_test.go +++ b/internal/clientcache/internal/daemon/options_test.go @@ -101,6 +101,13 @@ func Test_GetOpts(t *testing.T) { testOpts := getDefaultOptions() assert.Equal(t, opts, testOpts) }) + t.Run("WithHomeDir", func(t *testing.T) { + opts, err := getOpts(WithHomeDir(ctx, "/tmp")) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withHomeDir = "/tmp" + assert.Equal(t, opts, testOpts) + }) t.Run("WithReadyToServeNotificationCh", func(t *testing.T) { ch := make(chan struct{}) opts, err := getOpts(WithReadyToServeNotificationCh(ctx, ch)) @@ -109,6 +116,5 @@ func Test_GetOpts(t *testing.T) { testOpts := getDefaultOptions() assert.Nil(t, testOpts.WithReadyToServeNotificationCh) testOpts.WithReadyToServeNotificationCh = ch - assert.Equal(t, opts, testOpts) }) } diff --git a/internal/clientcache/internal/daemon/server.go b/internal/clientcache/internal/daemon/server.go index 79a80e67a8..cd0f13e77c 100644 --- a/internal/clientcache/internal/daemon/server.go +++ b/internal/clientcache/internal/daemon/server.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "path/filepath" "sort" "strings" "sync" @@ -499,6 +500,10 @@ func setupEventing(ctx context.Context, logger hclog.Logger, serializationLock * return nil } +// openStore will open the underlying store for the db. If no options are +// provided, it will default to an on disk store using the user's home dir + +// ".boundary/cache.db". If a url is provided, it will use that as the store. +// Supported options: WithUrl, WithLogger, WithHomeDir func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { const op = "daemon.openStore" opts, err := getOpts(opt...) @@ -514,6 +519,12 @@ func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { return nil, errors.Wrap(ctx, err, op) } dbOpts = append(dbOpts, cachedb.WithUrl(url)) + default: + url, err := defaultDbUrl(ctx, opt...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + dbOpts = append(dbOpts, cachedb.WithUrl(url)) } if !util.IsNil(opts.withLogger) { dbOpts = append(dbOpts, cachedb.WithGormFormatter(opts.withLogger)) @@ -524,3 +535,31 @@ func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { } return store, nil } + +// defaultDbUrl returns the default db name including the path. It will ensure +// the directory exists by creating it if it doesn't. +func defaultDbUrl(ctx context.Context, opt ...Option) (string, error) { + const op = "daemon.DefaultDotDirectory" + opts, err := getOpts(opt...) + if err != nil { + return "", errors.Wrap(ctx, err, op) + } + if opts.withHomeDir == "" { + opts.withHomeDir, err = os.UserHomeDir() + if err != nil { + return "", errors.Wrap(ctx, err, op) + } + } + dotDir := filepath.Join(opts.withHomeDir, dotDirname) + if err := os.MkdirAll(dotDir, 0o700); err != nil { + return "", errors.Wrap(ctx, err, op) + } + fileName := filepath.Join(dotDir, dbFileName) + return fmt.Sprintf("%s%s", fileName, fkPragma), nil +} + +const ( + dotDirname = ".boundary" + dbFileName = "cache.db" + fkPragma = "?_pragma=foreign_keys(1)" +) diff --git a/internal/clientcache/internal/daemon/server_test.go b/internal/clientcache/internal/daemon/server_test.go index 1afafedcaf..2f7e947c90 100644 --- a/internal/clientcache/internal/daemon/server_test.go +++ b/internal/clientcache/internal/daemon/server_test.go @@ -15,6 +15,24 @@ import ( "github.com/stretchr/testify/require" ) +func Test_openStore(t *testing.T) { + ctx := context.Background() + t.Run("success", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := openStore(ctx, WithUrl(ctx, tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/test.db") + }) + t.Run("homedir", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := openStore(ctx, WithHomeDir(ctx, tmpDir)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/"+dotDirname+"/"+dbFileName) + }) +} + // Note: the name of this test must remain short because the temp dir created // includes the name of the test and there is a 108 character limit in allowed // unix socket path names. diff --git a/internal/clientcache/internal/db/db.go b/internal/clientcache/internal/db/db.go index 31907a3b0f..974faa7fbf 100644 --- a/internal/clientcache/internal/db/db.go +++ b/internal/clientcache/internal/db/db.go @@ -7,6 +7,8 @@ import ( "context" _ "embed" "fmt" + "strings" + "time" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" @@ -17,11 +19,16 @@ import ( //go:embed schema.sql var cacheSchema string +//go:embed schema_reset.sql +var cacheSchemaReset string + // DefaultStoreUrl uses a temp in-memory sqlite database see: https://www.sqlite.org/inmemorydb.html const DefaultStoreUrl = "file::memory:?_pragma=foreign_keys(1)" // Open creates a database connection. WithUrl is supported, but by default it // uses an in memory sqlite table. Sqlite is the only supported dbtype. +// Supported options: WithUrl, WithGormFormatter, WithDebug, +// WithTestValidSchemaVersion (for testing purposes) func Open(ctx context.Context, opt ...Option) (*db.DB, error) { const op = "db.Open" opts, err := getOpts(opt...) @@ -50,16 +57,38 @@ func Open(ctx context.Context, opt ...Option) (*db.DB, error) { conn.Debug(opts.withDebug) switch { - case opts.withDbType == dbw.Sqlite: + case opts.withDbType == dbw.Sqlite && url == DefaultStoreUrl: if err := createTables(ctx, conn); err != nil { return nil, errors.Wrap(ctx, err, op) } + case opts.withDbType == dbw.Sqlite && url != DefaultStoreUrl: + ok, err := validSchema(ctx, conn, opt...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + if !ok { + if err := resetSchema(ctx, conn); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + if err := createTables(ctx, conn); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + } default: return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%q is not a supported cache store type", opts.withDbType)) } return conn, nil } +func resetSchema(ctx context.Context, conn *db.DB) error { + const op = "db.resetSchema" + rw := db.New(conn) + if _, err := rw.Exec(ctx, cacheSchemaReset, nil); err != nil { + return errors.Wrap(ctx, err, op) + } + return nil +} + func createTables(ctx context.Context, conn *db.DB) error { const op = "db.createTables" rw := db.New(conn) @@ -68,3 +97,56 @@ func createTables(ctx context.Context, conn *db.DB) error { } return nil } + +// validSchema checks of the schema is valid based on its version. Options +// supported: withTestValidSchemaVersion (for testing purposes) +func validSchema(ctx context.Context, conn *db.DB, opt ...Option) (bool, error) { + const op = "validateSchema" + switch { + case conn == nil: + return false, errors.New(ctx, errors.InvalidParameter, op, "conn is missing") + } + opts, err := getOpts(opt...) + if err != nil { + return false, errors.Wrap(ctx, err, op) + } + if opts.withSchemaVersion == "" { + opts.withSchemaVersion = schemaCurrentVersion + } + + rw := db.New(conn) + s := schema{} + err = rw.LookupWhere(ctx, &s, "1=1", nil) + switch { + case err != nil && strings.Contains(err.Error(), "no such table: schema_version"): + return false, nil + case err != nil: + // not sure if we should return the error or just return false so the + // schema is recreated... for now return the error. + return false, fmt.Errorf("%s: unable to get version: %w", op, err) + case s.Version != opts.withSchemaVersion: + return false, nil + default: + return true, nil + } +} + +// schema represents the current schema in the database +type schema struct { + // Version of the schema + Version string + // UpdateTime is the last update of the version + UpdateTime time.Time + // CreateTime is the create time of the initial version + CreateTime time.Time +} + +const ( + schemaTableName = "schema_version" + schemaCurrentVersion = "v0.0.1" +) + +// TableName returns the table name +func (s *schema) TableName() string { + return schemaTableName +} diff --git a/internal/clientcache/internal/db/db_test.go b/internal/clientcache/internal/db/db_test.go new file mode 100644 index 0000000000..062f55caa1 --- /dev/null +++ b/internal/clientcache/internal/db/db_test.go @@ -0,0 +1,66 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package db + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpen(t *testing.T) { + ctx := context.Background() + t.Run("success-file-url-with-reopening", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/test.db") + + info, err := os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + origCreatedAt := info.ModTime() + + // Reopen the db and make sure the file is not recreated + db, err = Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + info, err = os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + assert.Equal(t, origCreatedAt, info.ModTime()) + }) + t.Run("success-mem-default-url", func(t *testing.T) { + db, err := Open(ctx) + require.NoError(t, err) + require.NotNil(t, db) + }) + t.Run("recreate-on-version-mismatch", func(t *testing.T) { + tmpDir := t.TempDir() + db, err := Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma)) + require.NoError(t, err) + require.NotNil(t, db) + assert.FileExists(t, tmpDir+"/test.db") + info, err := os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + origCreatedAt := info.ModTime() + + // Reopen the db with a different schema version: forcing the db to be recreated + db, err = Open(ctx, WithUrl(tmpDir+"/test.db"+fkPragma), withTestValidSchemaVersion("2")) + require.NoError(t, err) + require.NotNil(t, db) + info, err = os.Stat(tmpDir + "/test.db") + require.NoError(t, err) + // The file should have been recreated with a new timestamp + assert.NotEqual(t, origCreatedAt, info.ModTime()) + }) +} + +const ( + dotDirname = ".boundary" + dbFileName = "cache.db" + fkPragma = "?_pragma=foreign_keys(1)" +) diff --git a/internal/clientcache/internal/db/options.go b/internal/clientcache/internal/db/options.go index ba8c3382bd..30186c90b1 100644 --- a/internal/clientcache/internal/db/options.go +++ b/internal/clientcache/internal/db/options.go @@ -9,6 +9,7 @@ import ( ) type options struct { + withSchemaVersion string withDebug bool withUrl string withDbType dbw.DbType @@ -42,6 +43,15 @@ func WithGormFormatter(logger hclog.Logger) Option { } } +// withTestValidSchemaVersion provides optional valid schema version for testing +// purposes. This is used to simulate a schema version that is valid/invalid. +func withTestValidSchemaVersion(useVersion string) Option { + return func(o *options) error { + o.withSchemaVersion = useVersion + return nil + } +} + // WithUrls provides optional url func WithUrl(url string) Option { return func(o *options) error { diff --git a/internal/clientcache/internal/db/options_test.go b/internal/clientcache/internal/db/options_test.go index f57591ef60..b67096b081 100644 --- a/internal/clientcache/internal/db/options_test.go +++ b/internal/clientcache/internal/db/options_test.go @@ -37,4 +37,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withDebug = true assert.Equal(t, opts, testOpts) }) + t.Run("withTestValidSchemaVersion", func(t *testing.T) { + version := "v1" + opts, err := getOpts(withTestValidSchemaVersion(version)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withSchemaVersion = version + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/clientcache/internal/db/schema.sql b/internal/clientcache/internal/db/schema.sql index bf5023b196..ddd635b2d1 100644 --- a/internal/clientcache/internal/db/schema.sql +++ b/internal/clientcache/internal/db/schema.sql @@ -2,6 +2,40 @@ -- SPDX-License-Identifier: BUSL-1.1 begin; + +-- schema_version is a one row table to keep the version +create table if not exists schema_version ( + version text not null, + create_time timestamp not null default current_timestamp, + update_time timestamp not null default current_timestamp +); + +-- ensure that it's only ever one row +create unique index schema_version_one_row +ON schema_version((version is not null)); + +create trigger immutable_columns_schema_version +before update on schema_version +for each row + when + new.create_time <> old.create_time + begin + select raise(abort, 'immutable column'); + end; + + +create trigger update_time_column_schema_version +before update on schema_version +for each row +when + new.version <> old.version + begin + update schema_version set update_time = datetime('now','localtime') where rowid == new.rowid; + end; + + +insert into schema_version(version) values('v0.0.1'); + -- user contains the boundary user information for the boundary user that owns -- the information in the cache. create table if not exists user ( diff --git a/internal/clientcache/internal/db/schema_reset.sql b/internal/clientcache/internal/db/schema_reset.sql new file mode 100644 index 0000000000..8c58d56d4b --- /dev/null +++ b/internal/clientcache/internal/db/schema_reset.sql @@ -0,0 +1,11 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +-- cannot vacuum from within a transaction, so we're not using a transaction +-- when running these statements +PRAGMA writable_schema = 1; +DELETE FROM sqlite_master; +PRAGMA writable_schema = 0; +VACUUM; +PRAGMA integrity_check; + From b04de5ac4b4c1ea5b0411ae1f9ecae10f9e9b392 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 28 Aug 2024 21:18:38 -0400 Subject: [PATCH 03/15] Add `implicit-scopes` resource type to cache (#5053) This allows getting a list of scope IDs known by the cache via cached targets and sessions. It is not refreshed from the controller. The name contains `implicit` in case we ever want e.g. `scopes`. (cherry picked from commit c0a318c58367d8ec55365447f41cf007eab64ce3) --- internal/clientcache/cmd/search/search.go | 40 +++++- .../cache/repository_implicit_scopes.go | 114 ++++++++++++++++ .../cache/repository_implicit_scopes_test.go | 124 ++++++++++++++++++ internal/clientcache/internal/cache/search.go | 20 ++- .../internal/daemon/search_handler.go | 30 ++++- 5 files changed, 318 insertions(+), 10 deletions(-) create mode 100644 internal/clientcache/internal/cache/repository_implicit_scopes.go create mode 100644 internal/clientcache/internal/cache/repository_implicit_scopes_test.go diff --git a/internal/clientcache/cmd/search/search.go b/internal/clientcache/cmd/search/search.go index e095671811..7075a31c7a 100644 --- a/internal/clientcache/cmd/search/search.go +++ b/internal/clientcache/cmd/search/search.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/aliases" + "github.com/hashicorp/boundary/api/scopes" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" cachecmd "github.com/hashicorp/boundary/internal/clientcache/cmd/cache" @@ -34,6 +35,7 @@ var ( "resolvable-aliases", "targets", "sessions", + "implicit-scopes", } errCacheNotRunning = stderrors.New("The cache process is not running.") @@ -181,9 +183,6 @@ func (c *SearchCommand) Run(args []string) int { return base.CommandCliError } default: - if result.Incomplete { - c.UI.Warn("The maximum result set size was reached and the search results are incomplete. Please narrow your search or adjust the -max-result-set-size parameter.") - } switch { case len(result.ResolvableAliases) > 0: c.UI.Output(printAliasListTable(result.ResolvableAliases)) @@ -191,9 +190,17 @@ func (c *SearchCommand) Run(args []string) int { c.UI.Output(printTargetListTable(result.Targets)) case len(result.Sessions) > 0: c.UI.Output(printSessionListTable(result.Sessions)) + case len(result.ImplicitScopes) > 0: + c.UI.Output(printImplicitScopesListTable(result.ImplicitScopes)) default: c.UI.Output("No items found") } + + // Put this at the end or people may not see it as they may not scroll + // all the way up. + if result.Incomplete { + c.UI.Warn("The maximum result set size was reached and the search results are incomplete. Please narrow your search or adjust the -max-result-set-size parameter.") + } } return base.CommandSuccess } @@ -449,6 +456,33 @@ func printSessionListTable(items []*sessions.Session) string { return base.WrapForHelpText(output) } +func printImplicitScopesListTable(items []*scopes.Scope) string { + if len(items) == 0 { + return "No implicit scopes found" + } + var output []string + output = []string{ + "", + "Scope information:", + } + for i, item := range items { + if i > 0 { + output = append(output, "") + } + if item.Id != "" { + output = append(output, + fmt.Sprintf(" ID: %s", item.Id), + ) + } else { + output = append(output, + fmt.Sprintf(" ID: %s", "(not available)"), + ) + } + } + + return base.WrapForHelpText(output) +} + type filterBy struct { flagFilter string flagQuery string diff --git a/internal/clientcache/internal/cache/repository_implicit_scopes.go b/internal/clientcache/internal/cache/repository_implicit_scopes.go new file mode 100644 index 0000000000..2a031839e6 --- /dev/null +++ b/internal/clientcache/internal/cache/repository_implicit_scopes.go @@ -0,0 +1,114 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "fmt" + "slices" + + "github.com/hashicorp/boundary/api/scopes" + "github.com/hashicorp/boundary/internal/errors" +) + +func (r *Repository) ListImplicitScopes(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { + const op = "cache.(Repository).ListImplicitScopes" + switch { + case authTokenId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") + } + ret, err := r.searchImplicitScopes(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + return ret, nil +} + +// QueryImplicitScopes is not supported currently so we return an error message +func (r *Repository) QueryImplicitScopes(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { + const op = "cache.(Repository).QueryImplicitScopes" + + // Internal is used as we have checks at the handler level to ensure this + // can't be used so it's an internal error if we actually call this + // function. + return nil, errors.New(ctx, errors.Internal, op, "querying implicit scopes is not supported") +} + +func (r *Repository) searchImplicitScopes(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { + const op = "cache.(Repository).searchImplicitScopes" + switch { + case condition == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "condition is missing") + } + + opts, err := getOpts(opt...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + switch { + case opts.withAuthTokenId != "" && opts.withUserId != "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "both user id and auth token id were provided") + case opts.withAuthTokenId == "" && opts.withUserId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "neither user id nor auth token id were provided") + + // In these cases we append twice because we're doing a union of two tables + case opts.withAuthTokenId != "": + condition = "where fk_user_id in (select user_id from auth_token where id = ?)" + searchArgs = append(searchArgs, opts.withAuthTokenId, opts.withAuthTokenId) + case opts.withUserId != "": + condition = "where fk_user_id = ?" + searchArgs = append(searchArgs, opts.withUserId, opts.withUserId) + } + + const unionQueryBase = ` + select distinct fk_user_id, scope_id from session + %s + union + select distinct fk_user_id, scope_id from target + %s +` + unionQuery := fmt.Sprintf(unionQueryBase, condition, condition) + + rows, err := r.rw.Query(ctx, unionQuery, searchArgs) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + defer rows.Close() + + type ScopeIdsResult struct { + FkUserId string `gorm:"primaryKey"` + ScopeId string `gorm:"default:null"` + } + + var scopeIdsResults []ScopeIdsResult + for rows.Next() { + var res ScopeIdsResult + if err := r.rw.ScanRows(ctx, rows, &res); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + scopeIdsResults = append(scopeIdsResults, res) + } + if err := rows.Err(); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + + dedupMap := make(map[string]struct{}, len(scopeIdsResults)) + for _, res := range scopeIdsResults { + dedupMap[res.ScopeId] = struct{}{} + } + scopeIds := make([]string, 0, len(dedupMap)) + for k := range dedupMap { + scopeIds = append(scopeIds, k) + } + slices.Sort(scopeIds) + + sr := &SearchResult{ + ImplicitScopes: make([]*scopes.Scope, 0, len(dedupMap)), + } + for _, scopeId := range scopeIds { + sr.ImplicitScopes = append(sr.ImplicitScopes, &scopes.Scope{Id: scopeId}) + } + + return sr, nil +} diff --git a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go new file mode 100644 index 0000000000..f5bc226337 --- /dev/null +++ b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go @@ -0,0 +1,124 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "sync" + "testing" + + "github.com/hashicorp/boundary/api/authtokens" + "github.com/hashicorp/boundary/api/scopes" + "github.com/hashicorp/boundary/api/sessions" + "github.com/hashicorp/boundary/api/targets" + cachedb "github.com/hashicorp/boundary/internal/clientcache/internal/db" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" +) + +func TestRepository_ImplicitScopes(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u1 := &user{ + Id: "u1", + Address: addr, + } + at1 := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u1.Id, + } + kt1 := KeyringToken{KeyringType: "k1", TokenName: "t1", AuthTokenId: at1.Id} + + u2 := &user{ + Id: "u2", + Address: addr, + } + at2 := &authtokens.AuthToken{ + Id: "at_2", + Token: "at_2_token", + UserId: u2.Id, + } + kt2 := KeyringToken{KeyringType: "k2", TokenName: "t2", AuthTokenId: at2.Id} + atMap := map[ringToken]*authtokens.AuthToken{ + {"k1", "t1"}: at1, + {"k2", "t2"}: at2, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt1)) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt2)) + + var expectedScopes []*scopes.Scope + + ts := []*targets.Target{ + target("1"), + target("2"), + target("3"), + } + require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + + for _, t := range ts { + expectedScopes = append(expectedScopes, &scopes.Scope{ + Id: t.ScopeId, + }) + } + + ss := []*sessions.Session{ + { + Id: "ttcp_1", + Status: "status1", + Endpoint: "address1", + ScopeId: "p_123", + TargetId: "ttcp_123", + UserId: "u_123", + Type: "tcp", + }, + { + Id: "ttcp_2", + Status: "status2", + Endpoint: "address2", + ScopeId: "p_123", + TargetId: "ttcp_123", + UserId: "u_123", + Type: "tcp", + }, + { + Id: "ttcp_3", + Status: "status3", + Endpoint: "address3", + ScopeId: "p_123", + TargetId: "ttcp_123", + UserId: "u_123", + Type: "tcp", + }, + } + require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + + expectedScopes = append(expectedScopes, &scopes.Scope{ + Id: ss[0].ScopeId, + }) + + t.Run("wrong user gets no implicit scopes", func(t *testing.T) { + l, err := r.ListImplicitScopes(ctx, kt2.AuthTokenId) + require.NoError(t, err) + assert.Empty(t, l.ImplicitScopes) + }) + t.Run("correct token gets implicit scopes from listing", func(t *testing.T) { + l, err := r.ListImplicitScopes(ctx, kt1.AuthTokenId) + require.NoError(t, err) + assert.Len(t, l.ImplicitScopes, len(expectedScopes)) + assert.ElementsMatch(t, l.ImplicitScopes, expectedScopes) + }) + t.Run("querying returns error", func(t *testing.T) { + _, err := r.QueryImplicitScopes(ctx, kt1.AuthTokenId, "anything") + require.Error(t, err) + }) +} diff --git a/internal/clientcache/internal/cache/search.go b/internal/clientcache/internal/cache/search.go index a26aa1a9c9..0b9ae252f8 100644 --- a/internal/clientcache/internal/cache/search.go +++ b/internal/clientcache/internal/cache/search.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/hashicorp/boundary/api/aliases" + "github.com/hashicorp/boundary/api/scopes" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/errors" @@ -23,11 +24,12 @@ const ( ResolvableAliases SearchableResource = "resolvable-aliases" Targets SearchableResource = "targets" Sessions SearchableResource = "sessions" + ImplicitScopes SearchableResource = "implicit-scopes" ) func (r SearchableResource) Valid() bool { switch r { - case ResolvableAliases, Targets, Sessions: + case ResolvableAliases, Targets, Sessions, ImplicitScopes: return true } return false @@ -41,6 +43,8 @@ func ToSearchableResource(s string) SearchableResource { return Targets case strings.EqualFold(s, string(Sessions)): return Sessions + case strings.EqualFold(s, string(ImplicitScopes)): + return ImplicitScopes } return Unknown } @@ -64,6 +68,7 @@ type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` Targets []*targets.Target `json:"targets,omitempty"` Sessions []*sessions.Session `json:"sessions,omitempty"` + ImplicitScopes []*scopes.Scope `json:"implicit_scopes,omitempty"` // Incomplete is true if the search results are incomplete, that is, we are // returning only a subset based on the max result set size @@ -125,6 +130,19 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er in.Sessions = finalResults }, }, + ImplicitScopes: &resourceSearchFns[*scopes.Scope]{ + list: repo.ListImplicitScopes, + query: repo.QueryImplicitScopes, + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*scopes.Scope, 0, len(in.ImplicitScopes)) + for _, item := range in.ImplicitScopes { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.ImplicitScopes = finalResults + }, + }, }, }, nil } diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index f450a76491..2fbdfbd300 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/aliases" + "github.com/hashicorp/boundary/api/scopes" "github.com/hashicorp/boundary/api/sessions" "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/internal/clientcache/internal/cache" @@ -26,6 +27,7 @@ type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` Targets []*targets.Target `json:"targets,omitempty"` Sessions []*sessions.Session `json:"sessions,omitempty"` + ImplicitScopes []*scopes.Scope `json:"implicit_scopes,omitempty"` Incomplete bool `json:"incomplete,omitempty"` } @@ -61,6 +63,8 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe authTokenId := q.Get(authTokenIdKey) maxResultSetSizeStr := q.Get(maxResultSetSizeKey) maxResultSetSizeInt, maxResultSetSizeIntErr := strconv.Atoi(maxResultSetSizeStr) + query := q.Get(queryKey) + filter := q.Get(filterKey) searchableResource := cache.ToSearchableResource(resource) switch { @@ -84,6 +88,18 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s must be greater than or equal to -1", maxResultSetSizeStr))) writeError(w, fmt.Sprintf("%s must be greater than or equal to -1", maxResultSetSizeStr), http.StatusBadRequest) return + case searchableResource == cache.ImplicitScopes && maxResultSetSizeStr != "": + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("max result set size is not supported for resource %q", resource))) + writeError(w, fmt.Sprintf("max result set size is not supported for resource %q", resource), http.StatusBadRequest) + return + case searchableResource == cache.ImplicitScopes && query != "": + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("query is not supported for resource %q", resource))) + writeError(w, fmt.Sprintf("query is not supported for resource %q", resource), http.StatusBadRequest) + return + case searchableResource == cache.ImplicitScopes && filter != "": + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("filter is not supported for resource %q", resource))) + writeError(w, fmt.Sprintf("filter is not supported for resource %q", resource), http.StatusBadRequest) + return } t, err := repo.LookupToken(reqCtx, authTokenId, cache.WithUpdateLastAccessedTime(true)) @@ -118,14 +134,15 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe // Refresh the resources for the provided user, if possible. This is best // effort, so if there is any problem refreshing, we just log the error // and move on to handling the search request. - if err := refreshService.RefreshForSearch(reqCtx, authTokenId, searchableResource, opts...); err != nil { - // we don't stop the search, we just log that the inline refresh failed - event.WriteError(ctx, op, err, event.WithInfoMsg("when refreshing the resources inline for search", "auth_token_id", authTokenId, "resource", searchableResource)) + switch searchableResource { + case cache.ImplicitScopes: + default: + if err := refreshService.RefreshForSearch(reqCtx, authTokenId, searchableResource, opts...); err != nil { + // we don't stop the search, we just log that the inline refresh failed + event.WriteError(ctx, op, err, event.WithInfoMsg("when refreshing the resources inline for search", "auth_token_id", authTokenId, "resource", searchableResource)) + } } - query := r.URL.Query().Get(queryKey) - filter := r.URL.Query().Get(filterKey) - res, err := s.Search(reqCtx, cache.SearchParams{ AuthTokenId: authTokenId, Resource: searchableResource, @@ -166,6 +183,7 @@ func toApiResult(sr *cache.SearchResult) *SearchResult { ResolvableAliases: sr.ResolvableAliases, Targets: sr.Targets, Sessions: sr.Sessions, + ImplicitScopes: sr.ImplicitScopes, Incomplete: sr.Incomplete, } } From 829ea02e6f5aed8f124785edf5889c4d52e5c8d6 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 29 Aug 2024 17:48:10 -0400 Subject: [PATCH 04/15] Ensure only one refresh is running for a user/resource at a time (#5054) This ensures we don't have competing update processes running. When doing the inline RefreshForSearch, a specific error will be returned so we can hint in the API that we're in the middle of a refresh. (cherry picked from commit 54776091e32af2f2f110f2bf10d0617bd5677aba) --- .../clientcache/cmd/search/search_test.go | 8 +- .../clientcache/internal/cache/options.go | 16 ++ .../internal/cache/options_test.go | 12 ++ .../clientcache/internal/cache/refresh.go | 118 ++++++++-- .../internal/cache/refresh_test.go | 203 ++++++++++++++++++ .../internal/daemon/search_handler.go | 38 +++- 6 files changed, 370 insertions(+), 25 deletions(-) diff --git a/internal/clientcache/cmd/search/search_test.go b/internal/clientcache/cmd/search/search_test.go index d36ea5fa64..3718e49479 100644 --- a/internal/clientcache/cmd/search/search_test.go +++ b/internal/clientcache/cmd/search/search_test.go @@ -140,7 +140,9 @@ func TestSearch(t *testing.T) { assert.Nil(t, apiErr) assert.NotNil(t, resp) assert.NotNil(t, r) - assert.EqualValues(t, r, &daemon.SearchResult{}) + assert.EqualValues(t, r, &daemon.SearchResult{ + RefreshStatus: daemon.NotRefreshing, + }) }) t.Run("empty response from query", func(t *testing.T) { @@ -154,7 +156,9 @@ func TestSearch(t *testing.T) { assert.Nil(t, apiErr) assert.NotNil(t, resp) assert.NotNil(t, r) - assert.EqualValues(t, r, &daemon.SearchResult{}) + assert.EqualValues(t, r, &daemon.SearchResult{ + RefreshStatus: daemon.NotRefreshing, + }) }) t.Run("unsupported boundary instance", func(t *testing.T) { diff --git a/internal/clientcache/internal/cache/options.go b/internal/clientcache/internal/cache/options.go index 09e34ee9d8..36907f8833 100644 --- a/internal/clientcache/internal/cache/options.go +++ b/internal/clientcache/internal/cache/options.go @@ -9,6 +9,11 @@ import ( "github.com/hashicorp/go-dbw" ) +type testRefreshWaitChs struct { + firstSempahore chan struct{} + secondSemaphore chan struct{} +} + type options struct { withUpdateLastAccessedTime bool withDbType dbw.DbType @@ -19,6 +24,7 @@ type options struct { withSessionRetrievalFunc SessionRetrievalFunc withIgnoreSearchStaleness bool withMaxResultSetSize int + withTestRefreshWaitChs *testRefreshWaitChs } // Option - how options are passed as args @@ -113,3 +119,13 @@ func WithMaxResultSetSize(with int) Option { return nil } } + +// WithTestRefreshWaitChs provides an option for specifying channels to wait on +// before proceeding. This allows testing the logic that ensures only one is +// running at a time. +func WithTestRefreshWaitChs(with *testRefreshWaitChs) Option { + return func(o *options) error { + o.withTestRefreshWaitChs = with + return nil + } +} diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index aed96cd5ba..44074fdbe6 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -106,4 +106,16 @@ func Test_GetOpts(t *testing.T) { _, err = getOpts(WithMaxResultSetSize(-2)) require.Error(t, err) }) + t.Run("withTestRefreshWaitChs", func(t *testing.T) { + waitCh := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + opts, err := getOpts(WithTestRefreshWaitChs(waitCh)) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.Empty(t, testOpts.withTestRefreshWaitChs) + testOpts.withTestRefreshWaitChs = waitCh + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/clientcache/internal/cache/refresh.go b/internal/clientcache/internal/cache/refresh.go index b9a77a5372..6cc294be09 100644 --- a/internal/clientcache/internal/cache/refresh.go +++ b/internal/clientcache/internal/cache/refresh.go @@ -7,6 +7,8 @@ import ( "context" stderrors "errors" "fmt" + "sync" + "sync/atomic" "time" "github.com/hashicorp/boundary/api" @@ -17,9 +19,19 @@ import ( "github.com/hashicorp/go-hclog" ) +// This is used as an internal error to indicate that a refresh is already in +// progress, as a signal that the caller may want to handle it a different way. +var ErrRefreshInProgress error = stderrors.New("cache refresh in progress") + type RefreshService struct { repo *Repository + // syncSempaphores is used to prevent multiple refreshes from happening at + // the same time for a given user ID and resource type. The key is a string + // (the user ID + resource type) and the value is an atomic bool we can CAS + // with. + syncSemaphores sync.Map + logger hclog.Logger // the amount of time that should have passed since the last refresh for a // call to RefreshForSearch will cause a refresh request to be sent to a @@ -135,21 +147,25 @@ func (r *RefreshService) cacheSupportedUsers(ctx context.Context, in []*user) ([ return ret, nil } -// RefreshForSearch refreshes a specific resource type owned by the user associated -// with the provided auth token id, if it hasn't been refreshed in a certain -// amount of time. If the resources has been updated recently, or if it is known -// that the user's boundary instance doesn't support partial refreshing, this -// method returns without any requests to the boundary controller. -// While the criteria of whether the user will refresh or not is almost the same -// as the criteria used in Refresh, RefreshForSearch will not refresh any -// data if there is not a refresh token stored for the resource. It -// might make sense to change this in the future, but the reasoning used is -// that we should not be making an initial load of all resources while blocking -// a search query, in case we have not yet even attempted to load the resources -// for this user yet. -// Note: Currently, if the context timesout we stop refreshing completely and -// return to the caller. A possible enhancement in the future would be to return -// when the context is Done, but allow the refresh to proceed in the background. +// RefreshForSearch refreshes a specific resource type owned by the user +// associated with the provided auth token id, if it hasn't been refreshed in a +// certain amount of time. If the resources has been updated recently, or if it +// is known that the user's boundary instance doesn't support partial +// refreshing, this method returns without any requests to the boundary +// controller. While the criteria of whether the user will refresh or not is +// almost the same as the criteria used in Refresh, RefreshForSearch will not +// refresh any data if there is not a refresh token stored for the resource. It +// might make sense to change this in the future, but the reasoning used is that +// we should not be making an initial load of all resources while blocking a +// search query, in case we have not yet even attempted to load the resources +// for this user yet. Note: Currently, if the context timesout we stop +// refreshing completely and return to the caller. A possible enhancement in +// the future would be to return when the context is Done, but allow the refresh +// to proceed in the background. +// +// If a refresh is already running for that user and resource type we will +// return ErrRefreshInProgress. If so we will not start another one and we can +// make it clear in the response that cache filling is ongoing. func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid string, resourceType SearchableResource, opt ...Option) error { const op = "cache.(RefreshService).RefreshForSearch" if r.maxSearchRefreshTimeout > 0 { @@ -191,6 +207,8 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op, errors.WithoutEvent()) } + cacheKey := fmt.Sprintf("%s-%s", u.Id, resourceType) + switch resourceType { case ResolvableAliases: rtv, err := r.repo.lookupRefreshToken(ctx, u, resolvableAliasResourceType) @@ -198,6 +216,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op) } if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness { + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + return ErrRefreshInProgress + } + defer semaphore.(*atomic.Bool).Store(false) tokens, err := r.cleanAndPickAuthTokens(ctx, u) if err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -206,6 +229,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin if rtv != nil { args = append(args, "alias staleness", time.Since(rtv.UpdateTime)) } + + if opts.withTestRefreshWaitChs != nil { + close(opts.withTestRefreshWaitChs.firstSempahore) + <-opts.withTestRefreshWaitChs.secondSemaphore + } + r.logger.Debug("refreshing aliases before performing search", args...) if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil { return errors.Wrap(ctx, err, op) @@ -217,6 +246,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op, errors.WithoutEvent()) } if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness { + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + return ErrRefreshInProgress + } + defer semaphore.(*atomic.Bool).Store(false) + tokens, err := r.cleanAndPickAuthTokens(ctx, u) if err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -225,6 +260,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin if rtv != nil { args = append(args, "target staleness", time.Since(rtv.UpdateTime)) } + + if opts.withTestRefreshWaitChs != nil { + close(opts.withTestRefreshWaitChs.firstSempahore) + <-opts.withTestRefreshWaitChs.secondSemaphore + } + r.logger.Debug("refreshing targets before performing search", args...) if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -236,6 +277,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op) } if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness { + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + return ErrRefreshInProgress + } + defer semaphore.(*atomic.Bool).Store(false) + tokens, err := r.cleanAndPickAuthTokens(ctx, u) if err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -244,6 +291,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin if rtv != nil { args = append(args, "session staleness", time.Since(rtv.UpdateTime)) } + + if opts.withTestRefreshWaitChs != nil { + close(opts.withTestRefreshWaitChs.firstSempahore) + <-opts.withTestRefreshWaitChs.secondSemaphore + } + r.logger.Debug("refreshing sessions before performing search", args...) if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -262,6 +315,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin // cache with the values retrieved there. Refresh accepts the options // WithTargetRetrievalFunc and WithSessionRetrievalFunc which overwrites the // default functions used to retrieve those resources from boundary. +// +// This shares the map of sync semaphores with RefreshForSearch, so if a refresh +// is already happening via either method, it will be skipped by the other. In +// the case of this function it won't return ErrRefreshInProgress, but it will +// log that the refresh is already in progress. func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error { const op = "cache.(RefreshService).Refresh" if err := r.repo.cleanExpiredOrOrphanedAuthTokens(ctx); err != nil { @@ -290,16 +348,32 @@ func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error { continue } - if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) - } - if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + cacheKey := fmt.Sprintf("%s-%s", u.Id, ResolvableAliases) + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + semaphore.(*atomic.Bool).Store(false) } - if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + + cacheKey = fmt.Sprintf("%s-%s", u.Id, Targets) + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + semaphore.(*atomic.Bool).Store(false) } + cacheKey = fmt.Sprintf("%s-%s", u.Id, Sessions) + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + semaphore.(*atomic.Bool).Store(false) + } } return retErr } diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index a24dedf198..bd0905de93 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -803,6 +803,209 @@ func TestRefreshForSearch(t *testing.T) { }) } +func TestRefreshNonBlocking(t *testing.T) { + ctx := context.Background() + + boundaryAddr := "address" + u := &user{Id: "u1", Address: boundaryAddr} + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + ExpirationTime: time.Now().Add(time.Minute), + } + + boundaryAuthTokens := []*authtokens.AuthToken{at} + atMap := make(map[ringToken]*authtokens.AuthToken) + + atMap[ringToken{"k", "t"}] = at + + t.Run("targets refreshed for searching", func(t *testing.T) { + t.Parallel() + s, err := db.Open(ctx) + require.NoError(t, err) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + rs, err := NewRefreshService(ctx, r, hclog.NewNullLogger(), time.Millisecond, 0) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) + + retTargets := []*targets.Target{ + target("1"), + target("2"), + target("3"), + target("4"), + } + opts := []Option{ + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + [][]*targets.Target{ + retTargets[:3], + retTargets[3:], + }, + [][]string{ + nil, + {retTargets[0].Id, retTargets[1].Id}, + }, + )), + } + + refreshWaitChs := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + wg := new(sync.WaitGroup) + wg.Add(2) + extraOpts := []Option{WithTestRefreshWaitChs(refreshWaitChs), WithIgnoreSearchStaleness(true)} + go func() { + defer wg.Done() + blockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Targets, append(opts, extraOpts...)...) + assert.NoError(t, blockingRefreshError) + }() + go func() { + defer wg.Done() + // Sleep here to ensure ordering of the calls since both goroutines + // are spawned at the same time + <-refreshWaitChs.firstSempahore + nonblockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Targets, append(opts, extraOpts...)...) + close(refreshWaitChs.secondSemaphore) + assert.ErrorIs(t, nonblockingRefreshError, ErrRefreshInProgress) + }() + wg.Wait() + + // Unlike in the TestRefreshForSearch test, since we did a force + // refresh we do expect to see values + cachedTargets, err := r.ListTargets(ctx, at.Id) + assert.NoError(t, err) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) + }) + + t.Run("sessions refreshed for searching", func(t *testing.T) { + t.Parallel() + s, err := db.Open(ctx) + require.NoError(t, err) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + rs, err := NewRefreshService(ctx, r, hclog.NewNullLogger(), 0, 0) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) + + retSess := []*sessions.Session{ + session("1"), + session("2"), + session("3"), + session("4"), + } + opts := []Option{ + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + [][]*sessions.Session{ + retSess[:3], + retSess[3:], + }, + [][]string{ + nil, + {retSess[0].Id, retSess[1].Id}, + }, + )), + } + + refreshWaitChs := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + wg := new(sync.WaitGroup) + wg.Add(2) + extraOpts := []Option{WithTestRefreshWaitChs(refreshWaitChs), WithIgnoreSearchStaleness(true)} + go func() { + defer wg.Done() + blockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Sessions, append(opts, extraOpts...)...) + assert.NoError(t, blockingRefreshError) + }() + go func() { + defer wg.Done() + // Sleep here to ensure ordering of the calls since both goroutines + // are spawned at the same time + <-refreshWaitChs.firstSempahore + nonblockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Sessions, append(opts, extraOpts...)...) + close(refreshWaitChs.secondSemaphore) + assert.ErrorIs(t, nonblockingRefreshError, ErrRefreshInProgress) + }() + + wg.Wait() + + // Unlike in the TestRefreshForSearch test, since we are did a force + // refresh we do expect to see values + cachedSessions, err := r.ListSessions(ctx, at.Id) + assert.NoError(t, err) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) + }) + + t.Run("aliases refreshed for searching", func(t *testing.T) { + t.Parallel() + s, err := db.Open(ctx) + require.NoError(t, err) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + rs, err := NewRefreshService(ctx, r, hclog.Default(), 0, 0) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) + + retAl := []*aliases.Alias{ + alias("1"), + alias("2"), + alias("3"), + alias("4"), + } + opts := []Option{ + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + [][]*aliases.Alias{ + retAl[:3], + retAl[3:], + }, + [][]string{ + nil, + {retAl[0].Id, retAl[1].Id}, + }, + )), + } + + refreshWaitChs := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + wg := new(sync.WaitGroup) + wg.Add(2) + extraOpts := []Option{WithTestRefreshWaitChs(refreshWaitChs), WithIgnoreSearchStaleness(true)} + go func() { + defer wg.Done() + blockingRefreshError := rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, append(opts, extraOpts...)...) + assert.NoError(t, blockingRefreshError) + }() + go func() { + defer wg.Done() + // Sleep here to ensure ordering of the calls since both goroutines + // are spawned at the same time + <-refreshWaitChs.firstSempahore + nonblockingRefreshError := rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, append(opts, extraOpts...)...) + close(refreshWaitChs.secondSemaphore) + assert.ErrorIs(t, nonblockingRefreshError, ErrRefreshInProgress) + }() + + wg.Wait() + + // Unlike in the TestRefreshForSearch test, since we are did a force + // refresh we do expect to see values + cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) + require.NoError(t, err) + assert.ElementsMatch(t, retAl[:3], cachedAliases.ResolvableAliases) + }) +} + func TestRefresh(t *testing.T) { ctx := context.Background() diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index 2fbdfbd300..97dffa5cc2 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -6,6 +6,7 @@ package daemon import ( "context" "encoding/json" + stderrors "errors" "fmt" "net/http" "strconv" @@ -22,6 +23,22 @@ import ( "github.com/hashicorp/go-hclog" ) +type RefreshStatus string + +const ( + // Not refreshing means the result is complete, that is, the cache is not in + // the process of being built or refreshed + NotRefreshing RefreshStatus = "not-refreshing" + // Refreshing means that the cache is in the process of being refreshed, so + // the result may not be complete and the caller should try again later for + // more complete results + Refreshing RefreshStatus = "refreshing" + // RefreshError means that there was an error refreshing the cache. It says + // nothing about the completeness of the result, only that when attempting + // to refresh the cache in-line with the search an error was encountered. + RefreshError RefreshStatus = "refresh-error" +) + // SearchResult is the struct returned to search requests. type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` @@ -29,6 +46,8 @@ type SearchResult struct { Sessions []*sessions.Session `json:"sessions,omitempty"` ImplicitScopes []*scopes.Scope `json:"implicit_scopes,omitempty"` Incomplete bool `json:"incomplete,omitempty"` + RefreshStatus RefreshStatus `json:"refresh_status,omitempty"` + RefreshError string `json:"refresh_error,omitempty"` } const ( @@ -131,13 +150,20 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe opts = append(opts, cache.WithIgnoreSearchStaleness(true)) } + var refreshError error // Refresh the resources for the provided user, if possible. This is best // effort, so if there is any problem refreshing, we just log the error // and move on to handling the search request. switch searchableResource { case cache.ImplicitScopes: + // This is not able to be refreshed, so continue on default: - if err := refreshService.RefreshForSearch(reqCtx, authTokenId, searchableResource, opts...); err != nil { + refreshError = refreshService.RefreshForSearch(reqCtx, authTokenId, searchableResource, opts...) + switch { + case refreshError == nil, + stderrors.Is(refreshError, cache.ErrRefreshInProgress): + // Don't event in these cases + default: // we don't stop the search, we just log that the inline refresh failed event.WriteError(ctx, op, err, event.WithInfoMsg("when refreshing the resources inline for search", "auth_token_id", authTokenId, "resource", searchableResource)) } @@ -166,6 +192,16 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe } apiRes := toApiResult(res) + switch { + case refreshError == nil: + apiRes.RefreshStatus = NotRefreshing + case stderrors.Is(refreshError, cache.ErrRefreshInProgress): + apiRes.RefreshStatus = Refreshing + default: + apiRes.RefreshStatus = RefreshError + apiRes.RefreshError = refreshError.Error() + } + j, err := json.Marshal(apiRes) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("when marshaling search result to JSON", "auth_token_id", authTokenId, "resource", searchableResource, "query", query, "filter", filter)) From 863359aa26a1b0ca1b76419709e0ab917ddcd7e9 Mon Sep 17 00:00:00 2001 From: Jim Date: Fri, 6 Sep 2024 10:32:47 -0400 Subject: [PATCH 05/15] test (cache/search): fix NewTestServer(...) storage (#5081) Fix ensures that the TestServer returned uses the correct directory for writing its sqlite files, which was created via t.TempDir() (cherry picked from commit 85f91043a164bdacc1c6738ae0e7c50077d55c79) --- internal/clientcache/cmd/search/search_test.go | 4 ++-- internal/clientcache/internal/daemon/testing.go | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/clientcache/cmd/search/search_test.go b/internal/clientcache/cmd/search/search_test.go index 3718e49479..1e5e134df2 100644 --- a/internal/clientcache/cmd/search/search_test.go +++ b/internal/clientcache/cmd/search/search_test.go @@ -140,9 +140,9 @@ func TestSearch(t *testing.T) { assert.Nil(t, apiErr) assert.NotNil(t, resp) assert.NotNil(t, r) - assert.EqualValues(t, r, &daemon.SearchResult{ + assert.EqualValues(t, &daemon.SearchResult{ RefreshStatus: daemon.NotRefreshing, - }) + }, r) }) t.Run("empty response from query", func(t *testing.T) { diff --git a/internal/clientcache/internal/daemon/testing.go b/internal/clientcache/internal/daemon/testing.go index b29ffc6b7d..7faa91a8c2 100644 --- a/internal/clientcache/internal/daemon/testing.go +++ b/internal/clientcache/internal/daemon/testing.go @@ -37,6 +37,9 @@ func NewTestServer(t *testing.T, cmd Commander, opt ...Option) *TestServer { RecheckSupportInterval: DefaultRecheckSupportInterval, LogWriter: io.Discard, DotDirectory: dotDir, + // we need to provide this, otherwise it will open a store in the user's + // home dir. See db.Open(...) + DatabaseUrl: dotDir + "cache.db?_pragma=foreign_keys(1)", } s, err := New(ctx, cfg) From 85afd56266be828267f27599771d55d55fce7599 Mon Sep 17 00:00:00 2001 From: Jim Date: Mon, 9 Sep 2024 12:38:18 -0400 Subject: [PATCH 06/15] fix (cache): support for DB debug logging (#5087) (cherry picked from commit a7fe48b8d827d1bf68ff8eea205428dd5faae381) --- go.mod | 2 +- go.sum | 4 +-- .../clientcache/internal/daemon/server.go | 26 ++++++++++++-- .../internal/daemon/server_test.go | 36 +++++++++++++++++-- internal/clientcache/internal/db/db.go | 7 ++-- 5 files changed, 66 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 2864f17342..186a0bf079 100644 --- a/go.mod +++ b/go.mod @@ -81,7 +81,7 @@ require ( nhooyr.io/websocket v1.8.10 ) -require github.com/hashicorp/go-dbw v0.1.4 +require github.com/hashicorp/go-dbw v0.1.5-0.20240909162114-6cee92b3da36 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 diff --git a/go.sum b/go.sum index e3c04e738d..828d0d4a71 100644 --- a/go.sum +++ b/go.sum @@ -201,8 +201,8 @@ github.com/hashicorp/go-bexpr v0.1.13 h1:HNwp7vZrMpRq8VZXj8VF90LbZpRjQQpim1oJF0D github.com/hashicorp/go-bexpr v0.1.13/go.mod h1:gN7hRKB3s7yT+YvTdnhZVLTENejvhlkZ8UE4YVBS+Q8= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= -github.com/hashicorp/go-dbw v0.1.4 h1:FipsEg7UbWiB8hBQylTYWJobG6jVr0LNoRdo9IClGs4= -github.com/hashicorp/go-dbw v0.1.4/go.mod h1:/YHbfK7mgG9k09aB74Imw3fEOwno0eTtlFTTYGZ7SFk= +github.com/hashicorp/go-dbw v0.1.5-0.20240909162114-6cee92b3da36 h1:rPD+2QPhCLq8mKMx2FnIaqR5PTNT+LzhjfacYWuvFzY= +github.com/hashicorp/go-dbw v0.1.5-0.20240909162114-6cee92b3da36/go.mod h1:/YHbfK7mgG9k09aB74Imw3fEOwno0eTtlFTTYGZ7SFk= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-kms-wrapping/extras/kms/v2 v2.0.0-20231219183231-6bac757bb482 h1:1DqTnLaNk658AEenlF4PNGYd9b1hXE/+0jSOBIGOAms= diff --git a/internal/clientcache/internal/daemon/server.go b/internal/clientcache/internal/daemon/server.go index cd0f13e77c..cc642629b7 100644 --- a/internal/clientcache/internal/daemon/server.go +++ b/internal/clientcache/internal/daemon/server.go @@ -229,7 +229,8 @@ func (s *CacheServer) Serve(ctx context.Context, cmd Commander, opt ...Option) e var store *db.DB store, err = openStore(ctx, WithUrl(ctx, s.conf.DatabaseUrl), - WithLogger(ctx, s.logger)) + WithLogger(ctx, s.logger), + ) if err != nil { return errors.Wrap(ctx, err, op) } @@ -527,7 +528,14 @@ func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { dbOpts = append(dbOpts, cachedb.WithUrl(url)) } if !util.IsNil(opts.withLogger) { - dbOpts = append(dbOpts, cachedb.WithGormFormatter(opts.withLogger)) + opts.withLogger.Log(hclog.Debug, "Store GormFormatter", "LogLevel", opts.withLogger.GetLevel()) + switch { + case opts.withLogger.IsDebug(): + dbOpts = append(dbOpts, cachedb.WithGormFormatter(gormDebugLogger{Logger: opts.withLogger})) + dbOpts = append(dbOpts, cachedb.WithDebug(true)) + default: + dbOpts = append(dbOpts, cachedb.WithGormFormatter(opts.withLogger)) + } } store, err := cachedb.Open(ctx, dbOpts...) if err != nil { @@ -563,3 +571,17 @@ const ( dbFileName = "cache.db" fkPragma = "?_pragma=foreign_keys(1)" ) + +type gormDebugLogger struct { + hclog.Logger +} + +func (g gormDebugLogger) Printf(msg string, values ...any) { + b := new(strings.Builder) + fmt.Fprintf(b, msg, values...) + g.Debug(b.String()) +} + +func getGormLogger(log hclog.Logger) gormDebugLogger { + return gormDebugLogger{Logger: log} +} diff --git a/internal/clientcache/internal/daemon/server_test.go b/internal/clientcache/internal/daemon/server_test.go index 2f7e947c90..660f73dd3f 100644 --- a/internal/clientcache/internal/daemon/server_test.go +++ b/internal/clientcache/internal/daemon/server_test.go @@ -5,12 +5,16 @@ package daemon import ( "context" + "strings" + "sync" "testing" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/roles" "github.com/hashicorp/boundary/internal/cmd/base" "github.com/hashicorp/boundary/internal/daemon/controller" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,10 +23,12 @@ func Test_openStore(t *testing.T) { ctx := context.Background() t.Run("success", func(t *testing.T) { tmpDir := t.TempDir() - db, err := openStore(ctx, WithUrl(ctx, tmpDir+"/test.db"+fkPragma)) + store, err := openStore(ctx, WithUrl(ctx, tmpDir+"/test.db"+fkPragma)) require.NoError(t, err) - require.NotNil(t, db) + require.NotNil(t, store) assert.FileExists(t, tmpDir+"/test.db") + rw := db.New(store) + rw.Query(ctx, "select * from target", nil) }) t.Run("homedir", func(t *testing.T) { tmpDir := t.TempDir() @@ -31,6 +37,32 @@ func Test_openStore(t *testing.T) { require.NotNil(t, db) assert.FileExists(t, tmpDir+"/"+dotDirname+"/"+dbFileName) }) + t.Run("log-level-debug", func(t *testing.T) { + buf := new(strings.Builder) + testLock := &sync.Mutex{} + testLogger := hclog.New(&hclog.LoggerOptions{ + Mutex: testLock, + Name: "test", + JSONFormat: true, + Output: buf, + Level: hclog.Debug, + }) + tmpDir := t.TempDir() + store, err := openStore(ctx, + WithUrl(ctx, tmpDir+"/test.db"+fkPragma), + WithLogger(ctx, testLogger), + ) + require.NoError(t, err) + require.NotNil(t, store) + assert.FileExists(t, tmpDir+"/test.db") + rw := db.New(store) + + rows, err := rw.Query(ctx, "select * from target", nil) + require.NoError(t, err) + defer rows.Close() + assert.Contains(t, buf.String(), "select * from target") + t.Log(buf.String()) + }) } // Note: the name of this test must remain short because the temp dir created diff --git a/internal/clientcache/internal/db/db.go b/internal/clientcache/internal/db/db.go index 974faa7fbf..96443f533c 100644 --- a/internal/clientcache/internal/db/db.go +++ b/internal/clientcache/internal/db/db.go @@ -54,8 +54,11 @@ func Open(ctx context.Context, opt ...Option) (*db.DB, error) { if err != nil { return nil, errors.Wrap(ctx, err, op) } - conn.Debug(opts.withDebug) - + defer func() { + // let's not capture the output of resetSchema and createTables, so + // we'll defer turning on debug + conn.Debug(opts.withDebug) + }() switch { case opts.withDbType == dbw.Sqlite && url == DefaultStoreUrl: if err := createTables(ctx, conn); err != nil { From 2a87dcad4803925cd99ce99bacc8165212184ede Mon Sep 17 00:00:00 2001 From: Jim Date: Mon, 9 Sep 2024 14:32:06 -0400 Subject: [PATCH 07/15] fix (cache/db): add indexes for desktop queries (#5091) (cherry picked from commit bfdbae1f042c4eb1b05e0373e1aaf2db55c9bb3b) --- internal/clientcache/internal/db/db.go | 2 +- internal/clientcache/internal/db/schema.sql | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/internal/clientcache/internal/db/db.go b/internal/clientcache/internal/db/db.go index 96443f533c..4690c6e10b 100644 --- a/internal/clientcache/internal/db/db.go +++ b/internal/clientcache/internal/db/db.go @@ -146,7 +146,7 @@ type schema struct { const ( schemaTableName = "schema_version" - schemaCurrentVersion = "v0.0.1" + schemaCurrentVersion = "v0.0.2" ) // TableName returns the table name diff --git a/internal/clientcache/internal/db/schema.sql b/internal/clientcache/internal/db/schema.sql index ddd635b2d1..3806637716 100644 --- a/internal/clientcache/internal/db/schema.sql +++ b/internal/clientcache/internal/db/schema.sql @@ -34,7 +34,7 @@ when end; -insert into schema_version(version) values('v0.0.1'); +insert into schema_version(version) values('v0.0.2'); -- user contains the boundary user information for the boundary user that owns -- the information in the cache. @@ -163,6 +163,9 @@ create table if not exists target ( primary key (fk_user_id, id) ); +-- index for implicit scope search +create index target_scope_id_ix on target(scope_id); + -- session contains cached boundary session resource for a specific user and -- with specific fields extracted to facilitate searching over those fields create table if not exists session ( @@ -190,6 +193,9 @@ create table if not exists session ( primary key (fk_user_id, id) ); +-- implicit scope search +create index session_scope_id_ix on session(scope_id); + -- alias contains cached boundary alias resource for a specific user and -- with specific fields extracted to facilitate searching over those fields create table if not exists resolvable_alias ( @@ -211,6 +217,9 @@ create table if not exists resolvable_alias ( primary key (fk_user_id, id) ); +-- optimize query for destination_id +create index destination_id_resolvable_alias_ix on resolvable_alias(destination_id); + -- contains errors from the last attempt to sync data from boundary for a -- specific resource type create table if not exists api_error ( From 3a4ee032d954c6ed01f3539d9667ad70882111b7 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 12 Sep 2024 15:56:33 -0400 Subject: [PATCH 08/15] Paginate targets in cache (#5101) * Update Go API to support pagination * Implement pagination of targets in cache (cherry picked from commit cd25fad42f325fc050b7ea3bec3336780d2b5e16) --- api/accounts/account.gen.go | 187 ++++++++++----- api/accounts/option.gen.go | 33 ++- api/aliases/alias.gen.go | 196 ++++++++++----- api/aliases/option.gen.go | 34 ++- api/authmethods/authmethods.gen.go | 196 ++++++++++----- api/authmethods/option.gen.go | 34 ++- api/authtokens/authtokens.gen.go | 196 ++++++++++----- api/authtokens/option.gen.go | 34 ++- api/billing/option.gen.go | 33 ++- .../credential_library.gen.go | 187 ++++++++++----- api/credentiallibraries/option.gen.go | 33 ++- api/credentials/credential.gen.go | 187 ++++++++++----- api/credentials/option.gen.go | 33 ++- api/credentialstores/credential_store.gen.go | 196 ++++++++++----- api/credentialstores/option.gen.go | 34 ++- api/groups/group.gen.go | 196 ++++++++++----- api/groups/option.gen.go | 34 ++- api/hostcatalogs/host_catalog.gen.go | 196 ++++++++++----- api/hostcatalogs/option.gen.go | 34 ++- api/hosts/host.gen.go | 187 ++++++++++----- api/hosts/option.gen.go | 33 ++- api/hostsets/host_set.gen.go | 187 ++++++++++----- api/hostsets/option.gen.go | 33 ++- api/managedgroups/managedgroups.gen.go | 187 ++++++++++----- api/managedgroups/option.gen.go | 33 ++- api/policies/option.gen.go | 34 ++- api/policies/policy.gen.go | 196 ++++++++++----- api/roles/option.gen.go | 34 ++- api/roles/role.gen.go | 196 ++++++++++----- api/scopes/option.gen.go | 34 ++- api/scopes/scope.gen.go | 196 ++++++++++----- api/sessionrecordings/option.gen.go | 34 ++- .../session_recording.gen.go | 196 ++++++++++----- api/sessions/option.gen.go | 34 ++- api/sessions/session.gen.go | 196 ++++++++++----- api/storagebuckets/option.gen.go | 34 ++- api/storagebuckets/storage_bucket.gen.go | 196 ++++++++++----- api/targets/option.gen.go | 34 ++- api/targets/target.gen.go | 196 ++++++++++----- api/users/option.gen.go | 34 ++- api/users/user.gen.go | 196 ++++++++++----- api/workers/option.gen.go | 34 ++- api/workers/worker.gen.go | 111 +-------- internal/api/genapi/input.go | 5 + internal/api/genapi/templates.go | 223 +++++++++++++----- internal/clientcache/cmd/cache/start.go | 9 + .../clientcache/internal/cache/options.go | 10 + .../internal/cache/options_test.go | 12 +- .../clientcache/internal/cache/refresh.go | 51 ++-- .../internal/cache/refresh_test.go | 103 ++++---- .../cache/repository_implicit_scopes_test.go | 2 +- .../internal/cache/repository_targets.go | 119 ++++++---- .../internal/cache/repository_targets_test.go | 34 +-- .../clientcache/internal/cache/status_test.go | 6 +- .../clientcache/internal/daemon/options.go | 16 +- .../internal/daemon/options_test.go | 8 + .../clientcache/internal/daemon/server.go | 4 + .../clientcache/internal/daemon/testing.go | 20 +- internal/clientcache/internal/db/db.go | 2 +- internal/clientcache/internal/db/options.go | 19 +- .../clientcache/internal/db/options_test.go | 8 + internal/tests/api/targets/target_test.go | 33 +++ 62 files changed, 3872 insertions(+), 1530 deletions(-) diff --git a/api/accounts/account.gen.go b/api/accounts/account.gen.go index 98245e6998..c55e8b1cb7 100644 --- a/api/accounts/account.gen.go +++ b/api/accounts/account.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -68,6 +69,13 @@ type AccountListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + authMethodId string + allRemovedIds []string } func (n AccountListResult) GetItems() []*Account { @@ -340,96 +348,91 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.pageSize = opts.withPageSize + target.authMethodId = authMethodId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Account, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "accounts", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(AccountListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Account that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Account has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Account) int { + slices.SortFunc(allItems, func(i, j *Account) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -443,4 +446,80 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *AccountListResult, opt ...Option) (*AccountListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.authMethodId == "" { + return nil, fmt.Errorf("empty authMethodId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["auth_method_id"] = currentPage.authMethodId + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "accounts", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(AccountListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.authMethodId = currentPage.authMethodId + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/accounts/option.gen.go b/api/accounts/option.gen.go index b70d6950c1..384121c547 100644 --- a/api/accounts/option.gen.go +++ b/api/accounts/option.gen.go @@ -5,6 +5,7 @@ package accounts import ( + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -20,12 +21,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -52,6 +55,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -90,6 +96,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/aliases/alias.gen.go b/api/aliases/alias.gen.go index d4a25dff70..54b3803cfe 100644 --- a/api/aliases/alias.gen.go +++ b/api/aliases/alias.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -69,6 +70,13 @@ type AliasListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n AliasListResult) GetItems() []*Alias { @@ -346,96 +354,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Alia return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Alias, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "aliases", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(AliasListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Alias that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Alias has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Alias) int { + slices.SortFunc(allItems, func(i, j *Alias) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -449,4 +454,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Alia // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *AliasListResult, opt ...Option) (*AliasListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "aliases", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(AliasListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/aliases/option.gen.go b/api/aliases/option.gen.go index 6e235a94e3..9ec819435b 100644 --- a/api/aliases/option.gen.go +++ b/api/aliases/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/authmethods/authmethods.gen.go b/api/authmethods/authmethods.gen.go index 719d43fae5..afd51e5d61 100644 --- a/api/authmethods/authmethods.gen.go +++ b/api/authmethods/authmethods.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -69,6 +70,13 @@ type AuthMethodListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n AuthMethodListResult) GetItems() []*AuthMethod { @@ -346,96 +354,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*AuthMethod, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "auth-methods", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(AuthMethodListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any AuthMethod that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a AuthMethod has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *AuthMethod) int { + slices.SortFunc(allItems, func(i, j *AuthMethod) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -449,4 +454,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *AuthMethodListResult, opt ...Option) (*AuthMethodListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "auth-methods", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(AuthMethodListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/authmethods/option.gen.go b/api/authmethods/option.gen.go index 9584d61eaa..079c35a855 100644 --- a/api/authmethods/option.gen.go +++ b/api/authmethods/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/authtokens/authtokens.gen.go b/api/authtokens/authtokens.gen.go index 541456a183..ced3e321d6 100644 --- a/api/authtokens/authtokens.gen.go +++ b/api/authtokens/authtokens.gen.go @@ -10,6 +10,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -66,6 +67,13 @@ type AuthTokenListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n AuthTokenListResult) GetItems() []*AuthToken { @@ -231,96 +239,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*AuthToken, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "auth-tokens", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(AuthTokenListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any AuthToken that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a AuthToken has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *AuthToken) int { + slices.SortFunc(allItems, func(i, j *AuthToken) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -334,4 +339,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *AuthTokenListResult, opt ...Option) (*AuthTokenListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "auth-tokens", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(AuthTokenListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/authtokens/option.gen.go b/api/authtokens/option.gen.go index e1ed767edf..9fa43031a5 100644 --- a/api/authtokens/option.gen.go +++ b/api/authtokens/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -85,6 +90,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/billing/option.gen.go b/api/billing/option.gen.go index 88ddf32ed3..76babe8e89 100644 --- a/api/billing/option.gen.go +++ b/api/billing/option.gen.go @@ -6,6 +6,7 @@ package billing import ( "fmt" + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -21,12 +22,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -53,6 +56,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -91,6 +97,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithEndTime(inEndTime string) Option { return func(o *options) { o.queryMap["end_time"] = fmt.Sprintf("%v", inEndTime) diff --git a/api/credentiallibraries/credential_library.gen.go b/api/credentiallibraries/credential_library.gen.go index a298196fbe..7206eb892f 100644 --- a/api/credentiallibraries/credential_library.gen.go +++ b/api/credentiallibraries/credential_library.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -69,6 +70,13 @@ type CredentialLibraryListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + credentialStoreId string + allRemovedIds []string } func (n CredentialLibraryListResult) GetItems() []*CredentialLibrary { @@ -346,96 +354,91 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.pageSize = opts.withPageSize + target.credentialStoreId = credentialStoreId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*CredentialLibrary, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "credential-libraries", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(CredentialLibraryListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any CredentialLibrary that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a CredentialLibrary has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *CredentialLibrary) int { + slices.SortFunc(allItems, func(i, j *CredentialLibrary) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -449,4 +452,80 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialLibraryListResult, opt ...Option) (*CredentialLibraryListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.credentialStoreId == "" { + return nil, fmt.Errorf("empty credentialStoreId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["credential_store_id"] = currentPage.credentialStoreId + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "credential-libraries", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(CredentialLibraryListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.credentialStoreId = currentPage.credentialStoreId + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/credentiallibraries/option.gen.go b/api/credentiallibraries/option.gen.go index 09600aae60..a4dc80b1d4 100644 --- a/api/credentiallibraries/option.gen.go +++ b/api/credentiallibraries/option.gen.go @@ -5,6 +5,7 @@ package credentiallibraries import ( + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -20,12 +21,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -52,6 +55,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -90,6 +96,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithVaultSSHCertificateCredentialLibraryAdditionalValidPrincipals(inAdditionalValidPrincipals []string) Option { return func(o *options) { raw, ok := o.postMap["attributes"] diff --git a/api/credentials/credential.gen.go b/api/credentials/credential.gen.go index ef8f496408..d77e7d1283 100644 --- a/api/credentials/credential.gen.go +++ b/api/credentials/credential.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -67,6 +68,13 @@ type CredentialListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + credentialStoreId string + allRemovedIds []string } func (n CredentialListResult) GetItems() []*Credential { @@ -344,96 +352,91 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.pageSize = opts.withPageSize + target.credentialStoreId = credentialStoreId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Credential, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "credentials", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(CredentialListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Credential that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Credential has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Credential) int { + slices.SortFunc(allItems, func(i, j *Credential) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -447,4 +450,80 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialListResult, opt ...Option) (*CredentialListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.credentialStoreId == "" { + return nil, fmt.Errorf("empty credentialStoreId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["credential_store_id"] = currentPage.credentialStoreId + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "credentials", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(CredentialListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.credentialStoreId = currentPage.credentialStoreId + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/credentials/option.gen.go b/api/credentials/option.gen.go index 6d7d002601..7cf0fb75e0 100644 --- a/api/credentials/option.gen.go +++ b/api/credentials/option.gen.go @@ -5,6 +5,7 @@ package credentials import ( + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -20,12 +21,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -52,6 +55,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -90,6 +96,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/credentialstores/credential_store.gen.go b/api/credentialstores/credential_store.gen.go index d0ec322695..055af8150b 100644 --- a/api/credentialstores/credential_store.gen.go +++ b/api/credentialstores/credential_store.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -68,6 +69,13 @@ type CredentialStoreListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n CredentialStoreListResult) GetItems() []*CredentialStore { @@ -345,96 +353,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Cred return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*CredentialStore, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "credential-stores", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(CredentialStoreListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any CredentialStore that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a CredentialStore has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *CredentialStore) int { + slices.SortFunc(allItems, func(i, j *CredentialStore) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -448,4 +453,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Cred // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialStoreListResult, opt ...Option) (*CredentialStoreListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "credential-stores", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(CredentialStoreListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/credentialstores/option.gen.go b/api/credentialstores/option.gen.go index 1b9ec58bcd..eb44bdecf4 100644 --- a/api/credentialstores/option.gen.go +++ b/api/credentialstores/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/groups/group.gen.go b/api/groups/group.gen.go index 49d6441f4e..ed0c31b65b 100644 --- a/api/groups/group.gen.go +++ b/api/groups/group.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -67,6 +68,13 @@ type GroupListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n GroupListResult) GetItems() []*Group { @@ -339,96 +347,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Grou return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Group, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "groups", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(GroupListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Group that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Group has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Group) int { + slices.SortFunc(allItems, func(i, j *Group) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -442,6 +447,89 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Grou // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *GroupListResult, opt ...Option) (*GroupListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "groups", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(GroupListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } func (c *Client) AddMembers(ctx context.Context, id string, version uint32, memberIds []string, opt ...Option) (*GroupUpdateResult, error) { diff --git a/api/groups/option.gen.go b/api/groups/option.gen.go index 13828df81b..d2e7b5a3d4 100644 --- a/api/groups/option.gen.go +++ b/api/groups/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/hostcatalogs/host_catalog.gen.go b/api/hostcatalogs/host_catalog.gen.go index 9682824356..5e4feb45a8 100644 --- a/api/hostcatalogs/host_catalog.gen.go +++ b/api/hostcatalogs/host_catalog.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -73,6 +74,13 @@ type HostCatalogListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n HostCatalogListResult) GetItems() []*HostCatalog { @@ -350,96 +358,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Host return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*HostCatalog, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "host-catalogs", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(HostCatalogListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any HostCatalog that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a HostCatalog has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *HostCatalog) int { + slices.SortFunc(allItems, func(i, j *HostCatalog) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -453,4 +458,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Host // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *HostCatalogListResult, opt ...Option) (*HostCatalogListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "host-catalogs", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(HostCatalogListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/hostcatalogs/option.gen.go b/api/hostcatalogs/option.gen.go index 466fd50364..fe5d26fd1e 100644 --- a/api/hostcatalogs/option.gen.go +++ b/api/hostcatalogs/option.gen.go @@ -22,13 +22,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -58,6 +60,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -96,6 +101,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/hosts/host.gen.go b/api/hosts/host.gen.go index be36382a16..016869f447 100644 --- a/api/hosts/host.gen.go +++ b/api/hosts/host.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -74,6 +75,13 @@ type HostListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + hostCatalogId string + allRemovedIds []string } func (n HostListResult) GetItems() []*Host { @@ -346,96 +354,91 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.pageSize = opts.withPageSize + target.hostCatalogId = hostCatalogId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Host, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "hosts", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(HostListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Host that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Host has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Host) int { + slices.SortFunc(allItems, func(i, j *Host) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -449,4 +452,80 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *HostListResult, opt ...Option) (*HostListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.hostCatalogId == "" { + return nil, fmt.Errorf("empty hostCatalogId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["host_catalog_id"] = currentPage.hostCatalogId + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "hosts", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(HostListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.hostCatalogId = currentPage.hostCatalogId + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/hosts/option.gen.go b/api/hosts/option.gen.go index 9868141e5c..32809d61c2 100644 --- a/api/hosts/option.gen.go +++ b/api/hosts/option.gen.go @@ -5,6 +5,7 @@ package hosts import ( + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -20,12 +21,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -52,6 +55,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -90,6 +96,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithStaticHostAddress(inAddress string) Option { return func(o *options) { raw, ok := o.postMap["attributes"] diff --git a/api/hostsets/host_set.gen.go b/api/hostsets/host_set.gen.go index 1b6642eda7..7ce7e6640b 100644 --- a/api/hostsets/host_set.gen.go +++ b/api/hostsets/host_set.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -72,6 +73,13 @@ type HostSetListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + hostCatalogId string + allRemovedIds []string } func (n HostSetListResult) GetItems() []*HostSet { @@ -344,96 +352,91 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.pageSize = opts.withPageSize + target.hostCatalogId = hostCatalogId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*HostSet, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "host-sets", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(HostSetListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any HostSet that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a HostSet has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *HostSet) int { + slices.SortFunc(allItems, func(i, j *HostSet) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -447,6 +450,82 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *HostSetListResult, opt ...Option) (*HostSetListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.hostCatalogId == "" { + return nil, fmt.Errorf("empty hostCatalogId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["host_catalog_id"] = currentPage.hostCatalogId + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "host-sets", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(HostSetListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.hostCatalogId = currentPage.hostCatalogId + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } func (c *Client) AddHosts(ctx context.Context, id string, version uint32, hostIds []string, opt ...Option) (*HostSetUpdateResult, error) { diff --git a/api/hostsets/option.gen.go b/api/hostsets/option.gen.go index f8649cf470..fa81eddb7b 100644 --- a/api/hostsets/option.gen.go +++ b/api/hostsets/option.gen.go @@ -5,6 +5,7 @@ package hostsets import ( + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -20,12 +21,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -52,6 +55,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -90,6 +96,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/managedgroups/managedgroups.gen.go b/api/managedgroups/managedgroups.gen.go index 9c7497fe4a..02cc9d5be8 100644 --- a/api/managedgroups/managedgroups.gen.go +++ b/api/managedgroups/managedgroups.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -68,6 +69,13 @@ type ManagedGroupListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + authMethodId string + allRemovedIds []string } func (n ManagedGroupListResult) GetItems() []*ManagedGroup { @@ -340,96 +348,91 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.pageSize = opts.withPageSize + target.authMethodId = authMethodId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*ManagedGroup, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "managed-groups", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(ManagedGroupListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any ManagedGroup that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a ManagedGroup has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *ManagedGroup) int { + slices.SortFunc(allItems, func(i, j *ManagedGroup) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -443,4 +446,80 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *ManagedGroupListResult, opt ...Option) (*ManagedGroupListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.authMethodId == "" { + return nil, fmt.Errorf("empty authMethodId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["auth_method_id"] = currentPage.authMethodId + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "managed-groups", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(ManagedGroupListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.authMethodId = currentPage.authMethodId + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/managedgroups/option.gen.go b/api/managedgroups/option.gen.go index 15ca4a69f3..7da2f0803a 100644 --- a/api/managedgroups/option.gen.go +++ b/api/managedgroups/option.gen.go @@ -5,6 +5,7 @@ package managedgroups import ( + "strconv" "strings" "github.com/hashicorp/boundary/api" @@ -20,12 +21,14 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 } func getDefaultOptions() options { @@ -52,6 +55,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withListToken != "" { opts.queryMap["list_token"] = opts.withListToken } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -90,6 +96,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/policies/option.gen.go b/api/policies/option.gen.go index 59a9674b2a..376e19c66b 100644 --- a/api/policies/option.gen.go +++ b/api/policies/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/policies/policy.gen.go b/api/policies/policy.gen.go index 4229172ec5..bf105a2baa 100644 --- a/api/policies/policy.gen.go +++ b/api/policies/policy.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -67,6 +68,13 @@ type PolicyListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n PolicyListResult) GetItems() []*Policy { @@ -344,96 +352,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Poli return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Policy, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "policies", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(PolicyListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Policy that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Policy has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Policy) int { + slices.SortFunc(allItems, func(i, j *Policy) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -447,4 +452,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Poli // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *PolicyListResult, opt ...Option) (*PolicyListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "policies", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(PolicyListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/roles/option.gen.go b/api/roles/option.gen.go index fd552e1f40..06cda67ead 100644 --- a/api/roles/option.gen.go +++ b/api/roles/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/roles/role.gen.go b/api/roles/role.gen.go index 7fae3fd186..7b6b550711 100644 --- a/api/roles/role.gen.go +++ b/api/roles/role.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -70,6 +71,13 @@ type RoleListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n RoleListResult) GetItems() []*Role { @@ -342,96 +350,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Role return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Role, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "roles", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(RoleListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Role that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Role has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Role) int { + slices.SortFunc(allItems, func(i, j *Role) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -445,6 +450,89 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Role // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *RoleListResult, opt ...Option) (*RoleListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "roles", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(RoleListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } func (c *Client) AddGrantScopes(ctx context.Context, id string, version uint32, grantScopeIds []string, opt ...Option) (*RoleUpdateResult, error) { diff --git a/api/scopes/option.gen.go b/api/scopes/option.gen.go index f8d59d6c1f..d20936ba0c 100644 --- a/api/scopes/option.gen.go +++ b/api/scopes/option.gen.go @@ -22,13 +22,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -58,6 +60,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -96,6 +101,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/scopes/scope.gen.go b/api/scopes/scope.gen.go index eeb91f42c1..93119cd5e3 100644 --- a/api/scopes/scope.gen.go +++ b/api/scopes/scope.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -68,6 +69,13 @@ type ScopeListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n ScopeListResult) GetItems() []*Scope { @@ -340,96 +348,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Scop return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Scope, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "scopes", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(ScopeListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Scope that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Scope has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Scope) int { + slices.SortFunc(allItems, func(i, j *Scope) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -443,4 +448,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Scop // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *ScopeListResult, opt ...Option) (*ScopeListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "scopes", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(ScopeListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/sessionrecordings/option.gen.go b/api/sessionrecordings/option.gen.go index fb53594df3..beae7de2ab 100644 --- a/api/sessionrecordings/option.gen.go +++ b/api/sessionrecordings/option.gen.go @@ -20,13 +20,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -56,6 +58,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -75,6 +80,21 @@ func WithListToken(listToken string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/sessionrecordings/session_recording.gen.go b/api/sessionrecordings/session_recording.gen.go index e217ea6fa7..164880cf93 100644 --- a/api/sessionrecordings/session_recording.gen.go +++ b/api/sessionrecordings/session_recording.gen.go @@ -10,6 +10,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -73,6 +74,13 @@ type SessionRecordingListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n SessionRecordingListResult) GetItems() []*SessionRecording { @@ -238,96 +246,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*SessionRecording, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "session-recordings", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(SessionRecordingListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any SessionRecording that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a SessionRecording has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *SessionRecording) int { + slices.SortFunc(allItems, func(i, j *SessionRecording) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -341,4 +346,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *SessionRecordingListResult, opt ...Option) (*SessionRecordingListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "session-recordings", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(SessionRecordingListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/sessions/option.gen.go b/api/sessions/option.gen.go index 15310f74d1..3887657af2 100644 --- a/api/sessions/option.gen.go +++ b/api/sessions/option.gen.go @@ -22,13 +22,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -58,6 +60,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -96,6 +101,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/sessions/session.gen.go b/api/sessions/session.gen.go index 16b2693e3d..5ffaa9f6fe 100644 --- a/api/sessions/session.gen.go +++ b/api/sessions/session.gen.go @@ -10,6 +10,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -75,6 +76,13 @@ type SessionListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n SessionListResult) GetItems() []*Session { @@ -198,96 +206,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Session, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "sessions", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(SessionListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Session that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Session has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Session) int { + slices.SortFunc(allItems, func(i, j *Session) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -301,4 +306,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *SessionListResult, opt ...Option) (*SessionListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "sessions", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(SessionListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/storagebuckets/option.gen.go b/api/storagebuckets/option.gen.go index ad2b9f6234..fc97a45869 100644 --- a/api/storagebuckets/option.gen.go +++ b/api/storagebuckets/option.gen.go @@ -22,13 +22,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -58,6 +60,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -96,6 +101,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/storagebuckets/storage_bucket.gen.go b/api/storagebuckets/storage_bucket.gen.go index ee75c9a41e..b232f5a2db 100644 --- a/api/storagebuckets/storage_bucket.gen.go +++ b/api/storagebuckets/storage_bucket.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -76,6 +77,13 @@ type StorageBucketListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n StorageBucketListResult) GetItems() []*StorageBucket { @@ -348,96 +356,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Stor return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*StorageBucket, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "storage-buckets", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(StorageBucketListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any StorageBucket that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a StorageBucket has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *StorageBucket) int { + slices.SortFunc(allItems, func(i, j *StorageBucket) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -451,4 +456,87 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Stor // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *StorageBucketListResult, opt ...Option) (*StorageBucketListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "storage-buckets", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(StorageBucketListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } diff --git a/api/targets/option.gen.go b/api/targets/option.gen.go index 73b3e85e07..4d71168be6 100644 --- a/api/targets/option.gen.go +++ b/api/targets/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/targets/target.gen.go b/api/targets/target.gen.go index 36d28555e0..6d7ad43315 100644 --- a/api/targets/target.gen.go +++ b/api/targets/target.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -81,6 +82,13 @@ type TargetListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n TargetListResult) GetItems() []*Target { @@ -358,96 +366,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Targ return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*Target, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "targets", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(TargetListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any Target that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a Target has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Target) int { + slices.SortFunc(allItems, func(i, j *Target) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -461,6 +466,89 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Targ // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *TargetListResult, opt ...Option) (*TargetListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "targets", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(TargetListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } func (c *Client) AddCredentialSources(ctx context.Context, id string, version uint32, opt ...Option) (*TargetUpdateResult, error) { diff --git a/api/users/option.gen.go b/api/users/option.gen.go index 7c0c5ed80c..a6349cde10 100644 --- a/api/users/option.gen.go +++ b/api/users/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/users/user.gen.go b/api/users/user.gen.go index c1499f8361..51065721f9 100644 --- a/api/users/user.gen.go +++ b/api/users/user.gen.go @@ -11,6 +11,7 @@ import ( "fmt" "net/url" "slices" + "strconv" "time" "github.com/hashicorp/boundary/api" @@ -71,6 +72,13 @@ type UserListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n UserListResult) GetItems() []*User { @@ -343,96 +351,93 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*User return nil, apiErr } target.Response = resp + if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set + + target.recursive = opts.withRecursive + + target.pageSize = opts.withPageSize + target.scopeId = scopeId + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*User, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "users", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new(UserListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any User that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a User has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *User) int { + slices.SortFunc(allItems, func(i, j *User) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -446,6 +451,89 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*User // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + +} + +func (c *Client) ListNextPage(ctx context.Context, currentPage *UserListResult, opt ...Option) (*UserListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.scopeId == "" { + return nil, fmt.Errorf("empty scopeId value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["scope_id"] = currentPage.scopeId + + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } + + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "users", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new(UserListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.scopeId = currentPage.scopeId + + nextPage.recursive = currentPage.recursive + + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil } func (c *Client) AddAccounts(ctx context.Context, id string, version uint32, accountIds []string, opt ...Option) (*UserUpdateResult, error) { diff --git a/api/workers/option.gen.go b/api/workers/option.gen.go index c9f51af416..1ca3ada826 100644 --- a/api/workers/option.gen.go +++ b/api/workers/option.gen.go @@ -21,13 +21,15 @@ import ( type Option func(*options) type options struct { - postMap map[string]any - queryMap map[string]string - withAutomaticVersioning bool - withSkipCurlOutput bool - withFilter string - withListToken string - withRecursive bool + postMap map[string]any + queryMap map[string]string + withAutomaticVersioning bool + withSkipCurlOutput bool + withFilter string + withListToken string + withClientDirectedPagination bool + withPageSize uint32 + withRecursive bool } func getDefaultOptions() options { @@ -57,6 +59,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -95,6 +100,21 @@ func WithFilter(filter string) Option { } } +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/workers/worker.gen.go b/api/workers/worker.gen.go index 140dc8e920..efe18d859e 100644 --- a/api/workers/worker.gen.go +++ b/api/workers/worker.gen.go @@ -6,11 +6,9 @@ package workers import ( "context" - "encoding/json" "errors" "fmt" "net/url" - "slices" "time" "github.com/hashicorp/boundary/api" @@ -78,6 +76,13 @@ type WorkerListResult struct { ListToken string `json:"list_token,omitempty"` ResponseType string `json:"response_type,omitempty"` Response *api.Response + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + scopeId string + allRemovedIds []string } func (n WorkerListResult) GetItems() []*Worker { @@ -399,109 +404,9 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Work return nil, apiErr } target.Response = resp - if target.ResponseType == "complete" || target.ResponseType == "" { - return target, nil - } - // If there are more results, automatically fetch the rest of the results. - // idToIndex keeps a map from the ID of an item to its index in target.Items. - // This is used to update updated items in-place and remove deleted items - // from the result after pagination is done. - idToIndex := map[string]int{} - for i, item := range target.Items { - idToIndex[item.Id] = i - } - for { - req, err := c.client.NewRequest(ctx, "GET", "workers", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) - } - page := new(WorkerListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { - if i, ok := idToIndex[item.Id]; ok { - // Item has already been seen at index i, update in-place - target.Items[i] = item - } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 - } - } - // RemovedIds contain any Worker that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { - break - } - } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) - // Remove items that were deleted since the end of the last iteration. - // If a Worker has been updated and subsequently removed, we don't want - // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { - if i, ok := idToIndex[removedId]; ok { - // Remove the item at index i without preserving order - // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] - // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i - } - } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) - // Sort the results again since in-place updates and deletes - // may have shuffled items. We sort by created time descending - // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *Worker) int { - return j.CreatedTime.Compare(i.CreatedTime) - }) - // Finally, since we made at least 2 requests to the server to fulfill this - // function call, resp.Body and resp.Map will only contain the most recent response. - // Overwrite them with the true response. - target.Response.Body.Reset() - if err := json.NewEncoder(target.Response.Body).Encode(target); err != nil { - return nil, fmt.Errorf("error encoding final JSON list response: %w", err) - } - if err := json.Unmarshal(target.Response.Body.Bytes(), &target.Response.Map); err != nil { - return nil, fmt.Errorf("error encoding final map list response: %w", err) - } - // Note: the HTTP response body is consumed by resp.Decode in the loop, - // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil + } func (c *Client) AddWorkerTags(ctx context.Context, id string, version uint32, apiTags map[string][]string, opt ...Option) (*WorkerUpdateResult, error) { diff --git a/internal/api/genapi/input.go b/internal/api/genapi/input.go index 5e9291f32f..0f497ec2b3 100644 --- a/internal/api/genapi/input.go +++ b/internal/api/genapi/input.go @@ -142,6 +142,10 @@ type structInfo struct { // fields fieldFilter []string + // nonPaginatedListing indicates a collection that does not support + // pagination + nonPaginatedListing bool + allowEmpty bool } @@ -1426,5 +1430,6 @@ var inputStructs = []*structInfo{ createResponseTypes: []string{CreateResponseType, ReadResponseType, UpdateResponseType, DeleteResponseType, ListResponseType}, recursiveListing: true, versionEnabled: true, + nonPaginatedListing: true, }, } diff --git a/internal/api/genapi/templates.go b/internal/api/genapi/templates.go index 8e769e77a2..aa37311e5a 100644 --- a/internal/api/genapi/templates.go +++ b/internal/api/genapi/templates.go @@ -49,6 +49,7 @@ type templateInput struct { SliceSubtypes map[string]sliceSubtypeInfo ExtraFields []fieldInfo VersionEnabled bool + NonPaginatedListing bool CreateResponseTypes []string SkipListFiltering bool RecursiveListing bool @@ -69,6 +70,7 @@ func fillTemplates() { ParentTypeName: in.parentTypeName, ExtraFields: in.extraFields, VersionEnabled: in.versionEnabled, + NonPaginatedListing: in.nonPaginatedListing, CreateResponseTypes: in.createResponseTypes, SkipListFiltering: in.skipListFiltering, RecursiveListing: in.recursiveListing, @@ -268,96 +270,96 @@ func (c *Client) List(ctx context.Context, {{ .CollectionFunctionArg }} string, return nil, apiErr } target.Response = resp +{{ if .NonPaginatedListing }} + return target, nil +{{ end }} +{{ if ( not ( .NonPaginatedListing ) ) }} if target.ResponseType == "complete" || target.ResponseType == "" { return target, nil } + + // In case we shortcut out due to client directed pagination, ensure these + // are set +{{ if .RecursiveListing }} + target.recursive = opts.withRecursive +{{ end }} + target.pageSize = opts.withPageSize + target.{{ .CollectionFunctionArg }} = {{ .CollectionFunctionArg }} + target.allRemovedIds = target.RemovedIds + if opts.withClientDirectedPagination { + return target, nil + } + + allItems := make([]*{{ .Name }}, 0, target.EstItemCount) + allItems = append(allItems, target.Items...) + // If there are more results, automatically fetch the rest of the results. // idToIndex keeps a map from the ID of an item to its index in target.Items. // This is used to update updated items in-place and remove deleted items // from the result after pagination is done. idToIndex := map[string]int{} - for i, item := range target.Items { + for i, item := range allItems { idToIndex[item.Id] = i } - for { - req, err := c.client.NewRequest(ctx, "GET", "{{ .CollectionPath }}", nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) + // If we're here there are more pages and the client does not want to + // paginate on their own; fetch them as this call returns all values. + currentPage := target + for { + nextPage, err := c.ListNextPage(ctx, currentPage, opt...) if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) + return nil, fmt.Errorf("error getting next page in List call: %w", err) } - page := new({{ .Name }}ListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { + for _, item := range nextPage.Items { if i, ok := idToIndex[item.Id]; ok { // Item has already been seen at index i, update in-place - target.Items[i] = item + allItems[i] = item } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 + allItems = append(allItems, item) + idToIndex[item.Id] = len(allItems) - 1 } } - // RemovedIds contain any {{ .Name }} that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { + + currentPage = nextPage + + if currentPage.ResponseType == "complete" { break } } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) + + // The current page here is the final page of the results, that is, the + // response type is "complete" + // Remove items that were deleted since the end of the last iteration. // If a {{ .Name }} has been updated and subsequently removed, we don't want // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { + for _, removedId := range currentPage.RemovedIds { if i, ok := idToIndex[removedId]; ok { // Remove the item at index i without preserving order // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] + allItems[i] = allItems[len(allItems)-1] + allItems = allItems[:len(allItems)-1] // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i + idToIndex[allItems[i].Id] = i } } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) // Sort the results again since in-place updates and deletes // may have shuffled items. We sort by created time descending // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *{{ .Name }}) int { + slices.SortFunc(allItems, func(i, j *{{ .Name }}) int { return j.CreatedTime.Compare(i.CreatedTime) }) + // Since we paginated to the end, we can avoid confusion + // for the user by setting the estimated item count to the + // length of the items slice. If we don't set this here, it + // will equal the value returned in the last response, which is + // often much smaller than the total number returned. + currentPage.EstItemCount = uint(len(allItems)) + // Set items to the full list we have collected here + currentPage.Items = allItems + // Set the returned value to the last page with calculated values + target = currentPage // Finally, since we made at least 2 requests to the server to fulfill this // function call, resp.Body and resp.Map will only contain the most recent response. // Overwrite them with the true response. @@ -371,7 +373,93 @@ func (c *Client) List(ctx context.Context, {{ .CollectionFunctionArg }} string, // Note: the HTTP response body is consumed by resp.Decode in the loop, // so it doesn't need to be updated (it will always be, and has always been, empty). return target, nil +{{ end }} } + +{{ if ( not ( .NonPaginatedListing ) ) }} +func (c *Client) ListNextPage(ctx context.Context, currentPage *{{ .Name }}ListResult, opt ...Option) (*{{ .Name }}ListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListNextPage request") + } + if currentPage.{{ .CollectionFunctionArg }} == "" { + return nil, fmt.Errorf("empty {{ .CollectionFunctionArg }} value in currentPage passed into ListNextPage request") + } + if c.client == nil { + return nil, fmt.Errorf("nil client") + } + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListNextPage request") + } + + opts, apiOpts := getOpts(opt...) + opts.queryMap["{{ snakeCase .CollectionFunctionArg }}"] = currentPage.{{ .CollectionFunctionArg }} + +{{ if .RecursiveListing }} + // Don't require them to re-specify recursive + if currentPage.recursive { + opts.queryMap["recursive"] = "true" + } +{{ end }} + if currentPage.pageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) + } + + req, err := c.client.NewRequest(ctx, "GET", "{{ .CollectionPath }}", nil, apiOpts...) + if err != nil { + return nil, fmt.Errorf("error creating List request: %w", err) + } + + opts.queryMap["list_token"] = currentPage.ListToken + if len(opts.queryMap) > 0 { + q := url.Values{} + for k, v := range opts.queryMap { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error performing client request during List call during ListNextPage: %w", err) + } + + nextPage := new({{ .Name }}ListResult) + apiErr, err := resp.Decode(nextPage) + if err != nil { + return nil, fmt.Errorf("error decoding List response during ListNextPage: %w", err) + } + if apiErr != nil { + return nil, apiErr + } + + // Ensure values are carried forward to the next call + nextPage.{{ .CollectionFunctionArg }} = currentPage.{{ .CollectionFunctionArg }} +{{ if .RecursiveListing }} + nextPage.recursive = currentPage.recursive +{{ end }} + nextPage.pageSize = currentPage.pageSize + // Cache the removed IDs from this page + nextPage.allRemovedIds = append(currentPage.allRemovedIds, nextPage.RemovedIds...) + // Set the response body to the current response + nextPage.Response = resp + // If we're done iterating, pull the full set of removed IDs into the last + // response + if nextPage.ResponseType == "complete" { + // Collect up the last values + nextPage.RemovedIds = nextPage.allRemovedIds + // For now, removedIds will only be populated if this pagination cycle + // was the result of a "refresh" operation (i.e., the caller provided a + // list token option to this call). + // + // Sort to make response deterministic + slices.Sort(nextPage.RemovedIds) + // Remove any duplicates + nextPage.RemovedIds = slices.Compact(nextPage.RemovedIds) + } + + return nextPage, nil +} +{{ end }} `)) var readTemplate = template.Must(template.New("").Parse(` @@ -764,6 +852,14 @@ type {{ .Name }}ListResult struct { ListToken string `, "`json:\"list_token,omitempty\"`", ` ResponseType string `, "`json:\"response_type,omitempty\"`", ` Response *api.Response + + + // The following fields are used for cached information when client-directed + // pagination is used. + recursive bool + pageSize uint32 + {{ .CollectionFunctionArg }} string + allRemovedIds []string } func (n {{ .Name }}ListResult) GetItems() []*{{ .Name }} { @@ -845,6 +941,8 @@ type options struct { withSkipCurlOutput bool withFilter string withListToken string + withClientDirectedPagination bool + withPageSize uint32 {{ if .RecursiveListing }} withRecursive bool {{ end }} } @@ -875,6 +973,9 @@ func getOpts(opt ...Option) (options, []api.Option) { if opts.withRecursive { opts.queryMap["recursive"] = strconv.FormatBool(opts.withRecursive) } {{ end }} + if opts.withPageSize != 0 { + opts.queryMap["page_size"] = strconv.FormatUint(uint64(opts.withPageSize), 10) + } return opts, apiOpts } @@ -916,6 +1017,20 @@ func WithFilter(filter string) Option { } } {{ end }} +// WithClientDirectedPagination tells the List function to return only the first +// page, if more pages are available +func WithClientDirectedPagination(with bool) Option { + return func(o *options) { + o.withClientDirectedPagination = with + } +} + +// WithPageSize controls the size of pages used during List +func WithPageSize(with uint32) Option { + return func(o *options) { + o.withPageSize = with + } +} {{ if .RecursiveListing }} // WithRecursive tells the API to use recursion for listing operations on this // resource diff --git a/internal/clientcache/cmd/cache/start.go b/internal/clientcache/cmd/cache/start.go index c94d507e09..e9d0385977 100644 --- a/internal/clientcache/cmd/cache/start.go +++ b/internal/clientcache/cmd/cache/start.go @@ -50,6 +50,7 @@ type StartCommand struct { flagLogFormat string flagStoreDebug bool flagBackground bool + flagForceResetSchema bool } func (c *StartCommand) Synopsis() string { @@ -133,6 +134,13 @@ func (c *StartCommand) Flags() *base.FlagSets { Default: false, Usage: `Run the cache daemon in the background`, }) + f.BoolVar(&base.BoolVar{ + Name: "force-reset-schema", + Target: &c.flagForceResetSchema, + Default: false, + Usage: `Force resetting the cache schema and all contained data`, + Hidden: true, + }) return set } @@ -207,6 +215,7 @@ func (c *StartCommand) Run(args []string) int { LogFileName: logFileName, DotDirectory: dotDir, RunningInBackground: os.Getenv(backgroundEnvName) == backgroundEnvVal, + ForceResetSchema: c.flagForceResetSchema, } srv, err := daemon.New(ctx, cfg) diff --git a/internal/clientcache/internal/cache/options.go b/internal/clientcache/internal/cache/options.go index 36907f8833..f57ab71809 100644 --- a/internal/clientcache/internal/cache/options.go +++ b/internal/clientcache/internal/cache/options.go @@ -25,6 +25,7 @@ type options struct { withIgnoreSearchStaleness bool withMaxResultSetSize int withTestRefreshWaitChs *testRefreshWaitChs + withUseNonPagedListing bool } // Option - how options are passed as args @@ -129,3 +130,12 @@ func WithTestRefreshWaitChs(with *testRefreshWaitChs) Option { return nil } } + +// WithUseNonPagedListing provides an option for ignoring the resource +// staleness when performing a search. +func WithUseNonPagedListing(b bool) Option { + return func(o *options) error { + o.withUseNonPagedListing = b + return nil + } +} diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 44074fdbe6..9f25faebe3 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -49,8 +49,8 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithTargetRetrievalFunc", func(t *testing.T) { - var f TargetRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { - return nil, nil, "", nil + var f TargetRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *targets.TargetListResult, opt ...Option) (*targets.TargetListResult, RefreshTokenValue, error) { + return nil, "", nil } opts, err := getOpts(WithTargetRetrievalFunc(f)) require.NoError(t, err) @@ -118,4 +118,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withTestRefreshWaitChs = waitCh assert.Equal(t, opts, testOpts) }) + t.Run("WithUseNonPagedListing", func(t *testing.T) { + opts, err := getOpts(WithUseNonPagedListing(true)) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.False(t, testOpts.withUseNonPagedListing) + testOpts.withUseNonPagedListing = true + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/clientcache/internal/cache/refresh.go b/internal/clientcache/internal/cache/refresh.go index 6cc294be09..015ca44800 100644 --- a/internal/clientcache/internal/cache/refresh.go +++ b/internal/clientcache/internal/cache/refresh.go @@ -420,28 +420,47 @@ func (r *RefreshService) RecheckCachingSupport(ctx context.Context, opt ...Optio continue } - if err := r.repo.checkCachingTargets(ctx, u, tokens, opt...); err != nil { - if err == ErrRefreshNotSupported { - // This is expected so no need to propagate the error up - continue + cacheKey := fmt.Sprintf("%s-%s", u.Id, Targets) + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.checkCachingTargets(ctx, u, tokens, opt...); err != nil { + if err == ErrRefreshNotSupported { + semaphore.(*atomic.Bool).Store(false) + // This is expected so no need to propagate the error up + continue + } + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) } - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + semaphore.(*atomic.Bool).Store(false) } - if err := r.repo.checkCachingSessions(ctx, u, tokens, opt...); err != nil { - if err == ErrRefreshNotSupported { - // This is expected so no need to propagate the error up - continue + + cacheKey = fmt.Sprintf("%s-%s", u.Id, ResolvableAliases) + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.checkCachingResolvableAliases(ctx, u, tokens, opt...); err != nil { + if err == ErrRefreshNotSupported { + // This is expected so no need to propagate the error up + semaphore.(*atomic.Bool).Store(false) + continue + } + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) } - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + semaphore.(*atomic.Bool).Store(false) } - if err := r.repo.checkCachingResolvableAliases(ctx, u, tokens, opt...); err != nil { - if err == ErrRefreshNotSupported { - // This is expected so no need to propagate the error up - continue + + cacheKey = fmt.Sprintf("%s-%s", u.Id, Sessions) + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.checkCachingSessions(ctx, u, tokens, opt...); err != nil { + if err == ErrRefreshNotSupported { + semaphore.(*atomic.Bool).Store(false) + // This is expected so no need to propagate the error up + continue + } + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) } - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + semaphore.(*atomic.Bool).Store(false) } - } return retErr } diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index bd0905de93..0d45187279 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -54,6 +54,22 @@ func testStaticResourceRetrievalFunc[T any](t *testing.T, ret [][]T, removed [][ } } +func testTargetStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error)) TargetRetrievalFunc { + return func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *targets.TargetListResult, opt ...Option) (ret *targets.TargetListResult, refreshToken RefreshTokenValue, err error) { + retTargets, removed, refreshToken, err := inFunc(ctx, addr, authTok, refreshTok) + if err != nil { + return nil, "", err + } + + ret = &targets.TargetListResult{ + Items: retTargets, + RemovedIds: removed, + ResponseType: "complete", + } + return ret, refreshToken, nil + } +} + // testNoRefreshRetrievalFunc simulates a controller that doesn't support refresh // since it does not return any refresh token. func testNoRefreshRetrievalFunc[T any](t *testing.T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { @@ -426,7 +442,7 @@ func TestRefreshForSearch(t *testing.T) { opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], retTargets[3:], @@ -435,7 +451,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retTargets[0].Id, retTargets[1].Id}, }, - )), + ))), } assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) @@ -481,7 +497,7 @@ func TestRefreshForSearch(t *testing.T) { opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], retTargets[3:], @@ -490,7 +506,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retTargets[0].Id, retTargets[1].Id}, }, - )), + ))), } assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) @@ -538,8 +554,8 @@ func TestRefreshForSearch(t *testing.T) { // Get the first set of resources, but no refresh tokens err = rs.Refresh(ctx, WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.ErrorContains(t, err, ErrRefreshNotSupported.Error()) got, err := r.ListTargets(ctx, at.Id) @@ -554,14 +570,14 @@ func TestRefreshForSearch(t *testing.T) { // any more. err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.Nil(t, err) err = rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), + WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.Nil(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -575,8 +591,9 @@ func TestRefreshForSearch(t *testing.T) { // the resources starting to be cached. err = rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{retTargets}, [][]string{{}}))), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{retTargets}, [][]string{{}}))) + ) assert.Nil(t, err, err) got, err = r.ListTargets(ctx, at.Id) @@ -601,7 +618,7 @@ func TestRefreshForSearch(t *testing.T) { } opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], @@ -653,7 +670,7 @@ func TestRefreshForSearch(t *testing.T) { } opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], @@ -710,7 +727,7 @@ func TestRefreshForSearch(t *testing.T) { } opts := []Option{ WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAl[:3], @@ -762,7 +779,7 @@ func TestRefreshForSearch(t *testing.T) { } opts := []Option{ WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAls[:3], @@ -839,7 +856,7 @@ func TestRefreshNonBlocking(t *testing.T) { opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], retTargets[3:], @@ -848,7 +865,7 @@ func TestRefreshNonBlocking(t *testing.T) { nil, {retTargets[0].Id, retTargets[1].Id}, }, - )), + ))), } refreshWaitChs := &testRefreshWaitChs{ @@ -899,7 +916,7 @@ func TestRefreshNonBlocking(t *testing.T) { } opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], @@ -961,7 +978,7 @@ func TestRefreshNonBlocking(t *testing.T) { } opts := []Option{ WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAl[:3], @@ -1041,7 +1058,7 @@ func TestRefresh(t *testing.T) { opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], retTargets[3:], @@ -1050,7 +1067,7 @@ func TestRefresh(t *testing.T) { nil, {retTargets[0].Id, retTargets[1].Id}, }, - )), + ))), } assert.NoError(t, rs.Refresh(ctx, opts...)) @@ -1082,7 +1099,7 @@ func TestRefresh(t *testing.T) { } opts := []Option{ WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], @@ -1123,7 +1140,7 @@ func TestRefresh(t *testing.T) { } opts := []Option{ WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAls[:3], @@ -1160,15 +1177,15 @@ func TestRefresh(t *testing.T) { err = rs.Refresh(ctx, WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) return nil, nil, "", innerErr - })) + }))) assert.ErrorContains(t, err, innerErr.Error()) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1202,7 +1219,7 @@ func TestRefresh(t *testing.T) { require.NoError(t, rs.Refresh(ctx, WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))))) ps, err = r.listTokens(ctx, u) require.NoError(t, err) @@ -1246,7 +1263,7 @@ func TestRecheckCachingSupport(t *testing.T) { assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))) got, err := r.ListTargets(ctx, at.Id) require.NoError(t, err) @@ -1258,7 +1275,7 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListTargets(ctx, at.Id) @@ -1272,7 +1289,7 @@ func TestRecheckCachingSupport(t *testing.T) { assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))) }) t.Run("sessions", func(t *testing.T) { @@ -1286,7 +1303,7 @@ func TestRecheckCachingSupport(t *testing.T) { assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListSessions(ctx, at.Id) @@ -1298,7 +1315,7 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.ErrorIs(t, err, ErrRefreshNotSupported) @@ -1311,7 +1328,7 @@ func TestRecheckCachingSupport(t *testing.T) { assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) @@ -1332,7 +1349,7 @@ func TestRecheckCachingSupport(t *testing.T) { assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListResolvableAliases(ctx, at.Id) @@ -1344,7 +1361,7 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.ErrorIs(t, err, ErrRefreshNotSupported) @@ -1357,7 +1374,7 @@ func TestRecheckCachingSupport(t *testing.T) { assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) @@ -1378,7 +1395,7 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.ErrorIs(t, err, ErrRefreshNotSupported) @@ -1386,21 +1403,21 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) return nil, nil, "", innerErr - })) + }))) assert.ErrorContains(t, err, innerErr.Error()) err = rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) return nil, nil, "", innerErr - })) + }))) assert.ErrorContains(t, err, innerErr.Error()) }) @@ -1415,7 +1432,7 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) assert.ErrorIs(t, err, ErrRefreshNotSupported) @@ -1435,7 +1452,7 @@ func TestRecheckCachingSupport(t *testing.T) { err = rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.NoError(t, err) ps, err = r.listTokens(ctx, u) diff --git a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go index f5bc226337..ca5f5c0b64 100644 --- a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go +++ b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go @@ -62,7 +62,7 @@ func TestRepository_ImplicitScopes(t *testing.T) { target("3"), } require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))) for _, t := range ts { expectedScopes = append(expectedScopes, &scopes.Scope{ diff --git a/internal/clientcache/internal/cache/repository_targets.go b/internal/clientcache/internal/cache/repository_targets.go index e5cfa6e1f9..56ccc16f20 100644 --- a/internal/clientcache/internal/cache/repository_targets.go +++ b/internal/clientcache/internal/cache/repository_targets.go @@ -21,32 +21,42 @@ import ( // TargetRetrievalFunc is a function that retrieves targets // from the provided boundary addr using the provided token. -type TargetRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*targets.Target, removedIds []string, refreshToken RefreshTokenValue, err error) +type TargetRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *targets.TargetListResult, opt ...Option) (ret *targets.TargetListResult, refreshToken RefreshTokenValue, err error) -func defaultTargetFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { +func defaultTargetFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *targets.TargetListResult, opt ...Option) (*targets.TargetListResult, RefreshTokenValue, error) { const op = "cache.defaultTargetFunc" conf, err := api.DefaultConfig() if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) + } + opts, err := getOpts(opt...) + if err != nil { + return nil, "", errors.Wrap(ctx, err, op) } conf.Addr = addr conf.Token = authTok client, err := api.NewClient(conf) if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } tarClient := targets.NewClient(client) - l, err := tarClient.List(ctx, "global", targets.WithRecursive(true), targets.WithListToken(string(refreshTok))) + var l *targets.TargetListResult + switch inPage { + case nil: + l, err = tarClient.List(ctx, "global", targets.WithRecursive(true), targets.WithListToken(string(refreshTok)), targets.WithClientDirectedPagination(!opts.withUseNonPagedListing)) + default: + l, err = tarClient.ListNextPage(ctx, inPage, targets.WithListToken(string(refreshTok))) + } if err != nil { if api.ErrInvalidListToken.Is(err) { - return nil, nil, "", err + return nil, "", err } - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } if l.ResponseType == "" { - return nil, nil, "", ErrRefreshNotSupported + return nil, "", ErrRefreshNotSupported } - return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil + return l, RefreshTokenValue(l.ListToken), nil } // refreshTargets uses attempts to refresh the targets for the provided user @@ -81,13 +91,13 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut // Find and use a token for retrieving targets var gotResponse bool - var resp []*targets.Target - var removedIds []string + var currentPage *targets.TargetListResult var newRefreshToken RefreshTokenValue + var foundAuthToken string var unsupportedCacheRequest bool var retErr error for at, t := range tokens { - resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal) + currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal, currentPage) if api.ErrInvalidListToken.Is(err) { event.WriteSysEvent(ctx, op, "old list token is no longer valid, starting new initial fetch", "user_id", u.Id) if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { @@ -95,7 +105,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut } // try again without the refresh token oldRefreshToken = nil - resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "") + currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "", currentPage) } if err != nil { if err == ErrRefreshNotSupported { @@ -105,6 +115,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut continue } } + foundAuthToken = t gotResponse = true break } @@ -118,44 +129,57 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut } var numDeleted int - _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - var err error - switch { - case oldRefreshToken == nil || unsupportedCacheRequest: - if numDeleted, err = w.Exec(ctx, "delete from target where fk_user_id = @fk_user_id", - []any{sql.Named("fk_user_id", u.Id)}); err != nil { - return err - } - case len(removedIds) > 0: - if numDeleted, err = w.Exec(ctx, "delete from target where id in @ids", - []any{sql.Named("ids", removedIds)}); err != nil { - return err - } - } - switch { - case unsupportedCacheRequest: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err + var numUpserted int + var clearPerformed bool + for { + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { + var err error + if (oldRefreshToken == nil || unsupportedCacheRequest) && !clearPerformed { + if numDeleted, err = w.Exec(ctx, "delete from target where fk_user_id = @fk_user_id", + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } + clearPerformed = true } - case newRefreshToken != "": - if err := upsertTargets(ctx, w, u, resp); err != nil { - return err + switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } + case newRefreshToken != "": + if err := upsertTargets(ctx, w, u, currentPage.Items); err != nil { + return err + } + numUpserted += len(currentPage.Items) + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } + default: + // controller supports caching, but doesn't have any resources } - if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { - return err + if !unsupportedCacheRequest && len(currentPage.RemovedIds) > 0 { + if numDeleted, err = w.Exec(ctx, "delete from target where id in @ids", + []any{sql.Named("ids", currentPage.RemovedIds)}); err != nil { + return err + } } - default: - // controller supports caching, but doesn't have any resources + return nil + }) + if unsupportedCacheRequest || currentPage.ResponseType == "" || currentPage.ResponseType == "complete" { + break } - return nil - }) + currentPage, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, foundAuthToken, newRefreshToken, currentPage) + if err != nil { + break + } + } if err != nil { return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "targets updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "targets updated", "deleted", numDeleted, "upserted", numUpserted, "user_id", u.Id) return nil } @@ -185,12 +209,12 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma // Find and use a token for retrieving targets var gotResponse bool - var resp []*targets.Target + var resp *targets.TargetListResult var newRefreshToken RefreshTokenValue var unsupportedCacheRequest bool var retErr error for at, t := range tokens { - resp, _, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "") + resp, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "", nil, WithUseNonPagedListing(true)) if err != nil { if err == ErrRefreshNotSupported { unsupportedCacheRequest = true @@ -221,14 +245,15 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma return err } case newRefreshToken != "": - var err error // Now that there is a refresh token, the data can be cached, so - // cache it and store the refresh token for future refreshes. + // cache it and store the refresh token for future refreshes. First + // remove any values, then add the new ones + var err error if numDeleted, err = w.Exec(ctx, "delete from target where fk_user_id = @fk_user_id", []any{sql.Named("fk_user_id", u.Id)}); err != nil { return err } - if err := upsertTargets(ctx, w, u, resp); err != nil { + if err := upsertTargets(ctx, w, u, resp.Items); err != nil { return err } if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { @@ -249,7 +274,7 @@ func (r *Repository) checkCachingTargets(ctx context.Context, u *user, tokens ma if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "targets updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "targets updated", "deleted", numDeleted, "upserted", len(resp.Items), "user_id", u.Id) return nil } diff --git a/internal/clientcache/internal/cache/repository_targets_test.go b/internal/clientcache/internal/cache/repository_targets_test.go index a3c705b513..deb87d692b 100644 --- a/internal/clientcache/internal/cache/repository_targets_test.go +++ b/internal/clientcache/internal/cache/repository_targets_test.go @@ -141,7 +141,7 @@ func TestRepository_refreshTargets(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshTargets(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{tc.targets}, [][]string{nil}))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{tc.targets}, [][]string{nil})))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) @@ -213,12 +213,12 @@ func TestRepository_RefreshTargets_InvalidListTokenError(t *testing.T) { } require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(invalidAuthTokenFunc))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(invalidAuthTokenFunc)))) // This time an invalid auth token should be returned, and refreshTargets should fall back // to requesting without one. require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(invalidAuthTokenFunc))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(invalidAuthTokenFunc)))) assert.Equal(t, 1, withRefreshToken) assert.Equal(t, 2, withoutRefreshToken) @@ -261,7 +261,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { } require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil}))))) got, err := r.ListTargets(ctx, at.Id) require.NoError(t, err) @@ -270,7 +270,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { // Refreshing again uses the refresh token and get additional sessions, appending // them to the response require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil}))))) assert.NoError(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -280,7 +280,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { // Refreshing again wont return any more resources, but also none should be // removed require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil}))))) assert.NoError(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -289,7 +289,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ts[0])))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ts[0]))))) assert.NoError(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -345,7 +345,7 @@ func TestRepository_ListTargets(t *testing.T) { target("3"), } require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))) t.Run("wrong user gets no targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt2.AuthTokenId) @@ -389,7 +389,7 @@ func TestRepository_ListTargetsLimiting(t *testing.T) { ts = append(ts, target("t"+strconv.Itoa(i))) } require.NoError(t, r.refreshTargets(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -512,7 +512,7 @@ func TestRepository_QueryTargets(t *testing.T) { }, } require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))) t.Run("wrong token gets no targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query) @@ -556,7 +556,7 @@ func TestRepository_QueryTargetsLimiting(t *testing.T) { ts = append(ts, target("t"+strconv.Itoa(i))) } require.NoError(t, r.refreshTargets(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -606,13 +606,13 @@ func TestDefaultTargetRetrievalFunc(t *testing.T) { require.NoError(t, err) require.NotNil(t, tar2) - got, removed, refTok, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "") + got, refTok, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "", nil) assert.NoError(t, err) assert.NotEmpty(t, refTok) - assert.Empty(t, removed) + assert.Empty(t, got.RemovedIds) found1 := false found2 := false - for _, t := range got { + for _, t := range got.Items { if t.Id == tar1.Item.Id { found1 = true } @@ -623,10 +623,10 @@ func TestDefaultTargetRetrievalFunc(t *testing.T) { assert.True(t, found1, "expected to find target %s in list", tar1.Item.Id) assert.True(t, found2, "expected to find target %s in list", tar2.Item.Id) - got2, removed2, refTok2, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok) + got2, refTok2, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok, nil) assert.NoError(t, err) assert.NotEmpty(t, refTok2) assert.NotEqual(t, refTok2, refTok) - assert.Empty(t, removed2) - assert.Empty(t, got2) + assert.Empty(t, got.RemovedIds) + assert.Empty(t, got2.Items) } diff --git a/internal/clientcache/internal/cache/status_test.go b/internal/clientcache/internal/cache/status_test.go index 2a87384549..67e34fc2ab 100644 --- a/internal/clientcache/internal/cache/status_test.go +++ b/internal/clientcache/internal/cache/status_test.go @@ -190,11 +190,11 @@ func TestStatus(t *testing.T) { target("4"), } err = r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) require.NoError(t, err) err = r.refreshTargets(ctx, u2, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts[:2]}, [][]string{nil}))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts[:2]}, [][]string{nil})))) require.NoError(t, err) sess := []*sessions.Session{ @@ -312,7 +312,7 @@ func TestStatus_unsupported(t *testing.T) { require.ErrorIs(t, err, ErrRefreshNotSupported) err = r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))) + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) require.ErrorIs(t, err, ErrRefreshNotSupported) err = r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, diff --git a/internal/clientcache/internal/daemon/options.go b/internal/clientcache/internal/daemon/options.go index daed203df3..1a4b15ad5f 100644 --- a/internal/clientcache/internal/daemon/options.go +++ b/internal/clientcache/internal/daemon/options.go @@ -20,9 +20,10 @@ type options struct { WithReadyToServeNotificationCh chan struct{} withBoundaryTokenReaderFunc cache.BoundaryTokenReaderFn - withUrl string - withLogger hclog.Logger - withHomeDir string + withUrl string + withLogger hclog.Logger + withHomeDir string + withForceResetSchema bool } // Option - how options are passed as args @@ -110,6 +111,15 @@ func WithBoundaryTokenReaderFunc(_ context.Context, fn cache.BoundaryTokenReader } } +// WithForceResetSchema provides an optional way to force resetting the schema, +// e.g. wiping the cache +func WithForceResetSchema(_ context.Context, force bool) Option { + return func(o *options) error { + o.withForceResetSchema = force + return nil + } +} + // WithReadyToServeNotificationCh provides an optional channel to notify when // the server is ready to serve; mainly used for test timing but exported for // availability. The channel will be closed just before the HTTP server is diff --git a/internal/clientcache/internal/daemon/options_test.go b/internal/clientcache/internal/daemon/options_test.go index 0afcbab1b5..4fd8bb4d68 100644 --- a/internal/clientcache/internal/daemon/options_test.go +++ b/internal/clientcache/internal/daemon/options_test.go @@ -108,6 +108,14 @@ func Test_GetOpts(t *testing.T) { testOpts.withHomeDir = "/tmp" assert.Equal(t, opts, testOpts) }) + t.Run("WithForceResetSchema", func(t *testing.T) { + opts, err := getOpts(WithForceResetSchema(ctx, true)) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.False(t, testOpts.withForceResetSchema) + testOpts.withForceResetSchema = true + assert.Equal(t, opts, testOpts) + }) t.Run("WithReadyToServeNotificationCh", func(t *testing.T) { ch := make(chan struct{}) opts, err := getOpts(WithReadyToServeNotificationCh(ctx, ch)) diff --git a/internal/clientcache/internal/daemon/server.go b/internal/clientcache/internal/daemon/server.go index cc642629b7..2c0037c34c 100644 --- a/internal/clientcache/internal/daemon/server.go +++ b/internal/clientcache/internal/daemon/server.go @@ -90,6 +90,8 @@ type Config struct { // The maximum amount of time a refresh should block a search request from // completing before it times out. MaxSearchRefreshTimeout time.Duration + // Force resetting the schema, that is, drop all existing data + ForceResetSchema bool } func (sc *Config) validate(ctx context.Context) error { @@ -230,6 +232,7 @@ func (s *CacheServer) Serve(ctx context.Context, cmd Commander, opt ...Option) e store, err = openStore(ctx, WithUrl(ctx, s.conf.DatabaseUrl), WithLogger(ctx, s.logger), + WithForceResetSchema(ctx, s.conf.ForceResetSchema), ) if err != nil { return errors.Wrap(ctx, err, op) @@ -537,6 +540,7 @@ func openStore(ctx context.Context, opt ...Option) (*db.DB, error) { dbOpts = append(dbOpts, cachedb.WithGormFormatter(opts.withLogger)) } } + dbOpts = append(dbOpts, cachedb.WithForceResetSchema(opts.withForceResetSchema)) store, err := cachedb.Open(ctx, dbOpts...) if err != nil { return nil, errors.Wrap(ctx, err, op) diff --git a/internal/clientcache/internal/daemon/testing.go b/internal/clientcache/internal/daemon/testing.go index 7faa91a8c2..0689d9cd4b 100644 --- a/internal/clientcache/internal/daemon/testing.go +++ b/internal/clientcache/internal/daemon/testing.go @@ -97,11 +97,13 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, alts [] } return alts, nil, "addedaliases", nil } - tarFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*targets.Target, []string, cache.RefreshTokenValue, error) { + tarFn := func(ctx context.Context, _ string, tok string, _ cache.RefreshTokenValue, inPage *targets.TargetListResult, opt ...cache.Option) (*targets.TargetListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return nil, "", nil } - return tars, nil, "addedtargets", nil + return &targets.TargetListResult{ + Items: tars, + }, "addedtargets", nil } sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { if tok != p.Token { @@ -124,13 +126,15 @@ func (s *TestServer) AddUnsupportedCachingData(t *testing.T, p *authtokens.AuthT r, err := cache.NewRepository(ctx, s.CacheServer.store.Load(), &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn) require.NoError(t, err) - tarFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*targets.Target, []string, cache.RefreshTokenValue, error) { + tarFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue, inPage *targets.TargetListResult, opt ...cache.Option) (*targets.TargetListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return &targets.TargetListResult{}, "", nil } - return []*targets.Target{ - {Id: "ttcp_unsupported", Name: "unsupported", Description: "not supported"}, - }, nil, "", cache.ErrRefreshNotSupported + return &targets.TargetListResult{ + Items: []*targets.Target{ + {Id: "ttcp_unsupported", Name: "unsupported", Description: "not supported"}, + }, + }, "", cache.ErrRefreshNotSupported } sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { if tok != p.Token { diff --git a/internal/clientcache/internal/db/db.go b/internal/clientcache/internal/db/db.go index 4690c6e10b..9063cf3725 100644 --- a/internal/clientcache/internal/db/db.go +++ b/internal/clientcache/internal/db/db.go @@ -69,7 +69,7 @@ func Open(ctx context.Context, opt ...Option) (*db.DB, error) { if err != nil { return nil, errors.Wrap(ctx, err, op) } - if !ok { + if !ok || opts.withForceResetSchema { if err := resetSchema(ctx, conn); err != nil { return nil, errors.Wrap(ctx, err, op) } diff --git a/internal/clientcache/internal/db/options.go b/internal/clientcache/internal/db/options.go index 30186c90b1..6300ba840b 100644 --- a/internal/clientcache/internal/db/options.go +++ b/internal/clientcache/internal/db/options.go @@ -9,11 +9,12 @@ import ( ) type options struct { - withSchemaVersion string - withDebug bool - withUrl string - withDbType dbw.DbType - withGormFormatter hclog.Logger + withSchemaVersion string + withDebug bool + withUrl string + withDbType dbw.DbType + withGormFormatter hclog.Logger + withForceResetSchema bool } // Option - how options are passed as args @@ -67,3 +68,11 @@ func WithDebug(debug bool) Option { return nil } } + +// WithForceResetSchema provides an optional way to force resetting the cache +func WithForceResetSchema(debug bool) Option { + return func(o *options) error { + o.withForceResetSchema = debug + return nil + } +} diff --git a/internal/clientcache/internal/db/options_test.go b/internal/clientcache/internal/db/options_test.go index b67096b081..251f20c3ff 100644 --- a/internal/clientcache/internal/db/options_test.go +++ b/internal/clientcache/internal/db/options_test.go @@ -45,4 +45,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withSchemaVersion = version assert.Equal(t, opts, testOpts) }) + t.Run("WithForceResetSchema", func(t *testing.T) { + opts, err := getOpts(WithForceResetSchema(true)) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.False(t, testOpts.withForceResetSchema) + testOpts.withForceResetSchema = true + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/tests/api/targets/target_test.go b/internal/tests/api/targets/target_test.go index 0a6263abe9..72a6bb1ea9 100644 --- a/internal/tests/api/targets/target_test.go +++ b/internal/tests/api/targets/target_test.go @@ -336,6 +336,39 @@ func TestListWithListToken(t *testing.T) { require.Empty(res.Items) } +func TestListWithPageSize(t *testing.T) { + // Set database read timeout to avoid duplicates in response + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + require := require.New(t) + tc := controller.NewTestController(t, nil) + defer tc.Shutdown() + + client := tc.Client() + token := tc.Token() + client.SetToken(token.Token) + _, proj := iam.TestScopes(t, tc.IamRepo(), iam.WithUserId(token.UserId)) + + tarClient := targets.NewClient(client) + _, err := tarClient.Create(tc.Context(), "tcp", proj.GetPublicId(), targets.WithName("1"), targets.WithTcpTargetDefaultPort(2)) + require.NoError(err) + _, err = tarClient.Create(tc.Context(), "tcp", proj.GetPublicId(), targets.WithName("2"), targets.WithTcpTargetDefaultPort(2)) + require.NoError(err) + + // Refresh tokens recursive listing over global scope + res, err := tarClient.List(tc.Context(), "global", targets.WithRecursive(true), targets.WithPageSize(2)) + require.NoError(err) + require.Len(res.Items, 4, "expected the 2 targets created above and the 2 auto created for the test controller") + refTok := res.ListToken + + res, err = tarClient.List(tc.Context(), "global", targets.WithRecursive(true), targets.WithListToken(refTok)) + require.NoError(err) + require.Empty(res.Items) +} + func TestTarget_AddressMutualExclusiveRelationship(t *testing.T) { tc := controller.NewTestController(t, nil) From 2945692bf4fbd90b35ad9a42c5016a6fc96d25ec Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Fri, 13 Sep 2024 15:59:49 -0400 Subject: [PATCH 09/15] Client cache API paging (#5107) * internal/clientcache: add -force-reset-schema flag * clientcache: stream list pages directly to DB --------- Co-authored-by: Johan Brandhorst-Satzkorn (cherry picked from commit ccea8a4ce53f7b0d1105c1c1e4875e443635d96a) --- api/accounts/account.gen.go | 14 +- api/accounts/option.gen.go | 8 + api/aliases/alias.gen.go | 14 +- api/aliases/option.gen.go | 8 + api/authmethods/authmethods.gen.go | 14 +- api/authmethods/option.gen.go | 8 + api/authtokens/authtokens.gen.go | 14 +- api/authtokens/option.gen.go | 8 + api/billing/option.gen.go | 8 + .../credential_library.gen.go | 14 +- api/credentiallibraries/option.gen.go | 8 + api/credentials/credential.gen.go | 14 +- api/credentials/option.gen.go | 8 + api/credentialstores/credential_store.gen.go | 14 +- api/credentialstores/option.gen.go | 8 + api/groups/group.gen.go | 14 +- api/groups/option.gen.go | 8 + api/hostcatalogs/host_catalog.gen.go | 14 +- api/hostcatalogs/option.gen.go | 8 + api/hosts/host.gen.go | 14 +- api/hosts/option.gen.go | 8 + api/hostsets/host_set.gen.go | 14 +- api/hostsets/option.gen.go | 8 + api/managedgroups/managedgroups.gen.go | 14 +- api/managedgroups/option.gen.go | 8 + api/policies/option.gen.go | 8 + api/policies/policy.gen.go | 14 +- api/roles/option.gen.go | 8 + api/roles/role.gen.go | 14 +- api/scopes/option.gen.go | 8 + api/scopes/scope.gen.go | 14 +- api/sessionrecordings/option.gen.go | 8 + .../session_recording.gen.go | 14 +- api/sessions/option.gen.go | 8 + api/sessions/session.gen.go | 14 +- api/storagebuckets/option.gen.go | 8 + api/storagebuckets/storage_bucket.gen.go | 14 +- api/targets/option.gen.go | 8 + api/targets/target.gen.go | 14 +- api/users/custom.go | 161 ++++----------- api/users/option.gen.go | 8 + api/users/user.gen.go | 14 +- api/workers/option.gen.go | 8 + api/workers/worker.gen.go | 7 +- internal/api/genapi/templates.go | 26 ++- .../internal/cache/options_test.go | 8 +- .../internal/cache/refresh_test.go | 184 ++++++++++-------- .../cache/repository_implicit_scopes_test.go | 2 +- .../cache/repository_resolvable_aliases.go | 116 ++++++----- .../repository_resolvable_aliases_test.go | 30 +-- .../internal/cache/repository_sessions.go | 124 +++++++----- .../cache/repository_sessions_test.go | 30 +-- .../clientcache/internal/cache/status_test.go | 8 +- .../clientcache/internal/daemon/testing.go | 26 ++- 54 files changed, 787 insertions(+), 391 deletions(-) diff --git a/api/accounts/account.gen.go b/api/accounts/account.gen.go index c55e8b1cb7..c80533ada1 100644 --- a/api/accounts/account.gen.go +++ b/api/accounts/account.gen.go @@ -321,7 +321,12 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( opts, apiOpts := getOpts(opt...) opts.queryMap["auth_method_id"] = authMethodId - req, err := c.client.NewRequest(ctx, "GET", "accounts", nil, apiOpts...) + requestPath := "accounts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -470,7 +475,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AccountListResul opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "accounts", nil, apiOpts...) + requestPath := "accounts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/accounts/option.gen.go b/api/accounts/option.gen.go index 384121c547..304f7212c1 100644 --- a/api/accounts/option.gen.go +++ b/api/accounts/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/aliases/alias.gen.go b/api/aliases/alias.gen.go index 54b3803cfe..233e4a2a6b 100644 --- a/api/aliases/alias.gen.go +++ b/api/aliases/alias.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Alia opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "aliases", nil, apiOpts...) + requestPath := "aliases" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -483,7 +488,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AliasListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "aliases", nil, apiOpts...) + requestPath := "aliases" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/aliases/option.gen.go b/api/aliases/option.gen.go index 9ec819435b..aa7d59a4ad 100644 --- a/api/aliases/option.gen.go +++ b/api/aliases/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/authmethods/authmethods.gen.go b/api/authmethods/authmethods.gen.go index afd51e5d61..7e48a76564 100644 --- a/api/authmethods/authmethods.gen.go +++ b/api/authmethods/authmethods.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "auth-methods", nil, apiOpts...) + requestPath := "auth-methods" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -483,7 +488,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AuthMethodListRe opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "auth-methods", nil, apiOpts...) + requestPath := "auth-methods" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/authmethods/option.gen.go b/api/authmethods/option.gen.go index 079c35a855..4ad2f57919 100644 --- a/api/authmethods/option.gen.go +++ b/api/authmethods/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/authtokens/authtokens.gen.go b/api/authtokens/authtokens.gen.go index ced3e321d6..90fbbd69a0 100644 --- a/api/authtokens/authtokens.gen.go +++ b/api/authtokens/authtokens.gen.go @@ -212,7 +212,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Auth opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "auth-tokens", nil, apiOpts...) + requestPath := "auth-tokens" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -368,7 +373,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *AuthTokenListRes opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "auth-tokens", nil, apiOpts...) + requestPath := "auth-tokens" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/authtokens/option.gen.go b/api/authtokens/option.gen.go index 9fa43031a5..83b09a7dba 100644 --- a/api/authtokens/option.gen.go +++ b/api/authtokens/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -105,6 +106,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/billing/option.gen.go b/api/billing/option.gen.go index 76babe8e89..04692ceb84 100644 --- a/api/billing/option.gen.go +++ b/api/billing/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -112,6 +113,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithEndTime(inEndTime string) Option { return func(o *options) { o.queryMap["end_time"] = fmt.Sprintf("%v", inEndTime) diff --git a/api/credentiallibraries/credential_library.gen.go b/api/credentiallibraries/credential_library.gen.go index 7206eb892f..7a0d2998fd 100644 --- a/api/credentiallibraries/credential_library.gen.go +++ b/api/credentiallibraries/credential_library.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti opts, apiOpts := getOpts(opt...) opts.queryMap["credential_store_id"] = credentialStoreId - req, err := c.client.NewRequest(ctx, "GET", "credential-libraries", nil, apiOpts...) + requestPath := "credential-libraries" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -476,7 +481,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialLibrar opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "credential-libraries", nil, apiOpts...) + requestPath := "credential-libraries" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/credentiallibraries/option.gen.go b/api/credentiallibraries/option.gen.go index a4dc80b1d4..ffbadffd77 100644 --- a/api/credentiallibraries/option.gen.go +++ b/api/credentiallibraries/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithVaultSSHCertificateCredentialLibraryAdditionalValidPrincipals(inAdditionalValidPrincipals []string) Option { return func(o *options) { raw, ok := o.postMap["attributes"] diff --git a/api/credentials/credential.gen.go b/api/credentials/credential.gen.go index d77e7d1283..bb723871f2 100644 --- a/api/credentials/credential.gen.go +++ b/api/credentials/credential.gen.go @@ -325,7 +325,12 @@ func (c *Client) List(ctx context.Context, credentialStoreId string, opt ...Opti opts, apiOpts := getOpts(opt...) opts.queryMap["credential_store_id"] = credentialStoreId - req, err := c.client.NewRequest(ctx, "GET", "credentials", nil, apiOpts...) + requestPath := "credentials" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -474,7 +479,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialListRe opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "credentials", nil, apiOpts...) + requestPath := "credentials" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/credentials/option.gen.go b/api/credentials/option.gen.go index 7cf0fb75e0..c6508253cf 100644 --- a/api/credentials/option.gen.go +++ b/api/credentials/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/credentialstores/credential_store.gen.go b/api/credentialstores/credential_store.gen.go index 055af8150b..07b1bc1e33 100644 --- a/api/credentialstores/credential_store.gen.go +++ b/api/credentialstores/credential_store.gen.go @@ -326,7 +326,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Cred opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "credential-stores", nil, apiOpts...) + requestPath := "credential-stores" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -482,7 +487,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *CredentialStoreL opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "credential-stores", nil, apiOpts...) + requestPath := "credential-stores" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/credentialstores/option.gen.go b/api/credentialstores/option.gen.go index eb44bdecf4..86e670af0c 100644 --- a/api/credentialstores/option.gen.go +++ b/api/credentialstores/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/groups/group.gen.go b/api/groups/group.gen.go index ed0c31b65b..dc9b249412 100644 --- a/api/groups/group.gen.go +++ b/api/groups/group.gen.go @@ -320,7 +320,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Grou opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "groups", nil, apiOpts...) + requestPath := "groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -476,7 +481,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *GroupListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "groups", nil, apiOpts...) + requestPath := "groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/groups/option.gen.go b/api/groups/option.gen.go index d2e7b5a3d4..81553cd4a4 100644 --- a/api/groups/option.gen.go +++ b/api/groups/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/hostcatalogs/host_catalog.gen.go b/api/hostcatalogs/host_catalog.gen.go index 5e4feb45a8..43d907a11c 100644 --- a/api/hostcatalogs/host_catalog.gen.go +++ b/api/hostcatalogs/host_catalog.gen.go @@ -331,7 +331,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Host opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "host-catalogs", nil, apiOpts...) + requestPath := "host-catalogs" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -487,7 +492,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *HostCatalogListR opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "host-catalogs", nil, apiOpts...) + requestPath := "host-catalogs" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/hostcatalogs/option.gen.go b/api/hostcatalogs/option.gen.go index fe5d26fd1e..fdd1496b28 100644 --- a/api/hostcatalogs/option.gen.go +++ b/api/hostcatalogs/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/hosts/host.gen.go b/api/hosts/host.gen.go index 016869f447..015414e40d 100644 --- a/api/hosts/host.gen.go +++ b/api/hosts/host.gen.go @@ -327,7 +327,12 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) opts, apiOpts := getOpts(opt...) opts.queryMap["host_catalog_id"] = hostCatalogId - req, err := c.client.NewRequest(ctx, "GET", "hosts", nil, apiOpts...) + requestPath := "hosts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -476,7 +481,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *HostListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "hosts", nil, apiOpts...) + requestPath := "hosts" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/hosts/option.gen.go b/api/hosts/option.gen.go index 32809d61c2..d15e5dd338 100644 --- a/api/hosts/option.gen.go +++ b/api/hosts/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithStaticHostAddress(inAddress string) Option { return func(o *options) { raw, ok := o.postMap["attributes"] diff --git a/api/hostsets/host_set.gen.go b/api/hostsets/host_set.gen.go index 7ce7e6640b..688f4fc6c1 100644 --- a/api/hostsets/host_set.gen.go +++ b/api/hostsets/host_set.gen.go @@ -325,7 +325,12 @@ func (c *Client) List(ctx context.Context, hostCatalogId string, opt ...Option) opts, apiOpts := getOpts(opt...) opts.queryMap["host_catalog_id"] = hostCatalogId - req, err := c.client.NewRequest(ctx, "GET", "host-sets", nil, apiOpts...) + requestPath := "host-sets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -474,7 +479,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *HostSetListResul opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "host-sets", nil, apiOpts...) + requestPath := "host-sets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/hostsets/option.gen.go b/api/hostsets/option.gen.go index fa81eddb7b..ec67230598 100644 --- a/api/hostsets/option.gen.go +++ b/api/hostsets/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/managedgroups/managedgroups.gen.go b/api/managedgroups/managedgroups.gen.go index 02cc9d5be8..e0eeb668f1 100644 --- a/api/managedgroups/managedgroups.gen.go +++ b/api/managedgroups/managedgroups.gen.go @@ -321,7 +321,12 @@ func (c *Client) List(ctx context.Context, authMethodId string, opt ...Option) ( opts, apiOpts := getOpts(opt...) opts.queryMap["auth_method_id"] = authMethodId - req, err := c.client.NewRequest(ctx, "GET", "managed-groups", nil, apiOpts...) + requestPath := "managed-groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -470,7 +475,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *ManagedGroupList opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "managed-groups", nil, apiOpts...) + requestPath := "managed-groups" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/managedgroups/option.gen.go b/api/managedgroups/option.gen.go index 7da2f0803a..5c13b52fff 100644 --- a/api/managedgroups/option.gen.go +++ b/api/managedgroups/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string } func getDefaultOptions() options { @@ -111,6 +112,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + func WithAttributes(inAttributes map[string]interface{}) Option { return func(o *options) { o.postMap["attributes"] = inAttributes diff --git a/api/policies/option.gen.go b/api/policies/option.gen.go index 376e19c66b..53dccd6597 100644 --- a/api/policies/option.gen.go +++ b/api/policies/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/policies/policy.gen.go b/api/policies/policy.gen.go index bf105a2baa..3fb60ce7bd 100644 --- a/api/policies/policy.gen.go +++ b/api/policies/policy.gen.go @@ -325,7 +325,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Poli opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "policies", nil, apiOpts...) + requestPath := "policies" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -481,7 +486,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *PolicyListResult opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "policies", nil, apiOpts...) + requestPath := "policies" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/roles/option.gen.go b/api/roles/option.gen.go index 06cda67ead..2671a82d36 100644 --- a/api/roles/option.gen.go +++ b/api/roles/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/roles/role.gen.go b/api/roles/role.gen.go index 7b6b550711..b9c052c2f2 100644 --- a/api/roles/role.gen.go +++ b/api/roles/role.gen.go @@ -323,7 +323,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Role opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "roles", nil, apiOpts...) + requestPath := "roles" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -479,7 +484,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *RoleListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "roles", nil, apiOpts...) + requestPath := "roles" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/scopes/option.gen.go b/api/scopes/option.gen.go index d20936ba0c..aef8586a37 100644 --- a/api/scopes/option.gen.go +++ b/api/scopes/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/scopes/scope.gen.go b/api/scopes/scope.gen.go index 93119cd5e3..9752443562 100644 --- a/api/scopes/scope.gen.go +++ b/api/scopes/scope.gen.go @@ -321,7 +321,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Scop opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "scopes", nil, apiOpts...) + requestPath := "scopes" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -477,7 +482,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *ScopeListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "scopes", nil, apiOpts...) + requestPath := "scopes" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/sessionrecordings/option.gen.go b/api/sessionrecordings/option.gen.go index beae7de2ab..d59df635a7 100644 --- a/api/sessionrecordings/option.gen.go +++ b/api/sessionrecordings/option.gen.go @@ -28,6 +28,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -95,6 +96,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/sessionrecordings/session_recording.gen.go b/api/sessionrecordings/session_recording.gen.go index 164880cf93..1c9bcbe770 100644 --- a/api/sessionrecordings/session_recording.gen.go +++ b/api/sessionrecordings/session_recording.gen.go @@ -219,7 +219,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "session-recordings", nil, apiOpts...) + requestPath := "session-recordings" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -375,7 +380,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *SessionRecording opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "session-recordings", nil, apiOpts...) + requestPath := "session-recordings" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/sessions/option.gen.go b/api/sessions/option.gen.go index 3887657af2..b9799ff3ee 100644 --- a/api/sessions/option.gen.go +++ b/api/sessions/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/sessions/session.gen.go b/api/sessions/session.gen.go index 5ffaa9f6fe..63cdb047a4 100644 --- a/api/sessions/session.gen.go +++ b/api/sessions/session.gen.go @@ -179,7 +179,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Sess opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "sessions", nil, apiOpts...) + requestPath := "sessions" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -335,7 +340,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *SessionListResul opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "sessions", nil, apiOpts...) + requestPath := "sessions" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/storagebuckets/option.gen.go b/api/storagebuckets/option.gen.go index fc97a45869..9d8e30aba5 100644 --- a/api/storagebuckets/option.gen.go +++ b/api/storagebuckets/option.gen.go @@ -30,6 +30,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -116,6 +117,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/storagebuckets/storage_bucket.gen.go b/api/storagebuckets/storage_bucket.gen.go index b232f5a2db..834671e02f 100644 --- a/api/storagebuckets/storage_bucket.gen.go +++ b/api/storagebuckets/storage_bucket.gen.go @@ -329,7 +329,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Stor opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "storage-buckets", nil, apiOpts...) + requestPath := "storage-buckets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -485,7 +490,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *StorageBucketLis opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "storage-buckets", nil, apiOpts...) + requestPath := "storage-buckets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/targets/option.gen.go b/api/targets/option.gen.go index 4d71168be6..80351506f7 100644 --- a/api/targets/option.gen.go +++ b/api/targets/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/targets/target.gen.go b/api/targets/target.gen.go index 6d7ad43315..4743dbe644 100644 --- a/api/targets/target.gen.go +++ b/api/targets/target.gen.go @@ -339,7 +339,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Targ opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "targets", nil, apiOpts...) + requestPath := "targets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -495,7 +500,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *TargetListResult opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "targets", nil, apiOpts...) + requestPath := "targets" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/users/custom.go b/api/users/custom.go index 3c33c3fd5e..bd211cec83 100644 --- a/api/users/custom.go +++ b/api/users/custom.go @@ -5,10 +5,8 @@ package users import ( "context" - "encoding/json" "fmt" "net/url" - "slices" "github.com/hashicorp/boundary/api/aliases" ) @@ -25,135 +23,44 @@ func (c *Client) ListResolvableAliases(ctx context.Context, userId string, opt . return nil, fmt.Errorf("nil client") } - opts, apiOpts := getOpts(opt...) - req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId)), nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) - } + opts, _ := getOpts(opt...) + apiClient := aliases.NewClient(c.client) + return apiClient.List(ctx, "global", + aliases.WithAutomaticVersioning(opts.withAutomaticVersioning), + aliases.WithSkipCurlOutput(opts.withSkipCurlOutput), + aliases.WithFilter(opts.withFilter), + aliases.WithListToken(opts.withListToken), + aliases.WithClientDirectedPagination(opts.withClientDirectedPagination), + aliases.WithPageSize(opts.withPageSize), + aliases.WithRecursive(opts.withRecursive), + aliases.WithResourcePathOverride(fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId))), + ) +} - target := new(aliases.AliasListResult) - apiErr, err := resp.Decode(target) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) +func (c *Client) ListResolvableAliasesNextPage(ctx context.Context, userId string, currentPage *aliases.AliasListResult, opt ...Option) (*aliases.AliasListResult, error) { + if currentPage == nil { + return nil, fmt.Errorf("empty currentPage value passed into ListResolvableAliasesNextPage request") } - if apiErr != nil { - return nil, apiErr + if userId == "" { + return nil, fmt.Errorf("empty userId value passed into ListResolvableAliasesNextPage request") } - target.Response = resp - if target.ResponseType == "complete" || target.ResponseType == "" { - return target, nil + if c.client == nil { + return nil, fmt.Errorf("nil client") } - // If there are more results, automatically fetch the rest of the results. - // idToIndex keeps a map from the ID of an item to its index in target.Items. - // This is used to update updated items in-place and remove deleted items - // from the result after pagination is done. - idToIndex := map[string]int{} - for i, item := range target.Items { - idToIndex[item.Id] = i + if currentPage.ResponseType == "complete" || currentPage.ResponseType == "" { + return nil, fmt.Errorf("no more pages available in ListResolvableAliasesNextPage request") } - for { - req, err := c.client.NewRequest(ctx, "GET", fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId)), nil, apiOpts...) - if err != nil { - return nil, fmt.Errorf("error creating List request: %w", err) - } - - opts.queryMap["list_token"] = target.ListToken - if len(opts.queryMap) > 0 { - q := url.Values{} - for k, v := range opts.queryMap { - q.Add(k, v) - } - req.URL.RawQuery = q.Encode() - } - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("error performing client request during List call: %w", err) - } - - page := new(aliases.AliasListResult) - apiErr, err := resp.Decode(page) - if err != nil { - return nil, fmt.Errorf("error decoding List response: %w", err) - } - if apiErr != nil { - return nil, apiErr - } - for _, item := range page.Items { - if i, ok := idToIndex[item.Id]; ok { - // Item has already been seen at index i, update in-place - target.Items[i] = item - } else { - target.Items = append(target.Items, item) - idToIndex[item.Id] = len(target.Items) - 1 - } - } - // RemovedIds contain any Alias that were deleted since the last response. - target.RemovedIds = append(target.RemovedIds, page.RemovedIds...) - target.EstItemCount = page.EstItemCount - target.ListToken = page.ListToken - target.ResponseType = page.ResponseType - target.Response = resp - if target.ResponseType == "complete" { - break - } - } - // For now, removedIds will only be populated if this pagination cycle was the result of a - // "refresh" operation (i.e., the caller provided a list token option to this call). - // - // Sort to make response deterministic - slices.Sort(target.RemovedIds) - // Remove any duplicates - target.RemovedIds = slices.Compact(target.RemovedIds) - // Remove items that were deleted since the end of the last iteration. - // If an Alias has been updated and subsequently removed, we don't want - // it to appear both in the Items and RemovedIds, so we remove it from the Items. - for _, removedId := range target.RemovedIds { - if i, ok := idToIndex[removedId]; ok { - // Remove the item at index i without preserving order - // https://github.com/golang/go/wiki/SliceTricks#delete-without-preserving-order - target.Items[i] = target.Items[len(target.Items)-1] - target.Items = target.Items[:len(target.Items)-1] - // Update the index of the previously last element - idToIndex[target.Items[i].Id] = i - } - } - // Since we paginated to the end, we can avoid confusion - // for the user by setting the estimated item count to the - // length of the items slice. If we don't set this here, it - // will equal the value returned in the last response, which is - // often much smaller than the total number returned. - target.EstItemCount = uint(len(target.Items)) - // Sort the results again since in-place updates and deletes - // may have shuffled items. We sort by created time descending - // (most recently created first), same as the API. - slices.SortFunc(target.Items, func(i, j *aliases.Alias) int { - return j.CreatedTime.Compare(i.CreatedTime) - }) - // Finally, since we made at least 2 requests to the server to fulfill this - // function call, resp.Body and resp.Map will only contain the most recent response. - // Overwrite them with the true response. - target.GetResponse().Body.Reset() - if err := json.NewEncoder(target.GetResponse().Body).Encode(target); err != nil { - return nil, fmt.Errorf("error encoding final JSON list response: %w", err) - } - if err := json.Unmarshal(target.GetResponse().Body.Bytes(), &target.GetResponse().Map); err != nil { - return nil, fmt.Errorf("error encoding final map list response: %w", err) - } - // Note: the HTTP response body is consumed by resp.Decode in the loop, - // so it doesn't need to be updated (it will always be, and has always been, empty). - return target, nil + opts, _ := getOpts(opt...) + apiClient := aliases.NewClient(c.client) + return apiClient.ListNextPage(ctx, currentPage, + aliases.WithAutomaticVersioning(opts.withAutomaticVersioning), + aliases.WithSkipCurlOutput(opts.withSkipCurlOutput), + aliases.WithFilter(opts.withFilter), + aliases.WithListToken(opts.withListToken), + aliases.WithClientDirectedPagination(opts.withClientDirectedPagination), + aliases.WithPageSize(opts.withPageSize), + aliases.WithRecursive(opts.withRecursive), + aliases.WithResourcePathOverride(fmt.Sprintf("users/%s:list-resolvable-aliases", url.PathEscape(userId))), + ) } diff --git a/api/users/option.gen.go b/api/users/option.gen.go index a6349cde10..170754148a 100644 --- a/api/users/option.gen.go +++ b/api/users/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/users/user.gen.go b/api/users/user.gen.go index 51065721f9..0728043f6e 100644 --- a/api/users/user.gen.go +++ b/api/users/user.gen.go @@ -324,7 +324,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*User opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "users", nil, apiOpts...) + requestPath := "users" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -480,7 +485,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *UserListResult, opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "users", nil, apiOpts...) + requestPath := "users" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/api/workers/option.gen.go b/api/workers/option.gen.go index 1ca3ada826..d5af29ef43 100644 --- a/api/workers/option.gen.go +++ b/api/workers/option.gen.go @@ -29,6 +29,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string withRecursive bool } @@ -115,6 +116,13 @@ func WithPageSize(with uint32) Option { } } +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} + // WithRecursive tells the API to use recursion for listing operations on this // resource func WithRecursive(recurse bool) Option { diff --git a/api/workers/worker.gen.go b/api/workers/worker.gen.go index efe18d859e..c55636263c 100644 --- a/api/workers/worker.gen.go +++ b/api/workers/worker.gen.go @@ -377,7 +377,12 @@ func (c *Client) List(ctx context.Context, scopeId string, opt ...Option) (*Work opts, apiOpts := getOpts(opt...) opts.queryMap["scope_id"] = scopeId - req, err := c.client.NewRequest(ctx, "GET", "workers", nil, apiOpts...) + requestPath := "workers" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } diff --git a/internal/api/genapi/templates.go b/internal/api/genapi/templates.go index aa37311e5a..7899fb1c3e 100644 --- a/internal/api/genapi/templates.go +++ b/internal/api/genapi/templates.go @@ -155,7 +155,7 @@ func fillTemplates() { optionsMap[input.Package] = optionMap } // Override some defined options - if len(in.fieldOverrides) > 0 && optionsMap != nil { + if len(in.fieldOverrides) > 0 { for _, override := range in.fieldOverrides { inOpts := optionsMap[input.Package] if inOpts != nil { @@ -243,7 +243,12 @@ func (c *Client) List(ctx context.Context, {{ .CollectionFunctionArg }} string, opts, apiOpts := getOpts(opt...) opts.queryMap["{{ snakeCase .CollectionFunctionArg }}"] = {{ .CollectionFunctionArg }} - req, err := c.client.NewRequest(ctx, "GET", "{{ .CollectionPath }}", nil, apiOpts...) + requestPath := "{{ .CollectionPath }}" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -404,7 +409,12 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *{{ .Name }}ListR opts.queryMap["page_size"] = strconv.FormatUint(uint64(currentPage.pageSize), 10) } - req, err := c.client.NewRequest(ctx, "GET", "{{ .CollectionPath }}", nil, apiOpts...) + requestPath := "{{ .CollectionPath }}" + if opts.withResourcePathOverride != "" { + requestPath = opts.withResourcePathOverride + } + + req, err := c.client.NewRequest(ctx, "GET", requestPath, nil, apiOpts...) if err != nil { return nil, fmt.Errorf("error creating List request: %w", err) } @@ -434,7 +444,7 @@ func (c *Client) ListNextPage(ctx context.Context, currentPage *{{ .Name }}ListR // Ensure values are carried forward to the next call nextPage.{{ .CollectionFunctionArg }} = currentPage.{{ .CollectionFunctionArg }} -{{ if .RecursiveListing }} +{{ if .RecursiveListing }} nextPage.recursive = currentPage.recursive {{ end }} nextPage.pageSize = currentPage.pageSize @@ -943,6 +953,7 @@ type options struct { withListToken string withClientDirectedPagination bool withPageSize uint32 + withResourcePathOverride string {{ if .RecursiveListing }} withRecursive bool {{ end }} } @@ -1031,6 +1042,13 @@ func WithPageSize(with uint32) Option { o.withPageSize = with } } + +// WithResourcePathOverride tells the API to use the provided resource path +func WithResourcePathOverride(path string) Option { + return func(o *options) { + o.withResourcePathOverride = path + } +} {{ if .RecursiveListing }} // WithRecursive tells the API to use recursion for listing operations on this // resource diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 9f25faebe3..26810f8309 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -62,8 +62,8 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithSessionRetrievalFunc", func(t *testing.T) { - var f SessionRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { - return nil, nil, "", nil + var f SessionRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (*sessions.SessionListResult, RefreshTokenValue, error) { + return nil, "", nil } opts, err := getOpts(WithSessionRetrievalFunc(f)) require.NoError(t, err) @@ -75,8 +75,8 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithAliasRetrievalFunc", func(t *testing.T) { - var f ResolvableAliasRetrievalFunc = func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue) ([]*aliases.Alias, []string, RefreshTokenValue, error) { - return nil, nil, "", nil + var f ResolvableAliasRetrievalFunc = func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (*aliases.AliasListResult, RefreshTokenValue, error) { + return nil, "", nil } opts, err := getOpts(WithAliasRetrievalFunc(f)) require.NoError(t, err) diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index 0d45187279..3895376132 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -70,6 +70,38 @@ func testTargetStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, } } +func testSessionStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error)) SessionRetrievalFunc { + return func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (ret *sessions.SessionListResult, refreshToken RefreshTokenValue, err error) { + retSessions, removed, refreshToken, err := inFunc(ctx, addr, authTok, refreshTok) + if err != nil { + return nil, "", err + } + + ret = &sessions.SessionListResult{ + Items: retSessions, + RemovedIds: removed, + ResponseType: "complete", + } + return ret, refreshToken, nil + } +} + +func testResolvableAliasStaticResourceRetrievalFunc(inFunc func(ctx context.Context, s1, s2, s3 string, refToken RefreshTokenValue) ([]*aliases.Alias, []string, RefreshTokenValue, error)) ResolvableAliasRetrievalFunc { + return func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (ret *aliases.AliasListResult, refreshToken RefreshTokenValue, err error) { + retSessions, removed, refreshToken, err := inFunc(ctx, addr, authTok, userId, refreshTok) + if err != nil { + return nil, "", err + } + + ret = &aliases.AliasListResult{ + Items: retSessions, + RemovedIds: removed, + ResponseType: "complete", + } + return ret, refreshToken, nil + } +} + // testNoRefreshRetrievalFunc simulates a controller that doesn't support refresh // since it does not return any refresh token. func testNoRefreshRetrievalFunc[T any](t *testing.T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { @@ -440,8 +472,8 @@ func TestRefreshForSearch(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -495,8 +527,8 @@ func TestRefreshForSearch(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -553,9 +585,9 @@ func TestRefreshForSearch(t *testing.T) { // Get the first set of resources, but no refresh tokens err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), + WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.ErrorContains(t, err, ErrRefreshNotSupported.Error()) got, err := r.ListTargets(ctx, at.Id) @@ -569,15 +601,15 @@ func TestRefreshForSearch(t *testing.T) { // wont be refreshed any more, and we wont see the error when refreshing // any more. err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.Nil(t, err) err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.Nil(t, err) got, err = r.ListTargets(ctx, at.Id) @@ -590,9 +622,9 @@ func TestRefreshForSearch(t *testing.T) { // Now simulate the controller updating to support refresh tokens and // the resources starting to be cached. err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{retTargets}, [][]string{{}}))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), ) assert.Nil(t, err, err) @@ -617,9 +649,9 @@ func TestRefreshForSearch(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -628,7 +660,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no sessions were already synced yet @@ -669,9 +701,9 @@ func TestRefreshForSearch(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -680,7 +712,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no sessions were already synced yet @@ -726,9 +758,9 @@ func TestRefreshForSearch(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAl[:3], retAl[3:], @@ -737,7 +769,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retAl[0].Id, retAl[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no aliases were already synced yet @@ -778,9 +810,9 @@ func TestRefreshForSearch(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAls[:3], retAls[3:], @@ -789,7 +821,7 @@ func TestRefreshForSearch(t *testing.T) { nil, {retAls[0].Id, retAls[1].Id}, }, - )), + ))), } // First call doesn't sync anything because no aliases were already synced yet @@ -854,8 +886,8 @@ func TestRefreshNonBlocking(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -915,9 +947,9 @@ func TestRefreshNonBlocking(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -926,7 +958,7 @@ func TestRefreshNonBlocking(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } refreshWaitChs := &testRefreshWaitChs{ @@ -977,9 +1009,9 @@ func TestRefreshNonBlocking(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAl[:3], retAl[3:], @@ -988,7 +1020,7 @@ func TestRefreshNonBlocking(t *testing.T) { nil, {retAl[0].Id, retAl[1].Id}, }, - )), + ))), } refreshWaitChs := &testRefreshWaitChs{ @@ -1056,8 +1088,8 @@ func TestRefresh(t *testing.T) { target("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, [][]*targets.Target{ retTargets[:3], @@ -1098,9 +1130,9 @@ func TestRefresh(t *testing.T) { session("4"), } opts := []Option{ - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, [][]*sessions.Session{ retSess[:3], retSess[3:], @@ -1109,7 +1141,7 @@ func TestRefresh(t *testing.T) { nil, {retSess[0].Id, retSess[1].Id}, }, - )), + ))), } assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) @@ -1139,9 +1171,9 @@ func TestRefresh(t *testing.T) { alias("4"), } opts := []Option{ - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, [][]*aliases.Alias{ retAls[:3], retAls[3:], @@ -1150,7 +1182,7 @@ func TestRefresh(t *testing.T) { nil, {retAls[0].Id, retAls[1].Id}, }, - )), + ))), } assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) @@ -1175,8 +1207,8 @@ func TestRefresh(t *testing.T) { innerErr := errors.New("test error") err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1184,13 +1216,13 @@ func TestRefresh(t *testing.T) { }))) assert.ErrorContains(t, err, innerErr.Error()) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))), - WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) return nil, nil, "", innerErr - })) + }))) assert.ErrorContains(t, err, innerErr.Error()) }) @@ -1217,8 +1249,8 @@ func TestRefresh(t *testing.T) { assert.Len(t, us, 1) require.NoError(t, rs.Refresh(ctx, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))))) ps, err = r.listTokens(ctx, u) @@ -1261,8 +1293,8 @@ func TestRecheckCachingSupport(t *testing.T) { // Since this user doesn't have any resources, the user's data will still // only get updated with a call to Refresh. assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))) got, err := r.ListTargets(ctx, at.Id) @@ -1273,8 +1305,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) @@ -1287,8 +1319,8 @@ func TestRecheckCachingSupport(t *testing.T) { // now a full fetch will work since the user has resources and no refresh token assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))))) }) @@ -1302,9 +1334,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err := r.ListSessions(ctx, at.Id) require.NoError(t, err) @@ -1314,9 +1346,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListSessions(ctx, at.Id) @@ -1327,9 +1359,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) assert.Empty(t, got.Targets) @@ -1348,9 +1380,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err := r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) @@ -1360,9 +1392,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -1373,9 +1405,9 @@ func TestRecheckCachingSupport(t *testing.T) { assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))))) got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) assert.Empty(t, got.Targets) @@ -1394,15 +1426,15 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) innerErr := errors.New("test error") err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1411,8 +1443,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorContains(t, err, innerErr.Error()) err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) @@ -1431,9 +1463,9 @@ func TestRecheckCachingSupport(t *testing.T) { require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) err = rs.Refresh(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t))), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) assert.ErrorIs(t, err, ErrRefreshNotSupported) // Remove the token from the keyring, see that we can still see the @@ -1450,8 +1482,8 @@ func TestRecheckCachingSupport(t *testing.T) { assert.Len(t, us, 1) err = rs.RecheckCachingSupport(ctx, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)), + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))), + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))), WithTargetRetrievalFunc(testTargetStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) assert.NoError(t, err) diff --git a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go index ca5f5c0b64..6363f4f96f 100644 --- a/internal/clientcache/internal/cache/repository_implicit_scopes_test.go +++ b/internal/clientcache/internal/cache/repository_implicit_scopes_test.go @@ -100,7 +100,7 @@ func TestRepository_ImplicitScopes(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))) expectedScopes = append(expectedScopes, &scopes.Scope{ Id: ss[0].ScopeId, diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases.go b/internal/clientcache/internal/cache/repository_resolvable_aliases.go index 39e3bf5f36..c0578e88e4 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases.go @@ -22,35 +22,45 @@ import ( // ResolvableAliasRetrievalFunc is a function that retrieves aliases // from the provided boundary addr using the provided token. -type ResolvableAliasRetrievalFunc func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue) (ret []*aliases.Alias, removedIds []string, refreshToken RefreshTokenValue, err error) +type ResolvableAliasRetrievalFunc func(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (ret *aliases.AliasListResult, refreshToken RefreshTokenValue, err error) -func defaultResolvableAliasFunc(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue) ([]*aliases.Alias, []string, RefreshTokenValue, error) { +func defaultResolvableAliasFunc(ctx context.Context, addr, authTok, userId string, refreshTok RefreshTokenValue, inPage *aliases.AliasListResult, opt ...Option) (*aliases.AliasListResult, RefreshTokenValue, error) { const op = "cache.defaultResolvableAliasFunc" conf, err := api.DefaultConfig() if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) + } + opts, err := getOpts(opt...) + if err != nil { + return nil, "", errors.Wrap(ctx, err, op) } conf.Addr = addr conf.Token = authTok client, err := api.NewClient(conf) if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } aClient := users.NewClient(client) - l, err := aClient.ListResolvableAliases(ctx, userId, users.WithListToken(string(refreshTok))) + var l *aliases.AliasListResult + switch inPage { + case nil: + l, err = aClient.ListResolvableAliases(ctx, userId, users.WithRecursive(true), users.WithListToken(string(refreshTok)), users.WithClientDirectedPagination(!opts.withUseNonPagedListing)) + default: + l, err = aClient.ListResolvableAliasesNextPage(ctx, userId, inPage, users.WithListToken(string(refreshTok))) + } if err != nil { if api.ErrInvalidListToken.Is(err) { - return nil, nil, "", err + return nil, "", err } - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } if l.ResponseType == "" { - return nil, nil, "", ErrRefreshNotSupported + return nil, "", ErrRefreshNotSupported } - return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil + return l, RefreshTokenValue(l.ListToken), nil } -// refreshResolvableAliases attempts to refresh the resolvabl aliases for the +// refreshResolvableAliases attempts to refresh the resolvable aliases for the // provided user using the provided tokens. If available, it uses the refresh // tokens in storage to retrieve and apply only the delta. func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { @@ -83,13 +93,13 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke // Find and use a token for retrieving aliases var gotResponse bool - var resp []*aliases.Alias + var currentPage *aliases.AliasListResult var newRefreshToken RefreshTokenValue + var foundAuthToken string var unsupportedCacheRequest bool - var removedIds []string var retErr error for at, t := range tokens { - resp, removedIds, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, oldRefreshTokenVal) + currentPage, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, oldRefreshTokenVal, currentPage) if api.ErrInvalidListToken.Is(err) { event.WriteSysEvent(ctx, op, "old list token is no longer valid, starting new initial fetch", "user_id", u.Id) if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { @@ -97,7 +107,7 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke } // try again without the refresh token oldRefreshToken = nil - resp, removedIds, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "") + currentPage, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "", currentPage) } if err != nil { if err == ErrRefreshNotSupported { @@ -107,6 +117,7 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke continue } } + foundAuthToken = t gotResponse = true break } @@ -121,44 +132,57 @@ func (r *Repository) refreshResolvableAliases(ctx context.Context, u *user, toke } var numDeleted int - _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - var err error - switch { - case oldRefreshToken == nil: - if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where fk_user_id = @fk_user_id", - []any{sql.Named("fk_user_id", u.Id)}); err != nil { - return err - } - case len(removedIds) > 0: - if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where id in @ids", - []any{sql.Named("ids", removedIds)}); err != nil { - return err - } - } - switch { - case unsupportedCacheRequest: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err + var numUpserted int + var clearPerformed bool + for { + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { + var err error + if (oldRefreshToken == nil || unsupportedCacheRequest) && !clearPerformed { + if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where fk_user_id = @fk_user_id", + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } + clearPerformed = true } - case newRefreshToken != "": - if err := upsertResolvableAliases(ctx, w, u, resp); err != nil { - return err + switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } + case newRefreshToken != "": + numUpserted += len(currentPage.Items) + if err := upsertResolvableAliases(ctx, w, u, currentPage.Items); err != nil { + return err + } + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } + default: + // controller supports caching, but doesn't have any resources } - if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { - return err + if !unsupportedCacheRequest && len(currentPage.RemovedIds) > 0 { + if numDeleted, err = w.Exec(ctx, "delete from resolvable_alias where id in @ids", + []any{sql.Named("ids", currentPage.RemovedIds)}); err != nil { + return err + } } - default: - // controller supports caching, but doesn't have any resources + return nil + }) + if unsupportedCacheRequest || currentPage.ResponseType == "" || currentPage.ResponseType == "complete" { + break } - return nil - }) + currentPage, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, foundAuthToken, u.Id, newRefreshToken, currentPage) + if err != nil { + break + } + } if err != nil { return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", numUpserted, "user_id", u.Id) return nil } @@ -187,12 +211,12 @@ func (r *Repository) checkCachingResolvableAliases(ctx context.Context, u *user, // Find and use a token for retrieving aliases var gotResponse bool - var resp []*aliases.Alias + var resp *aliases.AliasListResult var newRefreshToken RefreshTokenValue var unsupportedCacheRequest bool var retErr error for at, t := range tokens { - resp, _, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "") + resp, newRefreshToken, err = opts.withResolvableAliasRetrievalFunc(ctx, u.Address, t, u.Id, "", nil, WithUseNonPagedListing(true)) if err != nil { if err == ErrRefreshNotSupported { unsupportedCacheRequest = true @@ -227,7 +251,7 @@ func (r *Repository) checkCachingResolvableAliases(ctx context.Context, u *user, []any{sql.Named("fk_user_id", u.Id)}); err != nil { return err } - if err := upsertResolvableAliases(ctx, w, u, resp); err != nil { + if err := upsertResolvableAliases(ctx, w, u, resp.Items); err != nil { return err } if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { @@ -248,7 +272,7 @@ func (r *Repository) checkCachingResolvableAliases(ctx context.Context, u *user, if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "resolvable-aliases updated", "deleted", numDeleted, "upserted", len(resp.Items), "user_id", u.Id) return nil } diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go index 086bf17c4c..3096d1d5f7 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go @@ -139,7 +139,7 @@ func TestRepository_refreshAliases(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshResolvableAliases(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{tc.al}, [][]string{nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{tc.al}, [][]string{nil})))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) @@ -218,7 +218,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { } err = r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err := r.ListResolvableAliases(ctx, at.Id) @@ -228,7 +228,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { // Refreshing again uses the refresh token and get additional aliases, appending // them to the response err = r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -238,7 +238,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { // Refreshing again wont return any more resources, but also none should be // removed require.NoError(t, r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, ss, [][]string{nil, nil}))))) assert.NoError(t, err) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -247,7 +247,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testErroringForRefreshTokenRetrievalFuncForId(t, ss[0])))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testErroringForRefreshTokenRetrievalFuncForId(t, ss[0]))))) assert.NoError(t, err) got, err = r.ListResolvableAliases(ctx, at.Id) @@ -328,7 +328,7 @@ func TestRepository_ListAliases(t *testing.T) { }, } require.NoError(t, r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil}))))) t.Run("wrong user gets no aliases", func(t *testing.T) { l, err := r.ListResolvableAliases(ctx, kt2.AuthTokenId) @@ -372,7 +372,7 @@ func TestRepository_ListAliasesLimiting(t *testing.T) { ts = append(ts, alias("s"+strconv.Itoa(i))) } require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -498,7 +498,7 @@ func TestRepository_QueryAliases(t *testing.T) { }, } require.NoError(t, r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ss}, [][]string{nil}))))) t.Run("wrong token gets no aliases", func(t *testing.T) { l, err := r.QueryResolvableAliases(ctx, kt2.AuthTokenId, query) @@ -542,7 +542,7 @@ func TestRepository_QueryResolvableAliasesLimiting(t *testing.T) { ts = append(ts, alias("s"+strconv.Itoa(i))) } require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -593,15 +593,15 @@ func TestDefaultAliasRetrievalFunc(t *testing.T) { require.NoError(t, err) require.NotNil(t, tar1) - got, removed, refTok, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, "") + got, refTok, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, "", nil) assert.NoError(t, err) assert.NotEmpty(t, refTok) - assert.Empty(t, removed) - assert.Len(t, got, 1) + assert.Empty(t, got.RemovedIds) + assert.Len(t, got.Items, 1) - got2, removed2, refTok2, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, refTok) + got2, refTok2, err := defaultResolvableAliasFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, tc.Token().UserId, refTok, nil) assert.NoError(t, err) assert.NotEmpty(t, refTok2) - assert.Empty(t, removed2) - assert.Empty(t, got2) + assert.Empty(t, got2.RemovedIds) + assert.Empty(t, got2.Items) } diff --git a/internal/clientcache/internal/cache/repository_sessions.go b/internal/clientcache/internal/cache/repository_sessions.go index 904202252a..9d53498316 100644 --- a/internal/clientcache/internal/cache/repository_sessions.go +++ b/internal/clientcache/internal/cache/repository_sessions.go @@ -21,32 +21,42 @@ import ( // SessionRetrievalFunc is a function that retrieves sessions // from the provided boundary addr using the provided token. -type SessionRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*sessions.Session, removedIds []string, refreshToken RefreshTokenValue, err error) +type SessionRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (ret *sessions.SessionListResult, refreshToken RefreshTokenValue, err error) -func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { +func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue, inPage *sessions.SessionListResult, opt ...Option) (ret *sessions.SessionListResult, refreshToken RefreshTokenValue, err error) { const op = "cache.defaultSessionFunc" conf, err := api.DefaultConfig() if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) + } + opts, err := getOpts(opt...) + if err != nil { + return nil, "", errors.Wrap(ctx, err, op) } conf.Addr = addr conf.Token = authTok client, err := api.NewClient(conf) if err != nil { - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } sClient := sessions.NewClient(client) - l, err := sClient.List(ctx, "global", sessions.WithIncludeTerminated(true), sessions.WithRecursive(true), sessions.WithListToken(string(refreshTok))) + var l *sessions.SessionListResult + switch inPage { + case nil: + l, err = sClient.List(ctx, "global", sessions.WithIncludeTerminated(true), sessions.WithRecursive(true), sessions.WithListToken(string(refreshTok)), sessions.WithClientDirectedPagination(!opts.withUseNonPagedListing)) + default: + l, err = sClient.ListNextPage(ctx, inPage, sessions.WithListToken(string(refreshTok))) + } if err != nil { if api.ErrInvalidListToken.Is(err) { - return nil, nil, "", err + return nil, "", err } - return nil, nil, "", errors.Wrap(ctx, err, op) + return nil, "", errors.Wrap(ctx, err, op) } if l.ResponseType == "" { - return nil, nil, "", ErrRefreshNotSupported + return nil, "", ErrRefreshNotSupported } - return l.Items, l.RemovedIds, RefreshTokenValue(l.ListToken), nil + return l, RefreshTokenValue(l.ListToken), nil } // refreshSessions uses attempts to refresh the sessions for the provided user @@ -59,8 +69,6 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au return errors.New(ctx, errors.InvalidParameter, op, "user is nil") case u.Id == "": return errors.New(ctx, errors.InvalidParameter, op, "user id is missing") - case u.Address == "": - return errors.New(ctx, errors.InvalidParameter, op, "user boundary address is missing") } const resourceType = sessionResourceType @@ -82,13 +90,13 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au // Find and use a token for retrieving sessions var gotResponse bool - var resp []*sessions.Session + var currentPage *sessions.SessionListResult var newRefreshToken RefreshTokenValue + var foundAuthToken string var unsupportedCacheRequest bool - var removedIds []string var retErr error for at, t := range tokens { - resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal) + currentPage, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, oldRefreshTokenVal, currentPage) if api.ErrInvalidListToken.Is(err) { event.WriteSysEvent(ctx, op, "old list token is no longer valid, starting new initial fetch", "user_id", u.Id) if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { @@ -96,7 +104,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au } // try again without the refresh token oldRefreshToken = nil - resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") + currentPage, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "", currentPage) } if err != nil { if err == ErrRefreshNotSupported { @@ -106,6 +114,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au continue } } + foundAuthToken = t gotResponse = true break } @@ -120,44 +129,56 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au } var numDeleted int - _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - var err error - switch { - case oldRefreshToken == nil: - if numDeleted, err = w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", - []any{sql.Named("fk_user_id", u.Id)}); err != nil { - return err - } - case len(removedIds) > 0: - if numDeleted, err = w.Exec(ctx, "delete from session where id in @ids", - []any{sql.Named("ids", removedIds)}); err != nil { - return err - } - } - switch { - case unsupportedCacheRequest: - if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { - return err + var numUpserted int + var clearPerformed bool + for { + _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { + var err error + if (oldRefreshToken == nil || unsupportedCacheRequest) && !clearPerformed { + if numDeleted, err = w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", + []any{sql.Named("fk_user_id", u.Id)}); err != nil { + return err + } } - case newRefreshToken != "": - if err := upsertSessions(ctx, w, u, resp); err != nil { - return err + switch { + case unsupportedCacheRequest: + if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { + return err + } + case newRefreshToken != "": + numUpserted += len(currentPage.Items) + if err := upsertSessions(ctx, w, u, currentPage.Items); err != nil { + return err + } + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } + default: + // controller supports caching, but doesn't have any resources } - if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { - return err + if !unsupportedCacheRequest && len(currentPage.RemovedIds) > 0 { + if numDeleted, err = w.Exec(ctx, "delete from session where id in @ids", + []any{sql.Named("ids", currentPage.RemovedIds)}); err != nil { + return err + } } - default: - // controller supports caching, but doesn't have any resources + return nil + }) + if unsupportedCacheRequest || currentPage.ResponseType == "" || currentPage.ResponseType == "complete" { + break } - return nil - }) + currentPage, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, foundAuthToken, newRefreshToken, currentPage) + if err != nil { + break + } + } if err != nil { return errors.Wrap(ctx, err, op) } if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", numUpserted, "user id", u.Id) return nil } @@ -186,12 +207,12 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m // Find and use a token for retrieving sessions var gotResponse bool - var resp []*sessions.Session + var resp *sessions.SessionListResult var newRefreshToken RefreshTokenValue var unsupportedCacheRequest bool var retErr error for at, t := range tokens { - resp, _, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") + resp, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "", nil, WithUseNonPagedListing(true)) if err != nil { if err == ErrRefreshNotSupported { unsupportedCacheRequest = true @@ -217,24 +238,29 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { switch { case unsupportedCacheRequest: + // Since we know the controller doesn't support caching, we mark the + // user as unable to cache the data. if err := upsertRefreshToken(ctx, w, u, resourceType, sentinelNoRefreshToken); err != nil { return err } case newRefreshToken != "": + // Now that there is a refresh token, the data can be cached, so + // cache it and store the refresh token for future refreshes. First + // remove any values, then add the new ones var err error if numDeleted, err = w.Exec(ctx, "delete from session where fk_user_id = @fk_user_id", []any{sql.Named("fk_user_id", u.Id)}); err != nil { return err } - if err := upsertSessions(ctx, w, u, resp); err != nil { + if err := upsertSessions(ctx, w, u, resp.Items); err != nil { return err } if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { return err } default: - // This is no longer flagged as not supported, but we dont have a - // refresh token so clear out any refresh token we have stored. + // We know the controller supports caching, but doesn't have a + // refresh token so clear out any refresh token we have for this resource. if err := deleteRefreshToken(ctx, w, u, resourceType); err != nil { return err } @@ -247,7 +273,7 @@ func (r *Repository) checkCachingSessions(ctx context.Context, u *user, tokens m if unsupportedCacheRequest { return ErrRefreshNotSupported } - event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", len(resp), "user_id", u.Id) + event.WriteSysEvent(ctx, op, "sessions updated", "deleted", numDeleted, "upserted", len(resp.Items), "user_id", u.Id) return nil } diff --git a/internal/clientcache/internal/cache/repository_sessions_test.go b/internal/clientcache/internal/cache/repository_sessions_test.go index 4627f8cd88..5cc31dca0a 100644 --- a/internal/clientcache/internal/cache/repository_sessions_test.go +++ b/internal/clientcache/internal/cache/repository_sessions_test.go @@ -150,7 +150,7 @@ func TestRepository_refreshSessions(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshSessions(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{tc.sess}, [][]string{nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{tc.sess}, [][]string{nil})))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) @@ -235,7 +235,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { } err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err := r.ListSessions(ctx, at.Id) @@ -245,7 +245,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { // Refreshing again uses the refresh token and get additional sessions, appending // them to the response err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) assert.NoError(t, err) got, err = r.ListSessions(ctx, at.Id) @@ -255,7 +255,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { // Refreshing again wont return any more resources, but also none should be // removed require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))))) assert.NoError(t, err) got, err = r.ListSessions(ctx, at.Id) @@ -264,7 +264,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ss[0])))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ss[0]))))) assert.NoError(t, err) got, err = r.ListSessions(ctx, at.Id) @@ -351,7 +351,7 @@ func TestRepository_ListSessions(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))) t.Run("wrong user gets no sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt2.AuthTokenId) @@ -395,7 +395,7 @@ func TestRepository_ListSessionsLimiting(t *testing.T) { ts = append(ts, session("s"+strconv.Itoa(i))) } require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -527,7 +527,7 @@ func TestRepository_QuerySessions(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))) t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) @@ -571,7 +571,7 @@ func TestRepository_QuerySessionsLimiting(t *testing.T) { ts = append(ts, session("t"+strconv.Itoa(i))) } require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil}))))) searchService, err := NewSearchService(ctx, r) require.NoError(t, err) @@ -637,15 +637,15 @@ func TestDefaultSessionRetrievalFunc(t *testing.T) { } } - got, removed, refTok, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "") + got, refTok, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "", nil) assert.NoError(t, err) assert.NotEmpty(t, refTok) - assert.Empty(t, removed) - assert.Len(t, got, 1) + assert.Empty(t, got.RemovedIds) + assert.Len(t, got.Items, 1) - got2, removed2, refTok2, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok) + got2, refTok2, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok, nil) assert.NoError(t, err) assert.NotEmpty(t, refTok2) - assert.Empty(t, removed2) - assert.Empty(t, got2) + assert.Empty(t, got2.RemovedIds) + assert.Empty(t, got2.Items) } diff --git a/internal/clientcache/internal/cache/status_test.go b/internal/clientcache/internal/cache/status_test.go index 67e34fc2ab..34b3d2fb2e 100644 --- a/internal/clientcache/internal/cache/status_test.go +++ b/internal/clientcache/internal/cache/status_test.go @@ -203,7 +203,7 @@ func TestStatus(t *testing.T) { session("3"), } err := r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{sess}, [][]string{nil}))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{sess}, [][]string{nil})))) require.NoError(t, err) als := []*aliases.Alias{ @@ -212,7 +212,7 @@ func TestStatus(t *testing.T) { alias("3"), } err = r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{als}, [][]string{nil}))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{als}, [][]string{nil})))) require.NoError(t, err) got, err := ss.Status(ctx) @@ -308,7 +308,7 @@ func TestStatus_unsupported(t *testing.T) { })) err = r.refreshResolvableAliases(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t))) + WithAliasRetrievalFunc(testResolvableAliasStaticResourceRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)))) require.ErrorIs(t, err, ErrRefreshNotSupported) err = r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, @@ -316,7 +316,7 @@ func TestStatus_unsupported(t *testing.T) { require.ErrorIs(t, err, ErrRefreshNotSupported) err = r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t))) + WithSessionRetrievalFunc(testSessionStaticResourceRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) require.ErrorIs(t, err, ErrRefreshNotSupported) got, err := ss.Status(ctx) diff --git a/internal/clientcache/internal/daemon/testing.go b/internal/clientcache/internal/daemon/testing.go index 0689d9cd4b..ffd24d7f59 100644 --- a/internal/clientcache/internal/daemon/testing.go +++ b/internal/clientcache/internal/daemon/testing.go @@ -91,11 +91,13 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, alts [] r, err := cache.NewRepository(ctx, s.CacheServer.store.Load(), &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn) require.NoError(t, err) - altFn := func(ctx context.Context, _, tok, _ string, _ cache.RefreshTokenValue) ([]*aliases.Alias, []string, cache.RefreshTokenValue, error) { + altFn := func(ctx context.Context, _ string, tok, _ string, _ cache.RefreshTokenValue, inPage *aliases.AliasListResult, opt ...cache.Option) (*aliases.AliasListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return nil, "", nil } - return alts, nil, "addedaliases", nil + return &aliases.AliasListResult{ + Items: alts, + }, "addedaliases", nil } tarFn := func(ctx context.Context, _ string, tok string, _ cache.RefreshTokenValue, inPage *targets.TargetListResult, opt ...cache.Option) (*targets.TargetListResult, cache.RefreshTokenValue, error) { if tok != p.Token { @@ -105,11 +107,13 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, alts [] Items: tars, }, "addedtargets", nil } - sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { + sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue, inPage *sessions.SessionListResult, opt ...cache.Option) (*sessions.SessionListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return nil, "", nil } - return sess, nil, "addedsessions", nil + return &sessions.SessionListResult{ + Items: sess, + }, "addedsessions", nil } rs, err := cache.NewRefreshService(ctx, r, hclog.NewNullLogger(), 0, 0) require.NoError(t, err) @@ -136,11 +140,15 @@ func (s *TestServer) AddUnsupportedCachingData(t *testing.T, p *authtokens.AuthT }, }, "", cache.ErrRefreshNotSupported } - sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { + sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue, inPage *sessions.SessionListResult, opt ...cache.Option) (*sessions.SessionListResult, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil, "", nil + return &sessions.SessionListResult{}, "", nil } - return []*sessions.Session{}, nil, "", cache.ErrRefreshNotSupported + return &sessions.SessionListResult{ + Items: []*sessions.Session{ + {Id: "s_unsupported"}, + }, + }, "", cache.ErrRefreshNotSupported } rs, err := cache.NewRefreshService(ctx, r, hclog.NewNullLogger(), 0, 0) require.NoError(t, err) From 62d4f8453b0e0ba179a0e6c36b45a54c799a4e7e Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 17 Sep 2024 10:07:43 -0400 Subject: [PATCH 10/15] Don't explode grants from the DB (#5104) This removes the Cartesian product in the DB for GrantsForUser in favor of returning the actual grant scope information and dealing with it in application code. (cherry picked from commit 913cd536823d73994acc4f4d833208c1cc05f91b) --- api/go.mod | 11 +- api/go.sum | 55 +- go.mod | 2 +- .../repository_alias_list_resolvable.go | 170 ++- .../service_list_resolvable_ext_test.go | 308 ++++- internal/cmd/commands/rolescmd/funcs.go | 2 +- internal/daemon/controller/auth/auth.go | 13 +- .../controller/auth/authorized_actions.go | 8 +- .../handlers/accounts/account_service_test.go | 115 ++ .../authmethods/authmethod_service.go | 14 +- .../credentialstore_service.go | 4 +- .../host_catalogs/host_catalog_service.go | 8 +- .../controller/handlers/roles/role_service.go | 10 +- .../handlers/roles/role_service_test.go | 4 +- .../handlers/scopes/scope_service.go | 36 +- .../handlers/sessions/session_service.go | 7 +- .../handlers/targets/target_service.go | 2 +- .../targets/tcp/target_service_test.go | 866 ++++++++---- .../controller/handlers/users/user_service.go | 7 +- internal/iam/options.go | 39 +- internal/iam/query.go | 116 +- internal/iam/repository_role_grant.go | 59 +- internal/iam/repository_role_grant_test.go | 1147 ++++++++++++---- internal/iam/role_grant.go | 4 +- internal/perms/acl.go | 392 ++++-- internal/perms/acl_test.go | 1179 +++++++++++++---- internal/perms/grants.go | 130 +- internal/perms/grants_test.go | 405 ++---- internal/perms/output_fields_test.go | 27 +- internal/session/repository.go | 2 +- internal/session/repository_session_test.go | 38 +- internal/session/service_list_ext_test.go | 6 +- internal/target/options_test.go | 4 +- internal/target/repository.go | 2 +- internal/target/repository_ext_test.go | 40 +- internal/target/repository_test.go | 12 +- internal/target/service_list_ext_test.go | 8 +- internal/tests/api/users/user_test.go | 2 +- 38 files changed, 3636 insertions(+), 1618 deletions(-) diff --git a/api/go.mod b/api/go.mod index 00b7c7af6b..ce20c34df6 100644 --- a/api/go.mod +++ b/api/go.mod @@ -3,7 +3,7 @@ module github.com/hashicorp/boundary/api go 1.23.1 require ( - github.com/hashicorp/boundary/sdk v0.0.40 + github.com/hashicorp/boundary/sdk v0.0.48 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-kms-wrapping/v2 v2.0.14 github.com/hashicorp/go-retryablehttp v0.7.4 @@ -18,7 +18,7 @@ require ( github.com/stretchr/testify v1.8.4 go.uber.org/atomic v1.11.0 golang.org/x/time v0.3.0 - google.golang.org/grpc v1.59.0 + google.golang.org/grpc v1.61.0 google.golang.org/protobuf v1.33.0 nhooyr.io/websocket v1.8.10 ) @@ -36,9 +36,10 @@ require ( github.com/mitchellh/pointerstructure v1.2.1 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.8.1 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect - golang.org/x/crypto v0.14.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20231030173426-d783a09b4405 // indirect + golang.org/x/crypto v0.18.0 // indirect + golang.org/x/sys v0.16.0 // indirect + google.golang.org/genproto v0.0.0-20240116215550-a9fa1716bcac // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/api/go.sum b/api/go.sum index 6600349d5b..4b7b62d724 100644 --- a/api/go.sum +++ b/api/go.sum @@ -21,12 +21,12 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/boundary/sdk v0.0.40 h1:HNJcMWHCjoraPJALTZ9JssSoP/vflew2+VB656nvRlY= -github.com/hashicorp/boundary/sdk v0.0.40/go.mod h1:+XTDYf9YNeKIbGOPJwy7hlO2Le4zgzCtHCG/u+z4THI= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/boundary/sdk v0.0.48 h1:4HqyX1tS1kuaCa18OSbPGf8ZHJuwdmm1yaxr1u+nxZ4= +github.com/hashicorp/boundary/sdk v0.0.48/go.mod h1:9iOT7kDM6mYcSkKxNuZlv8rP7U5BG1kXoevjLLL8lNQ= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -88,8 +88,8 @@ github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f h1:E87tDTVS5W github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f/go.mod h1:3J2qVK16Lq8V+wfiL2lPeDZ7UWMxk5LemerHa1p6N00= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -117,14 +117,13 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXqo= github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s= github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= -github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= @@ -146,8 +145,8 @@ go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -157,19 +156,19 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -182,22 +181,22 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20231016165738-49dd2c1f3d0b h1:+YaDE2r2OG8t/z5qmsh7Y+XXwCbvadxxZ0YY6mTdrVA= -google.golang.org/genproto/googleapis/api v0.0.0-20231030173426-d783a09b4405 h1:HJMDndgxest5n2y77fnErkM62iUsptE/H8p0dC2Huo4= -google.golang.org/genproto/googleapis/api v0.0.0-20231030173426-d783a09b4405/go.mod h1:oT32Z4o8Zv2xPQTg0pbVaPr0MPOH6f14RgXt7zfIpwg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc= -google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= -google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= +google.golang.org/genproto v0.0.0-20240116215550-a9fa1716bcac h1:ZL/Teoy/ZGnzyrqK/Optxxp2pmVh+fmJ97slxSRyzUg= +google.golang.org/genproto v0.0.0-20240116215550-a9fa1716bcac/go.mod h1:+Rvu7ElI+aLzyDQhpHMFMMltsD6m7nqpuWDd2CwJw3k= +google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe h1:0poefMBYvYbs7g5UkjS6HcxBPaTRAmznle9jnxYoAI8= +google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:4jWUdICTdgc3Ibxmr8nAJiiLHwQBY0UI0XZcEMaFKaA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240125205218-1f4bbc51befe h1:bQnxqljG/wqi4NTXu2+DJ3n7APcEA882QZ1JvhQAq9o= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s= +google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= +google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go.mod b/go.mod index 186a0bf079..4a46436146 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 github.com/hashicorp/boundary/api v0.0.50 - github.com/hashicorp/boundary/sdk v0.0.46 + github.com/hashicorp/boundary/sdk v0.0.48 github.com/hashicorp/cap v0.5.1-0.20240315182732-faa330bfb8df github.com/hashicorp/dawdle v0.5.0 github.com/hashicorp/eventlogger v0.2.9 diff --git a/internal/alias/target/repository_alias_list_resolvable.go b/internal/alias/target/repository_alias_list_resolvable.go index da78507eaf..bb7a39bec0 100644 --- a/internal/alias/target/repository_alias_list_resolvable.go +++ b/internal/alias/target/repository_alias_list_resolvable.go @@ -10,26 +10,57 @@ import ( "strings" "time" + "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/perms" + "github.com/hashicorp/boundary/internal/types/scope" ) -// targetAndScopeIdsForDestinations returns the target ids for which there is -// at least one permission. If all targets in a specific scope are granted -// permission for an action, then the scope id is in the returned scope id slice. -func targetAndScopeIdsForDestinations(perms []perms.Permission) ([]string, []string) { - var targetIds, scopeIds []string - for _, perm := range perms { +func splitPermissions(permissions []perms.Permission) (directIds, directScopeIds, childAllScopes []string, allDescendants bool) { + // First check for all descendants. Since what we are querying for below is + // for targets (either IDs, or targets within specific scopes), and targets + // are not in global, if this matches we can actually ignore everything + // else. + for _, perm := range permissions { + if perm.GrantScopeId == globals.GrantScopeDescendants && perm.All { + allDescendants = true + return + } + } + + directIds = make([]string, 0, len(permissions)) + directScopeIds = make([]string, 0, len(permissions)) + childAllScopes = make([]string, 0, len(permissions)) + for _, perm := range permissions { switch { + case allDescendants: + // See the above check; we don't need any other info + case perm.GrantScopeId == scope.Global.String() || strings.HasPrefix(perm.GrantScopeId, globals.OrgPrefix): + // There are no targets in global or orgs + case perm.RoleScopeId == scope.Global.String() && perm.GrantScopeId == globals.GrantScopeChildren: + // A role in global that includes children will include only orgs, + // which do not have targets, so ignore + case perm.GrantScopeId == globals.GrantScopeChildren && perm.All: + // Because of the above check this will match only grants from org + // roles. If the grant scope is children and all, we store the scope + // ID. + childAllScopes = append(childAllScopes, perm.RoleScopeId) case perm.All: - scopeIds = append(scopeIds, perm.ScopeId) + // We ignore descendants and if this was a children grant scope and + // perm.All it would match the above case. So this is a grant + // directly on a scope. Since only projects contain targets, we can + // ignore any grant scope ID that doesn't match targets. + if strings.HasPrefix(perm.GrantScopeId, globals.ProjectPrefix) { + directScopeIds = append(directScopeIds, perm.GrantScopeId) + } case len(perm.ResourceIds) > 0: - targetIds = append(targetIds, perm.ResourceIds...) + // It's an ID grant + directIds = append(directIds, perm.ResourceIds...) } } - return targetIds, scopeIds + return } // listResolvableAliases lists aliases which have a destination id set to that @@ -41,7 +72,8 @@ func (r *Repository) listResolvableAliases(ctx context.Context, permissions []pe case len(permissions) == 0: return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing permissions") } - toTargetIds, toTargetsInScopeIds := targetAndScopeIdsForDestinations(permissions) + + directIds, directScopeIds, childAllScopes, allDescendants := splitPermissions(permissions) opts, err := getOpts(opt...) if err != nil { @@ -59,16 +91,33 @@ func (r *Repository) listResolvableAliases(ctx context.Context, permissions []pe var args []any var destinationIdClauses []string - if len(toTargetIds) > 0 { - destinationIdClauses = append(destinationIdClauses, "destination_id in @target_ids") - args = append(args, sql.Named("target_ids", toTargetIds)) - } - if len(toTargetsInScopeIds) > 0 { - destinationIdClauses = append(destinationIdClauses, "destination_id in (select public_id from target where project_id in @target_scope_ids)") - args = append(args, sql.Named("target_scope_ids", toTargetsInScopeIds)) - } - if len(destinationIdClauses) == 0 { - return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "no target ids or scope ids provided") + + switch { + case allDescendants: + // This matches all targets + destinationIdClauses = append(destinationIdClauses, "destination_id in (select public_id from target)") + default: + // Add orgs with all permissions on children + if len(childAllScopes) > 0 { + destinationIdClauses = append(destinationIdClauses, + "destination_id in "+ + "(select public_id from target where project_id in "+ + "(select public_id from iam_scope where parent_id = any(@child_all_scopes)))", + ) + args = append(args, sql.Named("child_all_scopes", "{"+strings.Join(childAllScopes, ",")+"}")) + } + // Add target ids + if len(directIds) > 0 { + destinationIdClauses = append(destinationIdClauses, "destination_id = any(@target_ids)") + args = append(args, sql.Named("target_ids", "{"+strings.Join(directIds, ",")+"}")) + } + if len(directScopeIds) > 0 { + destinationIdClauses = append(destinationIdClauses, "destination_id in (select public_id from target where project_id = any(@target_scope_ids))") + args = append(args, sql.Named("target_scope_ids", "{"+strings.Join(directScopeIds, ",")+"}")) + } + if len(destinationIdClauses) == 0 && len(childAllScopes) == 0 { + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "no target ids or scope ids provided") + } } whereClause := fmt.Sprintf("destination_id is not null and (%s)", strings.Join(destinationIdClauses, " or ")) @@ -98,7 +147,8 @@ func (r *Repository) listResolvableAliasesRefresh(ctx context.Context, updatedAf case len(permissions) == 0: return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing permissions") } - toTargetIds, toTargetsInScopeIds := targetAndScopeIdsForDestinations(permissions) + + directIds, directScopeIds, childAllScopes, allDescendants := splitPermissions(permissions) opts, err := getOpts(opt...) if err != nil { @@ -116,16 +166,33 @@ func (r *Repository) listResolvableAliasesRefresh(ctx context.Context, updatedAf var args []any var destinationIdClauses []string - if len(toTargetIds) > 0 { - destinationIdClauses = append(destinationIdClauses, "destination_id in @target_ids") - args = append(args, sql.Named("target_ids", toTargetIds)) - } - if len(toTargetsInScopeIds) > 0 { - destinationIdClauses = append(destinationIdClauses, "destination_id in (select public_id from target where project_id in @target_scope_ids)") - args = append(args, sql.Named("target_scope_ids", toTargetsInScopeIds)) - } - if len(destinationIdClauses) == 0 { - return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "no target ids or scope ids provided") + + switch { + case allDescendants: + // This matches all targets + destinationIdClauses = append(destinationIdClauses, "destination_id in (select public_id from target)") + default: + // Add orgs with all permissions on children + if len(childAllScopes) > 0 { + destinationIdClauses = append(destinationIdClauses, + "destination_id in "+ + "(select public_id from target where project_id in "+ + "(select public_id from iam_scope where parent_id = any(@child_all_scopes)))", + ) + args = append(args, sql.Named("child_all_scopes", "{"+strings.Join(childAllScopes, ",")+"}")) + } + // Add target ids + if len(directIds) > 0 { + destinationIdClauses = append(destinationIdClauses, "destination_id = any(@target_ids)") + args = append(args, sql.Named("target_ids", "{"+strings.Join(directIds, ",")+"}")) + } + if len(directScopeIds) > 0 { + destinationIdClauses = append(destinationIdClauses, "destination_id in (select public_id from target where project_id = any(@target_scope_ids))") + args = append(args, sql.Named("target_scope_ids", "{"+strings.Join(directScopeIds, ",")+"}")) + } + if len(destinationIdClauses) == 0 && len(childAllScopes) == 0 { + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "no target ids or scope ids provided") + } } whereClause := fmt.Sprintf("update_time > @updated_after_time and destination_id is not null and (%s)", @@ -162,21 +229,40 @@ func (r *Repository) listRemovedResolvableAliasIds(ctx context.Context, since ti // to be provided. return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing permissions") } - toTargetIds, toTargetsInScopeIds := targetAndScopeIdsForDestinations(permissions) + + directIds, directScopeIds, childAllScopes, allDescendants := splitPermissions(permissions) var args []any var destinationIdClauses []string - if len(toTargetIds) > 0 { - destinationIdClauses = append(destinationIdClauses, "destination_id not in @target_ids") - args = append(args, sql.Named("target_ids", toTargetIds)) - } - if len(toTargetsInScopeIds) > 0 { - destinationIdClauses = append(destinationIdClauses, "destination_id not in (select public_id from target where project_id in @target_scope_ids)") - args = append(args, sql.Named("target_scope_ids", toTargetsInScopeIds)) - } - if len(destinationIdClauses) == 0 { - return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "no target ids or scope ids provided") + + switch { + case allDescendants: + // This matches all targets + destinationIdClauses = append(destinationIdClauses, "destination_id not in (select public_id from target)") + default: + // Add orgs with all permissions on children + if len(childAllScopes) > 0 { + destinationIdClauses = append(destinationIdClauses, + "destination_id not in "+ + "(select public_id from target where project_id in "+ + "(select public_id from iam_scope where parent_id = any(@child_all_scopes)))", + ) + args = append(args, sql.Named("child_all_scopes", "{"+strings.Join(childAllScopes, ",")+"}")) + } + // Add target ids + if len(directIds) > 0 { + destinationIdClauses = append(destinationIdClauses, "destination_id != all(@target_ids)") + args = append(args, sql.Named("target_ids", "{"+strings.Join(directIds, ",")+"}")) + } + if len(directScopeIds) > 0 { + destinationIdClauses = append(destinationIdClauses, "destination_id not in (select public_id from target where project_id = any(@target_scope_ids))") + args = append(args, sql.Named("target_scope_ids", "{"+strings.Join(directScopeIds, ",")+"}")) + } + if len(destinationIdClauses) == 0 && len(childAllScopes) == 0 { + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "no target ids or scope ids provided") + } } + whereClause := fmt.Sprintf("update_time > @updated_after_time and (destination_id is null or (%s))", strings.Join(destinationIdClauses, " and ")) args = append(args, diff --git a/internal/alias/target/service_list_resolvable_ext_test.go b/internal/alias/target/service_list_resolvable_ext_test.go index 86a53bb15d..48d6be51c3 100644 --- a/internal/alias/target/service_list_resolvable_ext_test.go +++ b/internal/alias/target/service_list_resolvable_ext_test.go @@ -30,6 +30,11 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +// NOTE: These tests rely on state from previous tests, so they should be run in +// order -- running some subtests on their own will result in errors. It might +// be nice for them to be refactored at some point to start with a known state +// to avoid this. At the time I'm writing this, I'm not doing that because I'm +// not sure if this was a purposeful design choice. func TestService_ListResolvableAliases(t *testing.T) { fiveDaysAgo := time.Now() // Set database read timeout to avoid duplicates in response @@ -56,12 +61,12 @@ func TestService_ListResolvableAliases(t *testing.T) { } byIdPerms := []perms.Permission{ { - ScopeId: proj.GetPublicId(), - Resource: resource.Target, - Action: action.ListResolvableAliases, - ResourceIds: []string{tar.GetPublicId(), "ttcp_unknownid"}, - OnlySelf: false, - All: false, + GrantScopeId: proj.GetPublicId(), + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{tar.GetPublicId(), "ttcp_unknownid"}, + OnlySelf: false, + All: false, }, } // Reverse since we read items in descending order (newest first) @@ -76,16 +81,48 @@ func TestService_ListResolvableAliases(t *testing.T) { } byScopePerms := []perms.Permission{ { - ScopeId: proj2.GetPublicId(), - Resource: resource.Target, - Action: action.ListResolvableAliases, - OnlySelf: false, - All: true, + GrantScopeId: proj2.GetPublicId(), + Resource: resource.Target, + Action: action.ListResolvableAliases, + OnlySelf: false, + All: true, }, } // Reverse since we read items in descending order (newest first) slices.Reverse(byScopeResources) + org3, proj3 := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + tar3 := tcp.TestTarget(ctx, t, conn, proj3.GetPublicId(), "target3") + var byChildrenResources []*target.Alias + for i := 0; i < 5; i++ { + r := target.TestAlias(t, rw, fmt.Sprintf("test%d.alias.by-children", i), target.WithDestinationId(tar3.GetPublicId())) + byChildrenResources = append(byChildrenResources, r) + } + byChildrenPerms := []perms.Permission{ + { + RoleScopeId: org3.GetPublicId(), + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Resource: resource.Target, + Action: action.ListResolvableAliases, + OnlySelf: false, + All: true, + }, + } + // Reverse since we read items in descending order (newest first) + slices.Reverse(byChildrenResources) + + byDescendantsPerms := []perms.Permission{ + { + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeDescendants, + Resource: resource.Target, + Action: action.ListResolvableAliases, + OnlySelf: false, + All: true, + }, + } + repo, repoErr := target.NewRepository(ctx, rw, rw, kmsCache) require.NoError(t, repoErr) @@ -296,21 +333,41 @@ func TestService_ListResolvableAliases(t *testing.T) { }) }) + // Build the descendants resources for the first tests + byDescendantsResources := append([]*target.Alias{}, byChildrenResources...) + byDescendantsResources = append(byDescendantsResources, byScopeResources...) + byDescendantsResources = append(byDescendantsResources, byIdResources...) + t.Run("simple pagination", func(t *testing.T) { cases := []struct { name string perms []perms.Permission resourceSlice []*target.Alias + lastPageSize int }{ { name: "by-id", perms: byIdPerms, resourceSlice: byIdResources, + lastPageSize: 1, }, { name: "by-scope", perms: byScopePerms, resourceSlice: byScopeResources, + lastPageSize: 1, + }, + { + name: "by-children", + perms: byChildrenPerms, + resourceSlice: byChildrenResources, + lastPageSize: 1, + }, + { + name: "by-descendants", + perms: byDescendantsPerms, + resourceSlice: byDescendantsResources, + lastPageSize: 11, }, } for _, tc := range cases { @@ -320,7 +377,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NotNil(t, resp.ListToken) require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp.CompleteListing) - require.Equal(t, resp.EstimatedItemCount, 10) + require.Equal(t, 15, resp.EstimatedItemCount) require.Empty(t, resp.DeletedIds) require.Len(t, resp.Items, 1) require.Empty(t, cmp.Diff(resp.Items[0], tc.resourceSlice[0], cmpIgnoreUnexportedOpts), "resources did not match", tc.resourceSlice, "resp", resp.Items) @@ -329,7 +386,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp2.CompleteListing) - require.Equal(t, resp2.EstimatedItemCount, 10) + require.Equal(t, 15, resp2.EstimatedItemCount) require.Empty(t, resp2.DeletedIds) require.Len(t, resp2.Items, 1) require.Empty(t, cmp.Diff(resp2.Items[0], tc.resourceSlice[1], cmpIgnoreUnexportedOpts)) @@ -338,7 +395,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp3.CompleteListing) - require.Equal(t, resp3.EstimatedItemCount, 10) + require.Equal(t, 15, resp3.EstimatedItemCount) require.Empty(t, resp3.DeletedIds) require.Len(t, resp3.Items, 1) require.Empty(t, cmp.Diff(resp3.Items[0], tc.resourceSlice[2], cmpIgnoreUnexportedOpts)) @@ -347,18 +404,18 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp4.CompleteListing) - require.Equal(t, resp4.EstimatedItemCount, 10) + require.Equal(t, 15, resp4.EstimatedItemCount) require.Empty(t, resp4.DeletedIds) require.Len(t, resp4.Items, 1) require.Empty(t, cmp.Diff(resp4.Items[0], tc.resourceSlice[3], cmpIgnoreUnexportedOpts)) - resp5, err := target.ListResolvableAliasesPage(ctx, []byte("some hash"), 1, resp4.ListToken, repo, tc.perms) + resp5, err := target.ListResolvableAliasesPage(ctx, []byte("some hash"), tc.lastPageSize, resp4.ListToken, repo, tc.perms) require.NoError(t, err) require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp5.CompleteListing) - require.Equal(t, resp5.EstimatedItemCount, 10) + require.Equal(t, 15, resp5.EstimatedItemCount) require.Empty(t, resp5.DeletedIds) - require.Len(t, resp5.Items, 1) + require.Len(t, resp5.Items, tc.lastPageSize) require.Empty(t, cmp.Diff(resp5.Items[0], tc.resourceSlice[4], cmpIgnoreUnexportedOpts)) // Finished initial pagination phase, request refresh @@ -367,7 +424,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp6.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp6.CompleteListing) - require.Equal(t, resp6.EstimatedItemCount, 10) + require.Equal(t, 15, resp6.EstimatedItemCount) require.Empty(t, resp6.DeletedIds) require.Empty(t, resp6.Items) @@ -392,7 +449,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp7.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp7.CompleteListing) - require.Equal(t, resp7.EstimatedItemCount, 12) + require.Equal(t, 17, resp7.EstimatedItemCount) require.Empty(t, resp7.DeletedIds) require.Len(t, resp7.Items, 1) require.Empty(t, cmp.Diff(resp7.Items[0], newR2, cmpIgnoreUnexportedOpts)) @@ -402,7 +459,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp8.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp8.CompleteListing) - require.Equal(t, resp8.EstimatedItemCount, 12) + require.Equal(t, 17, resp8.EstimatedItemCount) require.Empty(t, resp8.DeletedIds) require.Len(t, resp8.Items, 1) require.Empty(t, cmp.Diff(resp8.Items[0], newR1, cmpIgnoreUnexportedOpts)) @@ -412,14 +469,14 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp9.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp9.CompleteListing) - require.Equal(t, resp9.EstimatedItemCount, 12) + require.Equal(t, 17, resp9.EstimatedItemCount) require.Empty(t, resp9.DeletedIds) require.Empty(t, resp9.Items) }) } }) - t.Run("simple pagination with destination id changes", func(t *testing.T) { + t.Run("simple pagination with destination id changes - id", func(t *testing.T) { firstUpdatedA := byScopeResources[0] // this no longer has the destination id that has permissions firstUpdatedA.DestinationId = tar.GetPublicId() @@ -442,7 +499,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NotNil(t, resp.ListToken) require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp.CompleteListing) - require.Equal(t, resp.EstimatedItemCount, 10) + require.Equal(t, resp.EstimatedItemCount, 15) require.Empty(t, resp.DeletedIds) require.Len(t, resp.Items, 1) require.Empty(t, cmp.Diff(resp.Items[0], byScopeResources[0], cmpIgnoreUnexportedOpts)) @@ -452,7 +509,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp2.CompleteListing) - require.Equal(t, resp2.EstimatedItemCount, 10) + require.Equal(t, resp2.EstimatedItemCount, 15) require.Empty(t, resp2.DeletedIds) require.Len(t, resp2.Items, 3) require.Empty(t, cmp.Diff(resp2.Items, byScopeResources[1:], cmpIgnoreUnexportedOpts)) @@ -479,12 +536,99 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp3.CompleteListing) - require.Equal(t, resp3.EstimatedItemCount, 10) + require.Equal(t, resp3.EstimatedItemCount, 15) require.Contains(t, resp3.DeletedIds, secondA.GetPublicId()) require.Empty(t, resp3.Items) }) - t.Run("simple pagination with deletion", func(t *testing.T) { + t.Run("simple pagination with destination id changes - children", func(t *testing.T) { + firstUpdatedA := byChildrenResources[0] + // this no longer has the destination id that has permissions + firstUpdatedA.DestinationId = tar.GetPublicId() + firstUpdatedA, _, err := repo.UpdateAlias(ctx, firstUpdatedA, firstUpdatedA.GetVersion(), []string{"DestinationId"}) + require.NoError(t, err) + byChildrenResources = byChildrenResources[1:] + t.Cleanup(func() { + firstUpdatedA.DestinationId = tar3.GetPublicId() + firstUpdatedA, _, err := repo.UpdateAlias(ctx, firstUpdatedA, firstUpdatedA.GetVersion(), []string{"DestinationId"}) + require.NoError(t, err) + byChildrenResources = append([]*target.Alias{firstUpdatedA}, byChildrenResources...) + }) + + // Run analyze to update count estimate + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + resp, err := target.ListResolvableAliases(ctx, []byte("some hash"), 1, repo, byChildrenPerms) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 15) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], byChildrenResources[0], cmpIgnoreUnexportedOpts)) + + // request remaining results + resp2, err := target.ListResolvableAliasesPage(ctx, []byte("some hash"), 3, resp.ListToken, repo, byChildrenPerms) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 15) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 3) + require.Empty(t, cmp.Diff(resp2.Items, byChildrenResources[1:], cmpIgnoreUnexportedOpts)) + }) + + // We have to re-build the expected set of resources as the original slices + // have been updated with updated values + byDescendantsResources = append([]*target.Alias{}, byChildrenResources...) + byDescendantsResources = append(byDescendantsResources, byScopeResources...) + byDescendantsResources = append(byDescendantsResources, byIdResources...) + + t.Run("simple pagination with destination id changes - descendants", func(t *testing.T) { + firstUpdatedA := byDescendantsResources[0] + // this no longer has the destination id that has permissions + firstUpdatedA.DestinationId = tar.GetPublicId() + firstUpdatedA, _, err := repo.UpdateAlias(ctx, firstUpdatedA, firstUpdatedA.GetVersion(), []string{"DestinationId"}) + require.NoError(t, err) + // Descendants will keep permissions for everything so we don't elide + // one here, but we need to increment the version number to match + byDescendantsResources[0] = firstUpdatedA + t.Cleanup(func() { + firstUpdatedA.DestinationId = tar3.GetPublicId() + firstUpdatedA, _, err = repo.UpdateAlias(ctx, firstUpdatedA, firstUpdatedA.GetVersion(), []string{"DestinationId"}) + require.NoError(t, err) + byDescendantsResources[0] = firstUpdatedA + }) + + // Run analyze to update count estimate + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + resp, err := target.ListResolvableAliases(ctx, []byte("some hash"), 1, repo, byDescendantsPerms) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 15) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], byDescendantsResources[0], cmpIgnoreUnexportedOpts)) + + // request remaining results -- we should see all, because descendants + // is going to maintain permissions + resp2, err := target.ListResolvableAliasesPage(ctx, []byte("some hash"), 14, resp.ListToken, repo, byDescendantsPerms) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 15) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 14) + require.Empty(t, cmp.Diff(resp2.Items, byDescendantsResources[1:], cmpIgnoreUnexportedOpts)) + }) + + t.Run("simple pagination with deletion - id", func(t *testing.T) { deletedAliasId := byIdResources[0].GetPublicId() _, err := repo.DeleteAlias(ctx, deletedAliasId) require.NoError(t, err) @@ -499,7 +643,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NotNil(t, resp.ListToken) require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) require.False(t, resp.CompleteListing) - require.Equal(t, resp.EstimatedItemCount, 9) + require.Equal(t, 14, resp.EstimatedItemCount) require.Empty(t, resp.DeletedIds) require.Len(t, resp.Items, 1) require.Empty(t, cmp.Diff(resp.Items[0], byIdResources[0], cmpIgnoreUnexportedOpts)) @@ -509,7 +653,7 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp2.CompleteListing) - require.Equal(t, resp2.EstimatedItemCount, 9) + require.Equal(t, 14, resp2.EstimatedItemCount) require.Empty(t, resp2.DeletedIds) require.Len(t, resp2.Items, 3) require.Empty(t, cmp.Diff(resp2.Items, byIdResources[1:], cmpIgnoreUnexportedOpts)) @@ -528,7 +672,111 @@ func TestService_ListResolvableAliases(t *testing.T) { require.NoError(t, err) require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) require.True(t, resp3.CompleteListing) - require.Equal(t, resp3.EstimatedItemCount, 8) + require.Equal(t, 13, resp3.EstimatedItemCount) + require.Contains(t, resp3.DeletedIds, deletedAliasId) + require.Empty(t, resp3.Items) + }) + + t.Run("simple pagination with deletion - children", func(t *testing.T) { + deletedAliasId := byChildrenResources[0].GetPublicId() + _, err := repo.DeleteAlias(ctx, deletedAliasId) + require.NoError(t, err) + byChildrenResources = byChildrenResources[1:] + + // Run analyze to update count estimate + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + resp, err := target.ListResolvableAliases(ctx, []byte("some hash"), 1, repo, byChildrenPerms) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, 12, resp.EstimatedItemCount) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], byChildrenResources[0], cmpIgnoreUnexportedOpts)) + + // request remaining results + resp2, err := target.ListResolvableAliasesPage(ctx, []byte("some hash"), 8, resp.ListToken, repo, byChildrenPerms) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, 12, resp2.EstimatedItemCount) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 3) + require.Empty(t, cmp.Diff(resp2.Items, byChildrenResources[1:], cmpIgnoreUnexportedOpts)) + + deletedAliasId = byChildrenResources[0].GetPublicId() + _, err = repo.DeleteAlias(ctx, deletedAliasId) + require.NoError(t, err) + byChildrenResources = byChildrenResources[1:] + + // Run analyze to update count estimate + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // request a refresh, nothing should be returned except the deleted id + resp3, err := target.ListResolvableAliasesRefresh(ctx, []byte("some hash"), 1, resp2.ListToken, repo, byChildrenPerms) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp3.CompleteListing) + require.Equal(t, 11, resp3.EstimatedItemCount) + require.Contains(t, resp3.DeletedIds, deletedAliasId) + require.Empty(t, resp3.Items) + }) + + // We have to re-build the expected set of resources as the original slices + // have been updated with updated values again + byDescendantsResources = append([]*target.Alias{}, byChildrenResources...) + byDescendantsResources = append(byDescendantsResources, byScopeResources...) + byDescendantsResources = append(byDescendantsResources, byIdResources...) + + t.Run("simple pagination with deletion - descendants", func(t *testing.T) { + deletedAliasId := byDescendantsResources[0].GetPublicId() + _, err := repo.DeleteAlias(ctx, deletedAliasId) + require.NoError(t, err) + byDescendantsResources = byDescendantsResources[1:] + + // Run analyze to update count estimate + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + resp, err := target.ListResolvableAliases(ctx, []byte("some hash"), 1, repo, byDescendantsPerms) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, 10, resp.EstimatedItemCount) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], byDescendantsResources[0], cmpIgnoreUnexportedOpts)) + + // request remaining results + resp2, err := target.ListResolvableAliasesPage(ctx, []byte("some hash"), 13, resp.ListToken, repo, byDescendantsPerms) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, 10, resp2.EstimatedItemCount) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 9) + require.Empty(t, cmp.Diff(resp2.Items, byDescendantsResources[1:], cmpIgnoreUnexportedOpts)) + + deletedAliasId = byDescendantsResources[0].GetPublicId() + _, err = repo.DeleteAlias(ctx, deletedAliasId) + require.NoError(t, err) + byDescendantsResources = byDescendantsResources[1:] + + // Run analyze to update count estimate + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // request a refresh, nothing should be returned except the deleted id + resp3, err := target.ListResolvableAliasesRefresh(ctx, []byte("some hash"), 1, resp2.ListToken, repo, byDescendantsPerms) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp3.CompleteListing) + require.Equal(t, 9, resp3.EstimatedItemCount) require.Contains(t, resp3.DeletedIds, deletedAliasId) require.Empty(t, resp3.Items) }) diff --git a/internal/cmd/commands/rolescmd/funcs.go b/internal/cmd/commands/rolescmd/funcs.go index a89f2d4495..72b1d58076 100644 --- a/internal/cmd/commands/rolescmd/funcs.go +++ b/internal/cmd/commands/rolescmd/funcs.go @@ -259,7 +259,7 @@ func extraFlagsHandlingFuncImpl(c *Command, _ *base.FlagSets, opts *[]roles.Opti if len(c.flagGrants) > 0 { for _, grant := range c.flagGrants { - parsed, err := perms.Parse(c.Context, scope.Global.String(), grant) + parsed, err := perms.Parse(c.Context, perms.GrantTuple{RoleScopeId: scope.Global.String(), GrantScopeId: scope.Global.String(), Grant: grant}) if err != nil { c.UI.Error(fmt.Errorf("Grant %q could not be parsed successfully: %w", grant, err).Error()) return false diff --git a/internal/daemon/controller/auth/auth.go b/internal/daemon/controller/auth/auth.go index fc14ae7044..082ed39902 100644 --- a/internal/daemon/controller/auth/auth.go +++ b/internal/daemon/controller/auth/auth.go @@ -273,6 +273,7 @@ func Verify(ctx context.Context, opt ...Option) (ret VerifyResults) { Id: opts.withId, Pin: opts.withPin, Type: opts.withType, + // Parent Scope ID will be filled in via performAuthCheck } // Global scope has no parent ID; account for this if opts.withId == scope.Global.String() && opts.withType == resource.Scope { @@ -336,7 +337,7 @@ func Verify(ctx context.Context, opt ...Option) (ret VerifyResults) { grants = append(grants, event.Grant{ Grant: g.Grant, RoleId: g.RoleId, - ScopeId: g.ScopeId, + ScopeId: g.GrantScopeId, }) } ea.UserInfo = &event.UserInfo{ @@ -636,6 +637,7 @@ func (v verifier) performAuthCheck(ctx context.Context) ( ParentScopeId: scp.GetParentId(), } } + v.res.ParentScopeId = scopeInfo.ParentScopeId // At this point we don't need to look up grants since it's automatically allowed if v.requestInfo.TokenFormat == uint32(AuthTokenTypeRecoveryKms) { @@ -658,7 +660,7 @@ func (v verifier) performAuthCheck(ctx context.Context) ( // Note: Below, we always skip validation so that we don't error on formats // that we've since restricted, e.g. "ids=foo;actions=create,read". These // will simply not have an effect. - for _, pair := range grantTuples { + for _, tuple := range grantTuples { permsOpts := []perms.Option{ perms.WithUserId(*userData.User.Id), perms.WithSkipFinalValidation(true), @@ -668,11 +670,10 @@ func (v verifier) performAuthCheck(ctx context.Context) ( } parsed, err := perms.Parse( ctx, - pair.ScopeId, - pair.Grant, + tuple, permsOpts...) if err != nil { - retErr = errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed to parse grant %#v", pair.Grant))) + retErr = errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed to parse grant %#v", tuple.Grant))) return } parsedGrants = append(parsedGrants, parsed) @@ -867,7 +868,7 @@ func (r *VerifyResults) ScopesAuthorizedForList(ctx context.Context, rootScopeId aSet := r.FetchActionSetForType(ctx, resource.Unknown, // This is overridden by `WithResource` option. action.NewActionSet(action.List), - WithResource(&perms.Resource{Type: resourceType, ScopeId: scpId}), + WithResource(&perms.Resource{Type: resourceType, ScopeId: scpId, ParentScopeId: scp.GetParentId()}), ) // We only expect the action set to be nothing, or list. In case diff --git a/internal/daemon/controller/auth/authorized_actions.go b/internal/daemon/controller/auth/authorized_actions.go index 3e5cd96d3c..07a1dd95f5 100644 --- a/internal/daemon/controller/auth/authorized_actions.go +++ b/internal/daemon/controller/auth/authorized_actions.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/internal/perms" "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes" "github.com/hashicorp/go-secure-stdlib/strutil" "google.golang.org/protobuf/types/known/structpb" ) @@ -22,11 +23,12 @@ import ( func CalculateAuthorizedCollectionActions(ctx context.Context, authResults VerifyResults, mapToRange map[resource.Type]action.ActionSet, - scopeId, pin string, + scopeInfo *scopes.ScopeInfo, pin string, ) (map[string]*structpb.ListValue, error) { res := &perms.Resource{ - ScopeId: scopeId, - Pin: pin, + ScopeId: scopeInfo.GetId(), + Pin: pin, + ParentScopeId: scopeInfo.GetParentScopeId(), } // Range over the defined collections and check permissions against those // collections. diff --git a/internal/daemon/controller/handlers/accounts/account_service_test.go b/internal/daemon/controller/handlers/accounts/account_service_test.go index a1498a542a..5e33ea3741 100644 --- a/internal/daemon/controller/handlers/accounts/account_service_test.go +++ b/internal/daemon/controller/handlers/accounts/account_service_test.go @@ -4385,3 +4385,118 @@ func TestChangePassword(t *testing.T) { }) } } + +// The purpose of this test is mainly to ensure that we are properly fetching +// membership information in GrantsForUser across managed group types +func TestGrantsAcrossManagedGroups(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrap := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, wrap) + + org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrap)) + + databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + oidcAm := oidc.TestAuthMethod( + t, conn, databaseWrapper, org.PublicId, oidc.ActivePrivateState, + "alice-rp", "fido", + oidc.WithIssuer(oidc.TestConvertToUrls(t, "https://www.alice.com")[0]), + oidc.WithSigningAlgs(oidc.RS256), + oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://www.alice.com/callback")[0]), + ) + oidcAcct := oidc.TestAccount(t, conn, oidcAm, "test-subject") + // Create a managed group that will always match, so we can test that it is + // returned in results + oidcMg := oidc.TestManagedGroup(t, conn, oidcAm, `"/token/sub" matches ".*"`) + oidc.TestManagedGroupMember(t, conn, oidcMg.GetPublicId(), oidcAcct.GetPublicId()) + + ldapAm := ldap.TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + ldapAcct := ldap.TestAccount(t, conn, ldapAm, "test-acct", + ldap.WithMemberOfGroups(ctx, "admin"), + ldap.WithFullName(ctx, "test-name"), + ldap.WithEmail(ctx, "test-email"), + ldap.WithDn(ctx, "test-dn"), + ) + ldapMg := ldap.TestManagedGroup(t, conn, ldapAm, []string{"admin"}) + + iamRepoFn := func() (*iam.Repository, error) { + return iam.NewRepository(ctx, rw, rw, kmsCache) + } + iamRepo, err := iamRepoFn() + require.NoError(t, err) + + user := iam.TestUser(t, iamRepo, org.PublicId, iam.WithAccountIds(oidcAcct.GetPublicId(), ldapAcct.GetPublicId())) + + // Create two roles, each containing a single managed group, and add a + // unique grant to each + oidcRole := iam.TestRole(t, conn, org.GetPublicId(), iam.WithGrantScopeIds([]string{globals.GrantScopeChildren})) + iam.TestManagedGroupRole(t, conn, oidcRole.GetPublicId(), oidcMg.GetPublicId()) + iam.TestRoleGrant(t, conn, oidcRole.GetPublicId(), "ids=ttcp_oidc;actions=read") + ldapRole := iam.TestRole(t, conn, org.GetPublicId(), iam.WithGrantScopeIds([]string{globals.GrantScopeChildren})) + iam.TestManagedGroupRole(t, conn, ldapRole.GetPublicId(), ldapMg.GetPublicId()) + iam.TestRoleGrant(t, conn, ldapRole.GetPublicId(), "ids=ttcp_ldap;actions=read") + + grants, err := iamRepo.GrantsForUser(ctx, user.GetPublicId()) + require.NoError(t, err) + + // Verify we see both grants + var foundOidc, foundLdap bool + for _, grant := range grants { + if grant.Grant == "ids=ttcp_oidc;actions=read" { + foundOidc = true + } + if grant.Grant == "ids=ttcp_ldap;actions=read" { + foundLdap = true + } + } + assert.True(t, foundOidc) + assert.True(t, foundLdap) + + // Delete the ldap managed group + ldapRepo, err := ldap.NewRepository(ctx, rw, rw, kmsCache) + require.NoError(t, err) + numDeleted, err := ldapRepo.DeleteManagedGroup(ctx, org.GetPublicId(), ldapMg.GetPublicId()) + require.NoError(t, err) + assert.Equal(t, 1, numDeleted) + + // Verify we don't see the ldap grant anymore + grants, err = iamRepo.GrantsForUser(ctx, user.GetPublicId()) + require.NoError(t, err) + foundOidc = false + foundLdap = false + for _, grant := range grants { + if grant.Grant == "ids=ttcp_oidc;actions=read" { + foundOidc = true + } + if grant.Grant == "ids=ttcp_ldap;actions=read" { + foundLdap = true + } + } + assert.True(t, foundOidc) + assert.False(t, foundLdap) + + // Delete the oidc managed group + oidcRepo, err := oidc.NewRepository(ctx, rw, rw, kmsCache) + require.NoError(t, err) + numDeleted, err = oidcRepo.DeleteManagedGroup(ctx, org.GetPublicId(), oidcMg.GetPublicId()) + require.NoError(t, err) + assert.Equal(t, 1, numDeleted) + + // Verify we don't see the oidc grant anymore + grants, err = iamRepo.GrantsForUser(ctx, user.GetPublicId()) + require.NoError(t, err) + foundOidc = false + foundLdap = false + for _, grant := range grants { + if grant.Grant == "ids=ttcp_oidc;actions=read" { + foundOidc = true + } + if grant.Grant == "ids=ttcp_ldap;actions=read" { + foundLdap = true + } + } + assert.False(t, foundOidc) + assert.False(t, foundLdap) +} diff --git a/internal/daemon/controller/handlers/authmethods/authmethod_service.go b/internal/daemon/controller/handlers/authmethods/authmethod_service.go index 879cd6a21c..109868e324 100644 --- a/internal/daemon/controller/handlers/authmethods/authmethod_service.go +++ b/internal/daemon/controller/handlers/authmethods/authmethod_service.go @@ -343,7 +343,7 @@ func (s Service) GetAuthMethod(ctx context.Context, req *pbs.GetAuthMethodReques outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, am.GetPublicId(), IdActions[globals.ResourceInfoFromPrefix(am.GetPublicId()).Subtype]).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope.Id, am.GetPublicId()) + collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope, am.GetPublicId()) if err != nil { return nil, err } @@ -388,7 +388,7 @@ func (s Service) CreateAuthMethod(ctx context.Context, req *pbs.CreateAuthMethod outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, am.GetPublicId(), IdActions[globals.ResourceInfoFromPrefix(am.GetPublicId()).Subtype]).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope.Id, am.GetPublicId()) + collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope, am.GetPublicId()) if err != nil { return nil, err } @@ -438,7 +438,7 @@ func (s Service) UpdateAuthMethod(ctx context.Context, req *pbs.UpdateAuthMethod outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, am.GetPublicId(), IdActions[globals.ResourceInfoFromPrefix(am.GetPublicId()).Subtype]).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope.Id, am.GetPublicId()) + collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope, am.GetPublicId()) if err != nil { return nil, err } @@ -492,7 +492,7 @@ func (s Service) ChangeState(ctx context.Context, req *pbs.ChangeStateRequest) ( outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, am.GetPublicId(), IdActions[globals.ResourceInfoFromPrefix(am.GetPublicId()).Subtype]).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope.Id, am.GetPublicId()) + collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope, am.GetPublicId()) if err != nil { return nil, err } @@ -1442,6 +1442,10 @@ func (s Service) convertToAuthenticateResponse(ctx context.Context, req *pbs.Aut ScopeId: authResults.Scope.Id, Type: resource.AuthToken, } + // Auth methods are only at global or org, so we can figure out the parent + if strings.HasPrefix(res.ScopeId, scope.Org.Prefix()) { + res.ParentScopeId = scope.Global.String() + } tokenType := req.GetType() if tokenType == "" { // Fall back to deprecated field if type is not set @@ -1588,7 +1592,7 @@ func newOutputOpts(ctx context.Context, item auth.AuthMethod, scopeInfoMap map[s outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authorizedActions)) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope.Id, item.GetPublicId()) + collectionActions, err := requestauth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap, authResults.Scope, item.GetPublicId()) if err != nil { return nil, false, err } diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go index ae3f881d7d..694f2bd588 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go @@ -996,10 +996,10 @@ func calculateAuthorizedCollectionActions(ctx context.Context, authResults auth. var err error switch globals.ResourceInfoFromPrefix(id).Subtype { case vault.Subtype: - collectionActions, err = auth.CalculateAuthorizedCollectionActions(ctx, authResults, vaultCollectionTypeMap, authResults.Scope.Id, id) + collectionActions, err = auth.CalculateAuthorizedCollectionActions(ctx, authResults, vaultCollectionTypeMap, authResults.Scope, id) case static.Subtype: - collectionActions, err = auth.CalculateAuthorizedCollectionActions(ctx, authResults, staticCollectionTypeMap, authResults.Scope.Id, id) + collectionActions, err = auth.CalculateAuthorizedCollectionActions(ctx, authResults, staticCollectionTypeMap, authResults.Scope, id) } if err != nil { return nil, err diff --git a/internal/daemon/controller/handlers/host_catalogs/host_catalog_service.go b/internal/daemon/controller/handlers/host_catalogs/host_catalog_service.go index 1e96304983..26b0f7b3fd 100644 --- a/internal/daemon/controller/handlers/host_catalogs/host_catalog_service.go +++ b/internal/daemon/controller/handlers/host_catalogs/host_catalog_service.go @@ -342,7 +342,7 @@ func (s Service) GetHostCatalog(ctx context.Context, req *pbs.GetHostCatalogRequ subtype = hostplugin.Subtype } if subtype != "" { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope.Id, hc.GetPublicId()) + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope, hc.GetPublicId()) if err != nil { return nil, err } @@ -399,7 +399,7 @@ func (s Service) CreateHostCatalog(ctx context.Context, req *pbs.CreateHostCatal subtype = hostplugin.Subtype } if subtype != "" { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope.Id, hc.GetPublicId()) + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope, hc.GetPublicId()) if err != nil { return nil, err } @@ -462,7 +462,7 @@ func (s Service) UpdateHostCatalog(ctx context.Context, req *pbs.UpdateHostCatal subtype = hostplugin.Subtype } if subtype != "" { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope.Id, hc.GetPublicId()) + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope, hc.GetPublicId()) if err != nil { return nil, err } @@ -795,7 +795,7 @@ func newOutputOpts( subtype = hostplugin.Subtype } if subtype != "" { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope.Id, item.GetPublicId()) + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, collectionTypeMap[subtype], authResults.Scope, item.GetPublicId()) if err != nil { return nil, false, err } diff --git a/internal/daemon/controller/handlers/roles/role_service.go b/internal/daemon/controller/handlers/roles/role_service.go index 091777296b..81c743bcf5 100644 --- a/internal/daemon/controller/handlers/roles/role_service.go +++ b/internal/daemon/controller/handlers/roles/role_service.go @@ -1123,7 +1123,7 @@ func toProto(ctx context.Context, in *iam.Role, principals []*iam.PrincipalRole, } if outputFields.Has(globals.GrantsField) { for _, g := range grants { - parsed, err := perms.Parse(ctx, in.GetScopeId(), g.GetRawGrant()) + parsed, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: in.GetPublicId(), GrantScopeId: in.GetScopeId(), Grant: g.GetRawGrant()}) if err != nil { // This should never happen as we validate on the way in, but let's // return what we can since we are still returning the raw grant @@ -1319,7 +1319,7 @@ func validateAddRoleGrantsRequest(ctx context.Context, req *pbs.AddRoleGrantsReq badFields["grant_strings"] = "Grant strings must not be empty." break } - grant, err := perms.Parse(ctx, "p_anything", v) + grant, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: req.GetId(), GrantScopeId: "p_anything", Grant: v}) if err != nil { badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v) break @@ -1356,9 +1356,9 @@ func validateSetRoleGrantsRequest(ctx context.Context, req *pbs.SetRoleGrantsReq badFields["grant_strings"] = "Grant strings must not be empty." break } - grant, err := perms.Parse(ctx, "p_anything", v) + grant, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: req.GetId(), GrantScopeId: "p_anything", Grant: v}) if err != nil { - badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v) + badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q: %s.", v, err.Error()) break } _, actStrs := grant.Actions() @@ -1396,7 +1396,7 @@ func validateRemoveRoleGrantsRequest(ctx context.Context, req *pbs.RemoveRoleGra badFields["grant_strings"] = "Grant strings must not be empty." break } - if _, err := perms.Parse(ctx, "p_anything", v); err != nil { + if _, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: req.GetId(), GrantScopeId: "p_anything", Grant: v}); err != nil { badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v) break } diff --git a/internal/daemon/controller/handlers/roles/role_service_test.go b/internal/daemon/controller/handlers/roles/role_service_test.go index 05524d693f..17c0535d52 100644 --- a/internal/daemon/controller/handlers/roles/role_service_test.go +++ b/internal/daemon/controller/handlers/roles/role_service_test.go @@ -1026,7 +1026,7 @@ func TestCreate(t *testing.T) { func TestUpdate(t *testing.T) { ctx := context.Background() grantString := "ids=*;type=*;actions=*" - g, err := perms.Parse(context.Background(), "global", grantString) + g, err := perms.Parse(context.Background(), perms.GrantTuple{RoleScopeId: "global", GrantScopeId: "global", Grant: grantString}) require.NoError(t, err) _, actions := g.Actions() grant := &pb.Grant{ @@ -2127,7 +2127,7 @@ func checkEqualGrants(t *testing.T, expected []string, got *pb.Role) { return got.GrantStrings[i] < got.GrantStrings[j] }) for i, v := range expected { - parsed, err := perms.Parse(context.Background(), "o_abc123", v) + parsed, err := perms.Parse(context.Background(), perms.GrantTuple{RoleScopeId: "o_abc123", GrantScopeId: "o_abc123", Grant: v}) require.NoError(err) assert.Equal(expected[i], got.GrantStrings[i]) assert.Equal(expected[i], got.Grants[i].GetRaw()) diff --git a/internal/daemon/controller/handlers/scopes/scope_service.go b/internal/daemon/controller/handlers/scopes/scope_service.go index 17cdf0ad2c..951015de39 100644 --- a/internal/daemon/controller/handlers/scopes/scope_service.go +++ b/internal/daemon/controller/handlers/scopes/scope_service.go @@ -337,7 +337,14 @@ func (s *Service) GetScope(ctx context.Context, req *pbs.GetScopeRequest) (*pbs. outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, p.GetPublicId(), idActionsById(p.GetPublicId())).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[p.Type], p.GetPublicId(), "") + scopeInfo := &pb.ScopeInfo{ + Id: p.GetPublicId(), + Type: p.Type, + Name: p.GetName(), + Description: p.GetDescription(), + ParentScopeId: p.GetParentId(), + } + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[p.Type], scopeInfo, "") if err != nil { return nil, err } @@ -382,7 +389,14 @@ func (s *Service) CreateScope(ctx context.Context, req *pbs.CreateScopeRequest) outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, p.GetPublicId(), idActionsById(p.GetPublicId())).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[p.Type], p.GetPublicId(), "") + scopeInfo := &pb.ScopeInfo{ + Id: p.GetPublicId(), + Type: p.Type, + Name: p.GetName(), + Description: p.GetDescription(), + ParentScopeId: p.GetParentId(), + } + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[p.Type], scopeInfo, "") if err != nil { return nil, err } @@ -427,7 +441,14 @@ func (s *Service) UpdateScope(ctx context.Context, req *pbs.UpdateScopeRequest) outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authResults.FetchActionSetForId(ctx, p.GetPublicId(), idActionsById(p.GetPublicId())).Strings())) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[p.Type], p.GetPublicId(), "") + scopeInfo := &pb.ScopeInfo{ + Id: p.GetPublicId(), + Type: p.Type, + Name: p.GetName(), + Description: p.GetDescription(), + ParentScopeId: p.GetParentId(), + } + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[p.Type], scopeInfo, "") if err != nil { return nil, err } @@ -1137,7 +1158,14 @@ func newOutputOpts(ctx context.Context, item *iam.Scope, authResults auth.Verify outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authorizedActions)) } if outputFields.Has(globals.AuthorizedCollectionActionsField) { - collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[item.Type], item.GetPublicId(), "") + scopeInfo := &pb.ScopeInfo{ + Id: item.GetPublicId(), + Type: item.Type, + Name: item.GetName(), + Description: item.GetDescription(), + ParentScopeId: item.GetParentId(), + } + collectionActions, err := auth.CalculateAuthorizedCollectionActions(ctx, authResults, scopeCollectionTypeMapMap[item.Type], scopeInfo, "") if err != nil { return nil, false, err } diff --git a/internal/daemon/controller/handlers/sessions/session_service.go b/internal/daemon/controller/handlers/sessions/session_service.go index 7df9fed100..59230bb0a1 100644 --- a/internal/daemon/controller/handlers/sessions/session_service.go +++ b/internal/daemon/controller/handlers/sessions/session_service.go @@ -575,9 +575,10 @@ func validateCancelRequest(req *pbs.CancelSessionRequest) error { func newOutputOpts(ctx context.Context, item *session.Session, scopeIds map[string]*scopes.ScopeInfo, authResults auth.VerifyResults) ([]handlers.Option, bool) { res := perms.Resource{ - Type: resource.Session, - Id: item.GetPublicId(), - ScopeId: item.GetProjectId(), + Type: resource.Session, + Id: item.GetPublicId(), + ScopeId: item.GetProjectId(), + ParentScopeId: scopeIds[item.ProjectId].ParentScopeId, } authorizedActions := authResults.FetchActionSetForId(ctx, item.GetPublicId(), IdActions, auth.WithResource(&res)).Strings() if len(authorizedActions) == 0 { diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index 48e5be7358..e05eeaec49 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -2006,7 +2006,7 @@ func validateListRequest(ctx context.Context, req *pbs.ListTargetsRequest) error } func newOutputOpts(ctx context.Context, item target.Target, authResults auth.VerifyResults, authzScopes map[string]*scopes.ScopeInfo) []handlers.Option { - pr := perms.Resource{Id: item.GetPublicId(), ScopeId: item.GetProjectId(), Type: resource.Target} + pr := perms.Resource{Id: item.GetPublicId(), ScopeId: item.GetProjectId(), Type: resource.Target, ParentScopeId: authzScopes[item.GetProjectId()].GetParentScopeId()} outputFields := authResults.FetchOutputFields(pr, action.List).SelfOrDefaults(authResults.UserId) outputOpts := make([]handlers.Option, 0, 3) diff --git a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go index bc9db3588b..92274e83e1 100644 --- a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go +++ b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go @@ -489,13 +489,8 @@ func TestList(t *testing.T) { } } -func TestListPagination(t *testing.T) { - // Set database read timeout to avoid duplicates in response - oldReadTimeout := globals.RefreshReadLookbackDuration - globals.RefreshReadLookbackDuration = 0 - t.Cleanup(func() { - globals.RefreshReadLookbackDuration = oldReadTimeout - }) +func TestListGrantScopes(t *testing.T) { + t.Parallel() ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") sqlDB, err := conn.SqlDB(ctx) @@ -515,315 +510,628 @@ func TestListPagination(t *testing.T) { serversRepoFn := func() (*server.Repository, error) { return server.NewRepository(ctx, rw, rw, kms) } - repo, err := target.NewRepository(ctx, rw, rw, kms) - require.NoError(t, err) - org, proj := iam.TestScopes(t, iamRepo) - at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) - r := iam.TestRole(t, conn, proj.GetPublicId()) - _ = iam.TestUserRole(t, conn, r.GetPublicId(), at.GetIamUserId()) - _ = iam.TestRoleGrant(t, conn, r.GetPublicId(), "ids=*;type=*;actions=*") - hc := static.TestCatalogs(t, conn, proj.GetPublicId(), 1)[0] - hss := static.TestSets(t, conn, hc.GetPublicId(), 2) - s, err := testService(t, context.Background(), conn, kms, wrapper) - require.NoError(t, err) + at := authtoken.TestAuthToken(t, conn, kms, scope.Global.String()) - var allTargets []*pb.Target - for i := 0; i < 10; i++ { - tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("tar%d", i), target.WithHostSources([]string{hss[0].GetPublicId(), hss[1].GetPublicId()})) - allTargets = append(allTargets, &pb.Target{ - Id: tar.GetPublicId(), - ScopeId: proj.GetPublicId(), - Name: wrapperspb.String(tar.GetName()), - Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org.GetPublicId()}, - CreatedTime: tar.GetCreateTime().GetTimestamp(), - UpdatedTime: tar.GetUpdateTime().GetTimestamp(), - Version: tar.GetVersion(), - Type: tcp.Subtype.String(), - Attrs: &pb.Target_TcpTargetAttributes{}, - SessionMaxSeconds: wrapperspb.UInt32(28800), - SessionConnectionLimit: wrapperspb.Int32(-1), - AuthorizedActions: testAuthorizedActions, - Address: &wrapperspb.StringValue{}, - }) + var projects []*iam.Scope + org1, proj1 := iam.TestScopes(t, iamRepo) + projects = append(projects, proj1) + org2, proj2 := iam.TestScopes(t, iamRepo) + projects = append(projects, proj2) + + var totalTars []*pb.Target + for i, proj := range projects { + for j := 0; j < 5; j++ { + name := fmt.Sprintf("tar-%d-%d", i, j) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), name, target.WithAddress(fmt.Sprintf("1.1.%d.%d", i, j))) + totalTars = append(totalTars, &pb.Target{ + Id: tar.GetPublicId(), + ScopeId: proj.GetPublicId(), + Name: wrapperspb.String(name), + Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: proj.ParentId}, + CreatedTime: tar.GetCreateTime().GetTimestamp(), + UpdatedTime: tar.GetUpdateTime().GetTimestamp(), + Version: tar.GetVersion(), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{}, + SessionMaxSeconds: wrapperspb.UInt32(28800), + SessionConnectionLimit: wrapperspb.Int32(-1), + AuthorizedActions: testAuthorizedActions, + Address: &wrapperspb.StringValue{Value: fmt.Sprintf("1.1.%d.%d", i, j)}, + }) + } } - // Reverse since we read items in descending order (newest first) - slices.Reverse(allTargets) // Run analyze to update postgres estimates _, err = sqlDB.ExecContext(ctx, "analyze") require.NoError(t, err) - requestInfo := authpb.RequestInfo{ - TokenFormat: uint32(auth.AuthTokenTypeBearer), - PublicId: at.GetPublicId(), - Token: at.GetToken(), - } - requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) - ctx = auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + _ = org1 + _ = org2 - // Start paginating, recursively - req := &pbs.ListTargetsRequest{ - ScopeId: "global", - Recursive: true, - Filter: "", - ListToken: "", - PageSize: 2, - } - got, err := s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 2) - // Compare without comparing the list token - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: allTargets[0:2], - ResponseType: "delta", + cases := []struct { + name string + pageSize uint32 + setupFunc func(t *testing.T) + res *pbs.ListTargetsResponse + err error + }{ + { + name: "global-with-direct-grants-wildcard", + setupFunc: func(t *testing.T) { + globalRole := iam.TestRole(t, conn, scope.Global.String(), iam.WithGrantScopeIds([]string{proj1.GetPublicId(), proj2.GetPublicId()})) + _ = iam.TestUserRole(t, conn, globalRole.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, globalRole.GetPublicId(), "ids=*;type=*;actions=*") + }, + res: &pbs.ListTargetsResponse{ + Items: totalTars, + ResponseType: "complete", SortBy: "created_time", SortDir: "desc", - RemovedIds: nil, EstItemCount: 10, }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - - // Request second page - req.ListToken = got.ListToken - got, err = s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 2) - // Compare without comparing the list token - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: allTargets[2:4], - ResponseType: "delta", + }, + { + name: "global-with-direct-grants-non-wildcard", + setupFunc: func(t *testing.T) { + globalRole := iam.TestRole(t, conn, scope.Global.String(), iam.WithGrantScopeIds([]string{proj1.GetPublicId(), proj2.GetPublicId()})) + _ = iam.TestUserRole(t, conn, globalRole.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, globalRole.GetPublicId(), "ids=*;type=target;actions=list") + _ = iam.TestRoleGrant(t, conn, globalRole.GetPublicId(), fmt.Sprintf("ids=%s,%s;actions=*", totalTars[0].Id, totalTars[1].Id)) + }, + res: &pbs.ListTargetsResponse{ + Items: totalTars[0:2], + ResponseType: "complete", SortBy: "created_time", SortDir: "desc", - RemovedIds: nil, - EstItemCount: 10, + EstItemCount: 2, }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - - // Request rest of results - req.ListToken = got.ListToken - req.PageSize = 10 - got, err = s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 6) - // Compare without comparing the list token - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: allTargets[4:], + }, + { + name: "global-with-descendants-wildcard", + setupFunc: func(t *testing.T) { + globalRole := iam.TestRole(t, conn, scope.Global.String(), iam.WithGrantScopeIds([]string{globals.GrantScopeDescendants})) + _ = iam.TestUserRole(t, conn, globalRole.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, globalRole.GetPublicId(), "ids=*;type=*;actions=*") + }, + res: &pbs.ListTargetsResponse{ + Items: totalTars, ResponseType: "complete", SortBy: "created_time", SortDir: "desc", - RemovedIds: nil, EstItemCount: 10, }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - - // Create another target - tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), "test-target", target.WithHostSources([]string{hss[0].GetPublicId(), hss[1].GetPublicId()})) - newTarget := &pb.Target{ - Id: tar.GetPublicId(), - ScopeId: proj.GetPublicId(), - Name: wrapperspb.String(tar.GetName()), - Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org.GetPublicId()}, - CreatedTime: tar.GetCreateTime().GetTimestamp(), - UpdatedTime: tar.GetUpdateTime().GetTimestamp(), - Version: tar.GetVersion(), - Type: tcp.Subtype.String(), - Attrs: &pb.Target_TcpTargetAttributes{}, - SessionMaxSeconds: wrapperspb.UInt32(28800), - SessionConnectionLimit: wrapperspb.Int32(-1), - AuthorizedActions: testAuthorizedActions, - Address: &wrapperspb.StringValue{}, - } - // Add to the front since it's most recently updated - allTargets = append([]*pb.Target{newTarget}, allTargets...) - - // Delete one of the other targets - _, err = repo.DeleteTarget(ctx, allTargets[len(allTargets)-1].Id) - require.NoError(t, err) - deletedTarget := allTargets[len(allTargets)-1] - allTargets = allTargets[:len(allTargets)-1] - - // Update one of the other targets - allTargets[1].Name = wrapperspb.String("new-name") - allTargets[1].Version = 2 - updatedTarget := &tcp.Target{ - Target: &store.Target{ - PublicId: allTargets[1].Id, - Name: allTargets[1].Name.GetValue(), - ProjectId: allTargets[1].ScopeId, }, - } - tg, _, err := repo.UpdateTarget(ctx, updatedTarget, 1, []string{"name"}) - require.NoError(t, err) - allTargets[1].UpdatedTime = tg.GetUpdateTime().GetTimestamp() - allTargets[1].Version = tg.GetVersion() - // Add to the front since it's most recently updated - allTargets = append( - []*pb.Target{allTargets[1]}, - append( - []*pb.Target{allTargets[0]}, - allTargets[2:]..., - )..., - ) - - // Run analyze to update postgres estimates - _, err = sqlDB.ExecContext(ctx, "analyze") - require.NoError(t, err) - - // Request updated results - req.ListToken = got.ListToken - req.PageSize = 1 - got, err = s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 1) - // Compare without comparing the list token - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: []*pb.Target{allTargets[0]}, - ResponseType: "delta", - SortBy: "updated_time", - SortDir: "desc", - // Should contain the deleted target - RemovedIds: []string{deletedTarget.Id}, - EstItemCount: 10, + { + name: "org-with-direct-grants-wildcard", + setupFunc: func(t *testing.T) { + org1Role := iam.TestRole(t, conn, org1.GetPublicId(), iam.WithGrantScopeIds([]string{proj1.GetPublicId()})) + _ = iam.TestUserRole(t, conn, org1Role.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, org1Role.GetPublicId(), "ids=*;type=*;actions=*") + org2Role := iam.TestRole(t, conn, org2.GetPublicId(), iam.WithGrantScopeIds([]string{proj2.GetPublicId()})) + _ = iam.TestUserRole(t, conn, org2Role.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, org2Role.GetPublicId(), "ids=*;type=*;actions=*") }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - - // Get next page - req.ListToken = got.ListToken - got, err = s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 1) - // Compare without comparing the list token - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: []*pb.Target{allTargets[1]}, + res: &pbs.ListTargetsResponse{ + Items: totalTars, ResponseType: "complete", - SortBy: "updated_time", + SortBy: "created_time", SortDir: "desc", - RemovedIds: nil, EstItemCount: 10, }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - - // Request new page with filter requiring looping - // to fill the page. - req.ListToken = "" - req.PageSize = 1 - req.Filter = fmt.Sprintf(`"/item/id"==%q or "/item/id"==%q`, allTargets[len(allTargets)-2].Id, allTargets[len(allTargets)-1].Id) - got, err = s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 1) - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: []*pb.Target{allTargets[len(allTargets)-2]}, - ResponseType: "delta", + }, + { + name: "org-with-direct-grants-non-wildcard", + setupFunc: func(t *testing.T) { + org1Role := iam.TestRole(t, conn, org1.GetPublicId(), iam.WithGrantScopeIds([]string{proj1.GetPublicId()})) + _ = iam.TestUserRole(t, conn, org1Role.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, org1Role.GetPublicId(), "ids=*;type=target;actions=list") + _ = iam.TestRoleGrant(t, conn, org1Role.GetPublicId(), fmt.Sprintf("ids=%s,%s;actions=*", totalTars[0].Id, totalTars[1].Id)) + org2Role := iam.TestRole(t, conn, org2.GetPublicId(), iam.WithGrantScopeIds([]string{proj2.GetPublicId()})) + _ = iam.TestUserRole(t, conn, org2Role.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, org2Role.GetPublicId(), "ids=*;type=target;actions=list") + _ = iam.TestRoleGrant(t, conn, org2Role.GetPublicId(), fmt.Sprintf("ids=%s,%s;actions=*", totalTars[5].Id, totalTars[6].Id)) + }, + res: &pbs.ListTargetsResponse{ + Items: append([]*pb.Target{}, append(append([]*pb.Target{}, totalTars[0:2]...), totalTars[5:7]...)...), + ResponseType: "complete", SortBy: "created_time", SortDir: "desc", - // Should be empty again - RemovedIds: nil, - EstItemCount: 10, + EstItemCount: 4, }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - req.ListToken = got.ListToken - // Get the second page - got, err = s.ListTargets(ctx, req) - require.NoError(t, err) - require.Len(t, got.GetItems(), 1) - assert.Empty(t, - cmp.Diff( - got, - &pbs.ListTargetsResponse{ - Items: []*pb.Target{allTargets[len(allTargets)-1]}, + }, + { + name: "org-with-children-wildcard", + setupFunc: func(t *testing.T) { + org1Role := iam.TestRole(t, conn, org1.GetPublicId(), iam.WithGrantScopeIds([]string{globals.GrantScopeChildren})) + _ = iam.TestUserRole(t, conn, org1Role.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, org1Role.GetPublicId(), "ids=*;type=*;actions=*") + }, + res: &pbs.ListTargetsResponse{ + Items: totalTars[0:5], ResponseType: "complete", SortBy: "created_time", SortDir: "desc", - RemovedIds: nil, - EstItemCount: 10, + EstItemCount: 5, }, - cmpopts.SortSlices(func(a, b string) bool { - return a < b - }), - protocmp.Transform(), - protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), - ), - ) - - // Create unauthenticated user - unauthAt := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) - unauthR := iam.TestRole(t, conn, proj.GetPublicId()) - _ = iam.TestUserRole(t, conn, unauthR.GetPublicId(), unauthAt.GetIamUserId()) - - // Make a request with the unauthenticated user, - // ensure the response contains the pagination parameters. - requestInfo = authpb.RequestInfo{ - TokenFormat: uint32(auth.AuthTokenTypeBearer), - PublicId: unauthAt.GetPublicId(), - Token: unauthAt.GetToken(), + }, } - requestContext = context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) - ctx = auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + _, err := sqlDB.Exec("delete from iam_role") + require.NoError(err) + tc.setupFunc(t) + + s, err := testService(t, context.Background(), conn, kms, wrapper) + require.NoError(err, "Couldn't create new target service.") + + requestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: at.GetPublicId(), + Token: at.GetToken(), + } + requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + got, gErr := s.ListTargets(ctx, &pbs.ListTargetsRequest{ + ScopeId: scope.Global.String(), + Recursive: true, + PageSize: tc.pageSize, + }) + if tc.err != nil { + require.Error(gErr) + assert.True(errors.Is(gErr, tc.err), "got error %v, wanted %v", gErr, tc.err) + return + } + require.NoError(gErr) + assert.Equal(len(tc.res.Items), len(got.Items)) + wantById := make(map[string]*pb.Target, len(tc.res.Items)) + for _, t := range tc.res.Items { + wantById[t.Id] = t + } + for _, t := range got.Items { + want, ok := wantById[t.Id] + assert.True(ok, "Got unexpected target with id: %s", t.Id) + assert.Empty(cmp.Diff( + t, + want, + protocmp.Transform(), + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + ), "got %v, wanted %v", t, want) + } + }) + } +} + +func TestListPagination(t *testing.T) { + testListPagination := func(t *testing.T, useDescendants bool) { + // Set database read timeout to avoid duplicates in response + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + sqlDB, err := conn.SqlDB(ctx) + require.NoError(t, err) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + + rw := db.New(conn) + + iamRepo := iam.TestRepo(t, conn, wrapper) + iamRepoFn := func() (*iam.Repository, error) { + return iamRepo, nil + } + tokenRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(ctx, rw, rw, kms) + } + serversRepoFn := func() (*server.Repository, error) { + return server.NewRepository(ctx, rw, rw, kms) + } + repo, err := target.NewRepository(ctx, rw, rw, kms) + require.NoError(t, err) + + // We're going to run the same test in two projects; one with + // descendants and one with direct grants in one project and a child + // grant from org in another project + org1, proj1 := iam.TestScopes(t, iamRepo) + org2, proj2 := iam.TestScopes(t, iamRepo) + at := authtoken.TestAuthToken(t, conn, kms, scope.Global.String()) + if useDescendants { + r := iam.TestRole(t, conn, scope.Global.String(), iam.WithGrantScopeIds([]string{globals.GrantScopeDescendants})) + _ = iam.TestUserRole(t, conn, r.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, r.GetPublicId(), "ids=*;type=*;actions=*") + } else { + r1 := iam.TestRole(t, conn, proj1.GetPublicId()) + _ = iam.TestUserRole(t, conn, r1.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, r1.GetPublicId(), "ids=*;type=*;actions=*") + r2 := iam.TestRole(t, conn, org2.GetPublicId(), iam.WithGrantScopeIds([]string{globals.GrantScopeChildren})) + _ = iam.TestUserRole(t, conn, r2.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, r2.GetPublicId(), "ids=*;type=*;actions=*") + } + hc := static.TestCatalogs(t, conn, proj1.GetPublicId(), 1)[0] + hss := static.TestSets(t, conn, hc.GetPublicId(), 2) + s, err := testService(t, context.Background(), conn, kms, wrapper) + require.NoError(t, err) + + var allTargets []*pb.Target + for i := 0; i < 10; i++ { + tar := tcp.TestTarget(ctx, t, conn, proj1.GetPublicId(), fmt.Sprintf("tar-1-%d", i), target.WithHostSources([]string{hss[0].GetPublicId(), hss[1].GetPublicId()})) + allTargets = append(allTargets, &pb.Target{ + Id: tar.GetPublicId(), + ScopeId: proj1.GetPublicId(), + Name: wrapperspb.String(tar.GetName()), + Scope: &scopes.ScopeInfo{Id: proj1.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org1.GetPublicId()}, + CreatedTime: tar.GetCreateTime().GetTimestamp(), + UpdatedTime: tar.GetUpdateTime().GetTimestamp(), + Version: tar.GetVersion(), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{}, + SessionMaxSeconds: wrapperspb.UInt32(28800), + SessionConnectionLimit: wrapperspb.Int32(-1), + AuthorizedActions: testAuthorizedActions, + Address: &wrapperspb.StringValue{}, + }) + } + for i := 0; i < 10; i++ { + tar := tcp.TestTarget(ctx, t, conn, proj2.GetPublicId(), fmt.Sprintf("tar-2-%d", i), target.WithAddress(fmt.Sprintf("127.0.0.%d", i))) + allTargets = append(allTargets, &pb.Target{ + Id: tar.GetPublicId(), + ScopeId: proj2.GetPublicId(), + Name: wrapperspb.String(tar.GetName()), + Scope: &scopes.ScopeInfo{Id: proj2.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org2.GetPublicId()}, + CreatedTime: tar.GetCreateTime().GetTimestamp(), + UpdatedTime: tar.GetUpdateTime().GetTimestamp(), + Version: tar.GetVersion(), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{}, + SessionMaxSeconds: wrapperspb.UInt32(28800), + SessionConnectionLimit: wrapperspb.Int32(-1), + AuthorizedActions: testAuthorizedActions, + Address: &wrapperspb.StringValue{Value: fmt.Sprintf("127.0.0.%d", i)}, + }) + } + // Reverse since we read items in descending order (newest first) + slices.Reverse(allTargets) + + // Run analyze to update postgres estimates + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + requestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: at.GetPublicId(), + Token: at.GetToken(), + } + requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx = auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + + // Start paginating, recursively + req := &pbs.ListTargetsRequest{ + ScopeId: "global", + Recursive: true, + Filter: "", + ListToken: "", + PageSize: 2, + } + got, err := s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 2) + // Compare without comparing the list token + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: allTargets[0:2], + ResponseType: "delta", + SortBy: "created_time", + SortDir: "desc", + RemovedIds: nil, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) + + // Request second page + req.ListToken = got.ListToken + got, err = s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 2) + // Compare without comparing the list token + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: allTargets[2:4], + ResponseType: "delta", + SortBy: "created_time", + SortDir: "desc", + RemovedIds: nil, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) + + // Request rest of results + req.ListToken = got.ListToken + req.PageSize = 20 + got, err = s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 16) + // Compare without comparing the list token + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: allTargets[4:], + ResponseType: "complete", + SortBy: "created_time", + SortDir: "desc", + RemovedIds: nil, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) - got, err = s.ListTargets(ctx, &pbs.ListTargetsRequest{ - ScopeId: "global", - Recursive: true, + // Create another target + tar := tcp.TestTarget(ctx, t, conn, proj1.GetPublicId(), "test-target-1", target.WithHostSources([]string{hss[0].GetPublicId(), hss[1].GetPublicId()})) + newTarget := &pb.Target{ + Id: tar.GetPublicId(), + ScopeId: proj1.GetPublicId(), + Name: wrapperspb.String(tar.GetName()), + Scope: &scopes.ScopeInfo{Id: proj1.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org1.GetPublicId()}, + CreatedTime: tar.GetCreateTime().GetTimestamp(), + UpdatedTime: tar.GetUpdateTime().GetTimestamp(), + Version: tar.GetVersion(), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{}, + SessionMaxSeconds: wrapperspb.UInt32(28800), + SessionConnectionLimit: wrapperspb.Int32(-1), + AuthorizedActions: testAuthorizedActions, + Address: &wrapperspb.StringValue{}, + } + // Add to the front since it's most recently updated + allTargets = append([]*pb.Target{newTarget}, allTargets...) + tar = tcp.TestTarget(ctx, t, conn, proj2.GetPublicId(), "test-target-2", target.WithAddress(fmt.Sprintf("127.0.0.11"))) + newTarget = &pb.Target{ + Id: tar.GetPublicId(), + ScopeId: proj2.GetPublicId(), + Name: wrapperspb.String(tar.GetName()), + Scope: &scopes.ScopeInfo{Id: proj2.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org2.GetPublicId()}, + CreatedTime: tar.GetCreateTime().GetTimestamp(), + UpdatedTime: tar.GetUpdateTime().GetTimestamp(), + Version: tar.GetVersion(), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{}, + SessionMaxSeconds: wrapperspb.UInt32(28800), + SessionConnectionLimit: wrapperspb.Int32(-1), + AuthorizedActions: testAuthorizedActions, + Address: &wrapperspb.StringValue{Value: fmt.Sprintf("127.0.0.11")}, + } + allTargets = append([]*pb.Target{newTarget}, allTargets...) + + // Leaving this function here as it is very useful if test objects change + /* + printNames := func(step string, tars []*pb.Target) { + names := make([]string, len(tars)) + for i, t := range tars { + names[i] = t.GetName().GetValue() + } + log.Println(step, pretty.Sprint(strings.Join(names, ", "))) + } + */ + + // printNames("before delete ", allTargets) + + // Delete one of the other targets in each project + _, err = repo.DeleteTarget(ctx, allTargets[len(allTargets)-11].Id) + require.NoError(t, err) + deletedTarget1 := allTargets[len(allTargets)-11] + allTargets = append(allTargets[:len(allTargets)-11], allTargets[len(allTargets)-11+1:]...) + // printNames("after first delete ", allTargets) + + _, err = repo.DeleteTarget(ctx, allTargets[len(allTargets)-1].Id) + require.NoError(t, err) + deletedTarget2 := allTargets[len(allTargets)-1] + allTargets = allTargets[:len(allTargets)-1] + // printNames("after second delete", allTargets) + + // Update two of the other targets + allTargets[2].Name = wrapperspb.String("new-name-1") + allTargets[2].Version = 2 + updatedTarget := &tcp.Target{ + Target: &store.Target{ + PublicId: allTargets[2].Id, + Name: allTargets[2].Name.GetValue(), + ProjectId: allTargets[2].ScopeId, + }, + } + tg, _, err := repo.UpdateTarget(ctx, updatedTarget, 1, []string{"name"}) + require.NoError(t, err) + allTargets[2].UpdatedTime = tg.GetUpdateTime().GetTimestamp() + allTargets[2].Version = tg.GetVersion() + // Add to the front since it's most recently updated + newAllTargets := append([]*pb.Target{allTargets[2]}, allTargets[0:2]...) + newAllTargets = append(newAllTargets, allTargets[3:]...) + allTargets = newAllTargets + // printNames("after first update ", allTargets) + allTargets[11].Name = wrapperspb.String("new-name-11") + allTargets[11].Version = 2 + updatedTarget = &tcp.Target{ + Target: &store.Target{ + PublicId: allTargets[11].Id, + Name: allTargets[11].Name.GetValue(), + ProjectId: allTargets[11].ScopeId, + }, + } + tg, _, err = repo.UpdateTarget(ctx, updatedTarget, 1, []string{"name"}) + require.NoError(t, err) + allTargets[11].UpdatedTime = tg.GetUpdateTime().GetTimestamp() + allTargets[11].Version = tg.GetVersion() + // Add to the front since it's most recently updated + newAllTargets = append([]*pb.Target{allTargets[11]}, allTargets[0:11]...) + newAllTargets = append(newAllTargets, allTargets[12:]...) + allTargets = newAllTargets + // printNames("after second update", allTargets) + + // Run analyze to update postgres estimates + _, err = sqlDB.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // Request updated results + req.ListToken = got.ListToken + req.PageSize = 2 + got, err = s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 2) + // Compare without comparing the list token + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: []*pb.Target{allTargets[0], allTargets[1]}, + ResponseType: "delta", + SortBy: "updated_time", + SortDir: "desc", + // Should contain the deleted target + RemovedIds: []string{deletedTarget1.Id, deletedTarget2.Id}, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) + + // Get next page + req.ListToken = got.ListToken + got, err = s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 2) + // Compare without comparing the list token + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: []*pb.Target{allTargets[2], allTargets[3]}, + ResponseType: "complete", + SortBy: "updated_time", + SortDir: "desc", + RemovedIds: nil, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) + + // Request new page with filter requiring looping + // to fill the page. + req.ListToken = "" + req.PageSize = 1 + req.Filter = fmt.Sprintf(`"/item/id"==%q or "/item/id"==%q`, allTargets[len(allTargets)-2].Id, allTargets[len(allTargets)-1].Id) + got, err = s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 1) + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: []*pb.Target{allTargets[len(allTargets)-2]}, + ResponseType: "delta", + SortBy: "created_time", + SortDir: "desc", + // Should be empty again + RemovedIds: nil, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) + req.ListToken = got.ListToken + // Get the second page + got, err = s.ListTargets(ctx, req) + require.NoError(t, err) + require.Len(t, got.GetItems(), 1) + assert.Empty(t, + cmp.Diff( + got, + &pbs.ListTargetsResponse{ + Items: []*pb.Target{allTargets[len(allTargets)-1]}, + ResponseType: "complete", + SortBy: "created_time", + SortDir: "desc", + RemovedIds: nil, + EstItemCount: 20, + }, + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + protocmp.Transform(), + protocmp.IgnoreFields(&pbs.ListTargetsResponse{}, "list_token"), + ), + ) + + // Create unauthenticated user + unauthAt := authtoken.TestAuthToken(t, conn, kms, org1.GetPublicId()) + unauthR := iam.TestRole(t, conn, proj1.GetPublicId()) + _ = iam.TestUserRole(t, conn, unauthR.GetPublicId(), unauthAt.GetIamUserId()) + + // Make a request with the unauthenticated user, + // ensure the response contains the pagination parameters. + requestInfo = authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: unauthAt.GetPublicId(), + Token: unauthAt.GetToken(), + } + requestContext = context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx = auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + + got, err = s.ListTargets(ctx, &pbs.ListTargetsRequest{ + ScopeId: "global", + Recursive: true, + }) + require.NoError(t, err) + assert.Empty(t, got.Items) + assert.Equal(t, "created_time", got.SortBy) + assert.Equal(t, "desc", got.SortDir) + assert.Equal(t, "complete", got.ResponseType) + } + + t.Run("with-descendants", func(t *testing.T) { + testListPagination(t, true) + }) + t.Run("without-descendants", func(t *testing.T) { + testListPagination(t, false) }) - require.NoError(t, err) - assert.Empty(t, got.Items) - assert.Equal(t, "created_time", got.SortBy) - assert.Equal(t, "desc", got.SortDir) - assert.Equal(t, "complete", got.ResponseType) } func TestDelete(t *testing.T) { diff --git a/internal/daemon/controller/handlers/users/user_service.go b/internal/daemon/controller/handlers/users/user_service.go index 8ca5ac2852..83b09ed97b 100644 --- a/internal/daemon/controller/handlers/users/user_service.go +++ b/internal/daemon/controller/handlers/users/user_service.go @@ -506,7 +506,7 @@ func (s Service) ListResolvableAliases(ctx context.Context, req *pbs.ListResolva } } - permissions := acl.ListResolvablePermissions(resource.Target, targets.IdActions) + permissions := acl.ListResolvableAliasesPermissions(resource.Target, targets.IdActions) if len(permissions) == 0 { // if there are no permitted targets then there will be no aliases that @@ -615,15 +615,14 @@ func (s Service) aclAndGrantHashForUser(ctx context.Context, userId string) (per // Note: Below, we always skip validation so that we don't error on formats // that we've since restricted, e.g. "ids=foo;actions=create,read". These // will simply not have an effect. - for _, pair := range grantTuples { + for _, tuple := range grantTuples { permsOpts := []perms.Option{ perms.WithUserId(userId), perms.WithSkipFinalValidation(true), } parsed, err := perms.Parse( ctx, - pair.ScopeId, - pair.Grant, + tuple, permsOpts...) if err != nil { return perms.ACL{}, nil, errors.Wrap(ctx, err, op) diff --git a/internal/iam/options.go b/internal/iam/options.go index 7edbebabeb..990da7ace0 100644 --- a/internal/iam/options.go +++ b/internal/iam/options.go @@ -26,22 +26,23 @@ type Option func(*options) // options = how options are represented type options struct { - withPublicId string - withName string - withDescription string - withLimit int - withGrantScopeIds []string - withSkipVetForWrite bool - withDisassociate bool - withSkipAdminRoleCreation bool - withSkipDefaultRoleCreation bool - withUserId string - withRandomReader io.Reader - withAccountIds []string - withPrimaryAuthMethodId string - withReader db.Reader - withWriter db.Writer - withStartPageAfterItem pagination.Item + withPublicId string + withName string + withDescription string + withLimit int + withGrantScopeIds []string + withSkipVetForWrite bool + withDisassociate bool + withSkipAdminRoleCreation bool + withSkipDefaultRoleCreation bool + withUserId string + withRandomReader io.Reader + withAccountIds []string + withPrimaryAuthMethodId string + withReader db.Reader + withWriter db.Writer + withStartPageAfterItem pagination.Item + withTestCacheMultiGrantTuples *[]multiGrantTuple } func getDefaultOptions() options { @@ -175,3 +176,9 @@ func WithStartPageAfterItem(item pagination.Item) Option { o.withStartPageAfterItem = item } } + +func withTestCacheMultiGrantTuples(cache *[]multiGrantTuple) Option { + return func(o *options) { + o.withTestCacheMultiGrantTuples = cache + } +} diff --git a/internal/iam/query.go b/internal/iam/query.go index c5f1dc5808..74af0b713e 100644 --- a/internal/iam/query.go +++ b/internal/iam/query.go @@ -129,15 +129,21 @@ const ( from auth_account where iam_user_id in (select id from users) ), - user_managed_groups (id) as ( + user_oidc_managed_groups (id) as ( select managed_group_id - from auth_managed_group_member_account + from auth_oidc_managed_group_member_account + where member_id in (select id from user_accounts) + ), + user_ldap_managed_groups (id) as ( + select managed_group_id + from auth_ldap_managed_group_member_account where member_id in (select id from user_accounts) ), managed_group_roles (role_id) as ( - select role_id + select distinct role_id from iam_managed_group_role - where principal_id in (select id from user_managed_groups) + where principal_id in (select id from user_oidc_managed_groups) + or principal_id in (select id from user_ldap_managed_groups) ), group_roles (role_id) as ( select role_id @@ -159,83 +165,45 @@ const ( select role_id from managed_group_roles ), - roles (role_id, role_scope_id) as ( + -- Now that we have the role IDs, expand the information to include scope + roles (role_id, role_scope_id, role_parent_scope_id) as ( select iam_role.public_id, - iam_role.scope_id + iam_role.scope_id, + iam_scope.parent_id from iam_role - where public_id in (select role_id from user_group_roles) + join iam_scope + on iam_scope.public_id = iam_role.scope_id + where iam_role.public_id in (select role_id from user_group_roles) ), - role_grant_scopes (role_id, role_scope_id, grant_scope_id) as ( + grant_scopes (role_id, grant_scope_ids) as ( select roles.role_id, - roles.role_scope_id, - iam_role_grant_scope.scope_id_or_special + string_agg(iam_role_grant_scope.scope_id_or_special, '^') as grant_scope_ids from roles - inner join iam_role_grant_scope - on roles.role_id = iam_role_grant_scope.role_id - ), - -- For all role_ids with a special scope_id of 'descendants', we want to - -- perform a cartesian product to pair the role_id with all non-global scopes. - descendant_grant_scopes (role_id, grant_scope_id) as ( - select role_grant_scopes.role_id as role_id, - iam_scope.public_id as grant_scope_id - from role_grant_scopes, - iam_scope - where iam_scope.public_id != 'global' - and role_grant_scopes.grant_scope_id = 'descendants' - ), - children_grant_scopes (role_id, grant_scope_id) as ( - select role_grant_scopes.role_id as role_id, - iam_scope.public_id as grant_scope_id - from role_grant_scopes - join iam_scope - on iam_scope.parent_id = role_grant_scopes.role_scope_id - where role_grant_scopes.grant_scope_id = 'children' + join iam_role_grant_scope + on iam_role_grant_scope.role_id = roles.role_id + group by roles.role_id ), - this_grant_scopes (role_id, grant_scope_id) as ( - select role_grant_scopes.role_id as role_id, - role_grant_scopes.role_scope_id as grant_scope_id - from role_grant_scopes - where role_grant_scopes.grant_scope_id = 'this' - ), - direct_grant_scopes (role_id, grant_scope_id) as ( - select role_grant_scopes.role_id as role_id, - role_grant_scopes.grant_scope_id as grant_scope_id - from role_grant_scopes - where role_grant_scopes.grant_scope_id not in ('descendants', 'children', 'this') - ), - grant_scopes (role_id, grant_scope_id) as ( - select - role_id as role_id, - grant_scope_id as grant_scope_id - from descendant_grant_scopes - union - select - role_id as role_id, - grant_scope_id as grant_scope_id - from children_grant_scopes - union - select - role_id as role_id, - grant_scope_id as grant_scope_id - from this_grant_scopes - union - select - role_id as role_id, - grant_scope_id as grant_scope_id - from direct_grant_scopes - ), - final (role_id, grant_scope_id, canonical_grant) as ( - select grant_scopes.role_id, - grant_scopes.grant_scope_id, - iam_role_grant.canonical_grant - from grant_scopes - join iam_role_grant - on grant_scopes.role_id = iam_role_grant.role_id + grants (role_id, grants) as ( + select roles.role_id, + string_agg(iam_role_grant.canonical_grant, '^') as grants + from roles + join iam_role_grant + on iam_role_grant.role_id = roles.role_id + group by roles.role_id ) - select role_id as role_id, - grant_scope_id as scope_id, - canonical_grant as grant - from final; + -- Finally, take the resulting roles and pull grant scope IDs and canonical grants. + -- We will split these out in application logic to keep the result set size low. + select + roles.role_id as role_id, + roles.role_scope_id as role_scope_id, + roles.role_parent_scope_id as role_parent_scope_id, + grant_scopes.grant_scope_ids as grant_scope_ids, + grants.grants as grants + from roles + join grant_scopes + on grant_scopes.role_id = roles.role_id + join grants + on grants.role_id = roles.role_id; ` estimateCountRoles = ` diff --git a/internal/iam/repository_role_grant.go b/internal/iam/repository_role_grant.go index 116800053e..d9a8d63172 100644 --- a/internal/iam/repository_role_grant.go +++ b/internal/iam/repository_role_grant.go @@ -6,6 +6,8 @@ package iam import ( "context" "fmt" + "sort" + "strings" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/db" @@ -165,7 +167,7 @@ func (r *Repository) DeleteRoleGrants(ctx context.Context, roleId string, roleVe deleteRoleGrants := make([]*RoleGrant, 0, len(grants)) for _, grant := range grants { // Use a fake scope, just want to get out a canonical string - perm, err := perms.Parse(ctx, "o_abcd1234", grant, perms.WithSkipFinalValidation(true)) + perm, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: "o_abcd1234", GrantScopeId: "o_abcd1234", Grant: grant}, perms.WithSkipFinalValidation(true)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("parsing grant string")) } @@ -257,7 +259,7 @@ func (r *Repository) SetRoleGrants(ctx context.Context, roleId string, roleVersi deleteRoleGrants := make([]*RoleGrant, 0, len(grants)) for _, grant := range grants { // Use a fake scope, just want to get out a canonical string - perm, err := perms.Parse(ctx, "o_abcd1234", grant, perms.WithSkipFinalValidation(true)) + perm, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: "o_abcd1234", GrantScopeId: "o_abcd1234", Grant: grant}, perms.WithSkipFinalValidation(true)) if err != nil { return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("error parsing grant string")) } @@ -410,12 +412,22 @@ func (r *Repository) ListRoleGrantScopes(ctx context.Context, roleIds []string, return roleGrantScopes, nil } -func (r *Repository) GrantsForUser(ctx context.Context, userId string, _ ...Option) (perms.GrantTuples, error) { +type multiGrantTuple struct { + RoleId string + RoleScopeId string + RoleParentScopeId string + GrantScopeIds string + Grants string +} + +func (r *Repository) GrantsForUser(ctx context.Context, userId string, opt ...Option) (perms.GrantTuples, error) { const op = "iam.(Repository).GrantsForUser" if userId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing user id") } + opts := getOpts(opt...) + const ( anonUser = `where public_id in (?)` authUser = `where public_id in ('u_anon', 'u_auth', ?)` @@ -429,7 +441,7 @@ func (r *Repository) GrantsForUser(ctx context.Context, userId string, _ ...Opti query = fmt.Sprintf(grantsForUserQuery, authUser) } - var grants []perms.GrantTuple + var grants []multiGrantTuple rows, err := r.reader.Query(ctx, query, []any{userId}) if err != nil { return nil, errors.Wrap(ctx, err, op) @@ -443,5 +455,42 @@ func (r *Repository) GrantsForUser(ctx context.Context, userId string, _ ...Opti if err := rows.Err(); err != nil { return nil, errors.Wrap(ctx, err, op) } - return grants, nil + + ret := make(perms.GrantTuples, 0, len(grants)*3) + for _, grant := range grants { + for _, grantScopeId := range strings.Split(grant.GrantScopeIds, "^") { + for _, canonicalGrant := range strings.Split(grant.Grants, "^") { + gt := perms.GrantTuple{ + RoleId: grant.RoleId, + RoleScopeId: grant.RoleScopeId, + RoleParentScopeId: grant.RoleParentScopeId, + GrantScopeId: grantScopeId, + Grant: canonicalGrant, + } + if gt.GrantScopeId == globals.GrantScopeThis || gt.GrantScopeId == "" { + gt.GrantScopeId = grant.RoleScopeId + } + ret = append(ret, gt) + } + } + } + + if opts.withTestCacheMultiGrantTuples != nil { + for i, grant := range grants { + grant.testStableSort() + grants[i] = grant + } + *opts.withTestCacheMultiGrantTuples = grants + } + + return ret, nil +} + +func (m *multiGrantTuple) testStableSort() { + grantScopeIds := strings.Split(m.GrantScopeIds, "^") + sort.Strings(grantScopeIds) + m.GrantScopeIds = strings.Join(grantScopeIds, "^") + gts := strings.Split(m.Grants, "^") + sort.Strings(gts) + m.Grants = strings.Join(gts, "^") } diff --git a/internal/iam/repository_role_grant_test.go b/internal/iam/repository_role_grant_test.go index b90216a547..ba90363692 100644 --- a/internal/iam/repository_role_grant_test.go +++ b/internal/iam/repository_role_grant_test.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "math/rand" + "strings" "testing" "time" @@ -15,6 +16,8 @@ import ( "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/perms" + "github.com/hashicorp/boundary/internal/types/action" + "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/types/scope" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -138,17 +141,17 @@ func TestRepository_ListRoleGrants(t *testing.T) { opt []Option } tests := []struct { - name string - createCnt int - createScopeId string - args args - wantCnt int - wantErr bool + name string + createCnt int + createGrantScopeId string + args args + wantCnt int + wantErr bool }{ { - name: "no-limit", - createCnt: repo.defaultLimit + 2, - createScopeId: org.PublicId, + name: "no-limit", + createCnt: repo.defaultLimit + 2, + createGrantScopeId: org.PublicId, args: args{ opt: []Option{WithLimit(-1)}, }, @@ -156,9 +159,9 @@ func TestRepository_ListRoleGrants(t *testing.T) { wantErr: false, }, { - name: "no-limit-proj-group", - createCnt: repo.defaultLimit + 2, - createScopeId: proj.PublicId, + name: "no-limit-proj-group", + createCnt: repo.defaultLimit + 2, + createGrantScopeId: proj.PublicId, args: args{ opt: []Option{WithLimit(-1)}, }, @@ -166,16 +169,16 @@ func TestRepository_ListRoleGrants(t *testing.T) { wantErr: false, }, { - name: "default-limit", - createCnt: repo.defaultLimit + 2, - createScopeId: org.PublicId, - wantCnt: repo.defaultLimit, - wantErr: false, + name: "default-limit", + createCnt: repo.defaultLimit + 2, + createGrantScopeId: org.PublicId, + wantCnt: repo.defaultLimit, + wantErr: false, }, { - name: "custom-limit", - createCnt: repo.defaultLimit + 2, - createScopeId: org.PublicId, + name: "custom-limit", + createCnt: repo.defaultLimit + 2, + createGrantScopeId: org.PublicId, args: args{ opt: []Option{WithLimit(3)}, }, @@ -183,9 +186,9 @@ func TestRepository_ListRoleGrants(t *testing.T) { wantErr: false, }, { - name: "bad-role-id", - createCnt: 2, - createScopeId: org.PublicId, + name: "bad-role-id", + createCnt: 2, + createGrantScopeId: org.PublicId, args: args{ withRoleId: "bad-id", }, @@ -197,7 +200,7 @@ func TestRepository_ListRoleGrants(t *testing.T) { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) db.TestDeleteWhere(t, conn, func() any { r := allocRole(); return &r }(), "1=1") - role := TestRole(t, conn, tt.createScopeId) + role := TestRole(t, conn, tt.createGrantScopeId) roleGrants := make([]string, 0, tt.createCnt) for i := 0; i < tt.createCnt; i++ { roleGrants = append(roleGrants, fmt.Sprintf("ids=h_%d;actions=*", i)) @@ -573,7 +576,6 @@ func TestRepository_SetRoleGrants_Parameters(t *testing.T) { } func TestGrantsForUser(t *testing.T) { - require, assert := require.New(t), assert.New(t) ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") @@ -594,9 +596,9 @@ func TestGrantsForUser(t *testing.T) { WithSkipDefaultRoleCreation(true), ) noGrantOrg1Role := TestRole(t, conn, noGrantOrg1.PublicId) - TestRoleGrant(t, conn, noGrantOrg1Role.PublicId, "ids=o_noGrantOrg1;actions=*") + TestRoleGrant(t, conn, noGrantOrg1Role.PublicId, "ids=*;type=scope;actions=*") noGrantProj1Role := TestRole(t, conn, noGrantProj1.PublicId) - TestRoleGrant(t, conn, noGrantProj1Role.PublicId, "ids=p_noGrantProj1;actions=*") + TestRoleGrant(t, conn, noGrantProj1Role.PublicId, "ids=*;type=*;actions=*") noGrantOrg2, noGrantProj2 := TestScopes( t, repo, @@ -604,9 +606,9 @@ func TestGrantsForUser(t *testing.T) { WithSkipDefaultRoleCreation(true), ) noGrantOrg2Role := TestRole(t, conn, noGrantOrg2.PublicId) - TestRoleGrant(t, conn, noGrantOrg2Role.PublicId, "ids=o_noGrantOrg2;actions=*") + TestRoleGrant(t, conn, noGrantOrg2Role.PublicId, "ids=*;type=scope;actions=*") noGrantProj2Role := TestRole(t, conn, noGrantProj2.PublicId) - TestRoleGrant(t, conn, noGrantProj2Role.PublicId, "ids=p_noGrantProj2;actions=*") + TestRoleGrant(t, conn, noGrantProj2Role.PublicId, "ids=*;type=*;actions=*") // The second org/project set contains direct grants, but without // inheritance. We create two roles in each project. @@ -623,23 +625,22 @@ func TestGrantsForUser(t *testing.T) { WithSkipAdminRoleCreation(true), WithSkipDefaultRoleCreation(true), ) - directGrantOrg1Role := TestRole(t, conn, directGrantOrg1.PublicId, - WithGrantScopeIds([]string{ - globals.GrantScopeThis, - directGrantProj1a.PublicId, - directGrantProj1b.PublicId, - })) + directGrantOrg1Role := TestRole(t, conn, directGrantOrg1.PublicId) TestUserRole(t, conn, directGrantOrg1Role.PublicId, user.PublicId) - directGrantOrg1RoleGrant := "ids=o_directGrantOrg1;actions=*" - TestRoleGrant(t, conn, directGrantOrg1Role.PublicId, directGrantOrg1RoleGrant) + directGrantOrg1RoleGrant1 := "ids=*;type=*;actions=*" + TestRoleGrant(t, conn, directGrantOrg1Role.PublicId, directGrantOrg1RoleGrant1) + directGrantOrg1RoleGrant2 := "ids=*;type=role;actions=list,read" + TestRoleGrant(t, conn, directGrantOrg1Role.PublicId, directGrantOrg1RoleGrant2) + directGrantProj1aRole := TestRole(t, conn, directGrantProj1a.PublicId) TestUserRole(t, conn, directGrantProj1aRole.PublicId, user.PublicId) - directGrantProj1aRoleGrant := "ids=p_directGrantProj1a;actions=*" + directGrantProj1aRoleGrant := "ids=*;type=target;actions=authorize-session,read" TestRoleGrant(t, conn, directGrantProj1aRole.PublicId, directGrantProj1aRoleGrant) directGrantProj1bRole := TestRole(t, conn, directGrantProj1b.PublicId) TestUserRole(t, conn, directGrantProj1bRole.PublicId, user.PublicId) - directGrantProj1bRoleGrant := "ids=p_directGrantProj1b;actions=*" + directGrantProj1bRoleGrant := "ids=*;type=session;actions=list,read" TestRoleGrant(t, conn, directGrantProj1bRole.PublicId, directGrantProj1bRoleGrant) + directGrantOrg2, directGrantProj2a := TestScopes( t, repo, @@ -657,32 +658,27 @@ func TestGrantsForUser(t *testing.T) { WithGrantScopeIds([]string{ globals.GrantScopeThis, directGrantProj2a.PublicId, - directGrantProj2b.PublicId, })) TestUserRole(t, conn, directGrantOrg2Role.PublicId, user.PublicId) - directGrantOrg2RoleGrant := "ids=o_directGrantOrg2;actions=*" - TestRoleGrant(t, conn, directGrantOrg2Role.PublicId, directGrantOrg2RoleGrant) + directGrantOrg2RoleGrant1 := "ids=*;type=user;actions=*" + TestRoleGrant(t, conn, directGrantOrg2Role.PublicId, directGrantOrg2RoleGrant1) + directGrantOrg2RoleGrant2 := "ids=*;type=group;actions=list,read" + TestRoleGrant(t, conn, directGrantOrg2Role.PublicId, directGrantOrg2RoleGrant2) + directGrantProj2aRole := TestRole(t, conn, directGrantProj2a.PublicId) TestUserRole(t, conn, directGrantProj2aRole.PublicId, user.PublicId) - directGrantProj2aRoleGrant := "ids=p_directGrantProj2a;actions=*" + directGrantProj2aRoleGrant := "ids=hcst_abcd1234,hcst_1234abcd;actions=*" TestRoleGrant(t, conn, directGrantProj2aRole.PublicId, directGrantProj2aRoleGrant) directGrantProj2bRole := TestRole(t, conn, directGrantProj2b.PublicId) TestUserRole(t, conn, directGrantProj2bRole.PublicId, user.PublicId) - directGrantProj2bRoleGrant := "ids=p_directGrantProj2b;actions=*" + directGrantProj2bRoleGrant := "ids=cs_abcd1234;actions=read,update" TestRoleGrant(t, conn, directGrantProj2bRole.PublicId, directGrantProj2bRoleGrant) // For the third set we create a couple of orgs/projects and then use - // "children". We expect to see no grant on the org but for both projects. - childGrantOrg1, childGrantProj1a := TestScopes( - t, - repo, - WithSkipAdminRoleCreation(true), - WithSkipDefaultRoleCreation(true), - ) - childGrantProj1b := TestProject( + // globals.GrantScopeChildren. + childGrantOrg1, childGrantOrg1Proj := TestScopes( t, repo, - childGrantOrg1.PublicId, WithSkipAdminRoleCreation(true), WithSkipDefaultRoleCreation(true), ) @@ -691,28 +687,23 @@ func TestGrantsForUser(t *testing.T) { globals.GrantScopeChildren, })) TestUserRole(t, conn, childGrantOrg1Role.PublicId, user.PublicId) - childGrantOrg1RoleGrant := "ids=o_childGrantOrg1;actions=*" + childGrantOrg1RoleGrant := "ids=*;type=host-set;actions=add-hosts,remove-hosts" TestRoleGrant(t, conn, childGrantOrg1Role.PublicId, childGrantOrg1RoleGrant) - childGrantOrg2, childGrantProj2a := TestScopes( + childGrantOrg2, childGrantOrg2Proj := TestScopes( t, repo, WithSkipAdminRoleCreation(true), WithSkipDefaultRoleCreation(true), ) - childGrantProj2b := TestProject( - t, - repo, - childGrantOrg2.PublicId, - WithSkipAdminRoleCreation(true), - WithSkipDefaultRoleCreation(true), - ) childGrantOrg2Role := TestRole(t, conn, childGrantOrg2.PublicId, WithGrantScopeIds([]string{ globals.GrantScopeChildren, })) TestUserRole(t, conn, childGrantOrg2Role.PublicId, user.PublicId) - childGrantOrg2RoleGrant := "ids=o_childGrantOrg2;actions=*" - TestRoleGrant(t, conn, childGrantOrg2Role.PublicId, childGrantOrg2RoleGrant) + childGrantOrg2RoleGrant1 := "ids=*;type=session;actions=cancel:self" + TestRoleGrant(t, conn, childGrantOrg2Role.PublicId, childGrantOrg2RoleGrant1) + childGrantOrg2RoleGrant2 := "ids=*;type=session;actions=read:self" + TestRoleGrant(t, conn, childGrantOrg2Role.PublicId, childGrantOrg2RoleGrant2) // Finally, let's create some roles at global scope with children and // descendants grants @@ -721,7 +712,7 @@ func TestGrantsForUser(t *testing.T) { globals.GrantScopeChildren, })) TestUserRole(t, conn, childGrantGlobalRole.PublicId, globals.AnyAuthenticatedUserId) - childGrantGlobalRoleGrant := "ids=*;type=host;actions=*" + childGrantGlobalRoleGrant := "ids=*;type=account;actions=*" TestRoleGrant(t, conn, childGrantGlobalRole.PublicId, childGrantGlobalRoleGrant) descendantGrantGlobalRole := TestRole(t, conn, scope.Global.String(), WithGrantScopeIds([]string{ @@ -731,235 +722,817 @@ func TestGrantsForUser(t *testing.T) { descendantGrantGlobalRoleGrant := "ids=*;type=credential;actions=*" TestRoleGrant(t, conn, descendantGrantGlobalRole.PublicId, descendantGrantGlobalRoleGrant) - /* - // Useful if needing to debug - t.Log( - "\nnoGrantOrg1", noGrantOrg1.PublicId, noGrantOrg1Role.PublicId, - "\nnoGrantProj1", noGrantProj1.PublicId, noGrantProj1Role.PublicId, - "\nnoGrantOrg2", noGrantOrg2.PublicId, noGrantOrg2Role.PublicId, - "\nnoGrantProj2", noGrantProj2.PublicId, noGrantProj2Role.PublicId, - "\ndirectGrantOrg1", directGrantOrg1.PublicId, directGrantOrg1Role.PublicId, - "\ndirectGrantProj1a", directGrantProj1a.PublicId, directGrantProj1aRole.PublicId, - "\ndirectGrantProj1b", directGrantProj1b.PublicId, directGrantProj1bRole.PublicId, - "\ndirectGrantOrg2", directGrantOrg2.PublicId, directGrantOrg2Role.PublicId, - "\ndirectGrantProj2a", directGrantProj2a.PublicId, directGrantProj2aRole.PublicId, - "\ndirectGrantProj2b", directGrantProj2b.PublicId, directGrantProj2bRole.PublicId, - "\nchildGrantOrg1", childGrantOrg1.PublicId, childGrantOrg1Role.PublicId, - "\nchildGrantProj1a", childGrantProj1a.PublicId, - "\nchildGrantProj1b", childGrantProj1b.PublicId, - "\nchildGrantOrg2", childGrantOrg2.PublicId, childGrantOrg2Role.PublicId, - "\nchildGrantProj2a", childGrantProj2a.PublicId, - "\nchildGrantProj2b", childGrantProj2b.PublicId, - "\nchildGrantGlobalRole", childGrantGlobalRole.PublicId, - "\ndescendantGrantGlobalRole", descendantGrantGlobalRole.PublicId, - ) - */ - - // We expect to see: - // - // * No grants from noOrg/noProj - // * Grants from direct orgs/projs: - // * directGrantOrg1/directGrantOrg2 on org and respective projects (6 grants total) - // * directGrantProj on respective projects (4 grants total) - expGrantTuples := []perms.GrantTuple{ - // No grants from noOrg/noProj - // Grants from direct org1 to org1/proj1a/proj1b: - { - RoleId: directGrantOrg1Role.PublicId, - ScopeId: directGrantOrg1.PublicId, - Grant: directGrantOrg1RoleGrant, - }, - { - RoleId: directGrantOrg1Role.PublicId, - ScopeId: directGrantProj1a.PublicId, - Grant: directGrantOrg1RoleGrant, - }, - { - RoleId: directGrantOrg1Role.PublicId, - ScopeId: directGrantProj1b.PublicId, - Grant: directGrantOrg1RoleGrant, - }, - // Grants from direct org 1 proj 1a: - { - RoleId: directGrantProj1aRole.PublicId, - ScopeId: directGrantProj1a.PublicId, - Grant: directGrantProj1aRoleGrant, - }, - // Grant from direct org 1 proj 1 b: - { - RoleId: directGrantProj1bRole.PublicId, - ScopeId: directGrantProj1b.PublicId, - Grant: directGrantProj1bRoleGrant, - }, + t.Run("db-grants", func(t *testing.T) { + // Here we should see exactly what the DB has returned, before we do some + // local exploding of grants and grant scopes + expMultiGrantTuples := []multiGrantTuple{ + // No grants from noOrg/noProj + // Direct org1/2: + { + RoleId: directGrantOrg1Role.PublicId, + RoleScopeId: directGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeIds: globals.GrantScopeThis, + Grants: strings.Join([]string{directGrantOrg1RoleGrant1, directGrantOrg1RoleGrant2}, "^"), + }, + { + RoleId: directGrantOrg2Role.PublicId, + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeIds: strings.Join([]string{globals.GrantScopeThis, directGrantProj2a.PublicId}, "^"), + Grants: strings.Join([]string{directGrantOrg2RoleGrant1, directGrantOrg2RoleGrant2}, "^"), + }, + // Proj orgs 1/2: + { + RoleId: directGrantProj1aRole.PublicId, + RoleScopeId: directGrantProj1a.PublicId, + RoleParentScopeId: directGrantOrg1.PublicId, + GrantScopeIds: globals.GrantScopeThis, + Grants: directGrantProj1aRoleGrant, + }, + { + RoleId: directGrantProj1bRole.PublicId, + RoleScopeId: directGrantProj1b.PublicId, + RoleParentScopeId: directGrantOrg1.PublicId, + GrantScopeIds: globals.GrantScopeThis, + Grants: directGrantProj1bRoleGrant, + }, + { + RoleId: directGrantProj2aRole.PublicId, + RoleScopeId: directGrantProj2a.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeIds: globals.GrantScopeThis, + Grants: directGrantProj2aRoleGrant, + }, + { + RoleId: directGrantProj2bRole.PublicId, + RoleScopeId: directGrantProj2b.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeIds: globals.GrantScopeThis, + Grants: directGrantProj2bRoleGrant, + }, + // Child grants from orgs 1/2: + { + RoleId: childGrantOrg1Role.PublicId, + RoleScopeId: childGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeIds: globals.GrantScopeChildren, + Grants: childGrantOrg1RoleGrant, + }, + { + RoleId: childGrantOrg2Role.PublicId, + RoleScopeId: childGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeIds: globals.GrantScopeChildren, + Grants: strings.Join([]string{childGrantOrg2RoleGrant1, childGrantOrg2RoleGrant2}, "^"), + }, + // Children of global and descendants of global + { + RoleId: descendantGrantGlobalRole.PublicId, + RoleScopeId: scope.Global.String(), + GrantScopeIds: globals.GrantScopeDescendants, + Grants: descendantGrantGlobalRoleGrant, + }, + { + RoleId: childGrantGlobalRole.PublicId, + RoleScopeId: scope.Global.String(), + GrantScopeIds: globals.GrantScopeChildren, + Grants: childGrantGlobalRoleGrant, + }, + } + for i, tuple := range expMultiGrantTuples { + tuple.testStableSort() + expMultiGrantTuples[i] = tuple + } + multiGrantTuplesCache := new([]multiGrantTuple) + _, err := repo.GrantsForUser(ctx, user.PublicId, withTestCacheMultiGrantTuples(multiGrantTuplesCache)) + require.NoError(t, err) - // Grants from direct org1 to org2/proj2a/proj2b: - { - RoleId: directGrantOrg2Role.PublicId, - ScopeId: directGrantOrg2.PublicId, - Grant: directGrantOrg2RoleGrant, - }, - { - RoleId: directGrantOrg2Role.PublicId, - ScopeId: directGrantProj2a.PublicId, - Grant: directGrantOrg2RoleGrant, - }, - { - RoleId: directGrantOrg2Role.PublicId, - ScopeId: directGrantProj2b.PublicId, - Grant: directGrantOrg2RoleGrant, - }, - // Grants from direct org 2 proj 2a: - { - RoleId: directGrantProj2aRole.PublicId, - ScopeId: directGrantProj2a.PublicId, - Grant: directGrantProj2aRoleGrant, - }, - // Grant from direct org 2 proj 2 b: - { - RoleId: directGrantProj2bRole.PublicId, - ScopeId: directGrantProj2b.PublicId, - Grant: directGrantProj2bRoleGrant, - }, + // log.Println("multiGrantTuplesCache", pretty.Sprint(*multiGrantTuplesCache)) + assert.ElementsMatch(t, *multiGrantTuplesCache, expMultiGrantTuples) + }) - // Child grants from child org1 to proj1a/proj1b: - { - RoleId: childGrantOrg1Role.PublicId, - ScopeId: childGrantProj1a.PublicId, - Grant: childGrantOrg1RoleGrant, - }, - { - RoleId: childGrantOrg1Role.PublicId, - ScopeId: childGrantProj1b.PublicId, - Grant: childGrantOrg1RoleGrant, - }, - // Child grants from child org2 to proj2a/proj2b: - { - RoleId: childGrantOrg2Role.PublicId, - ScopeId: childGrantProj2a.PublicId, - Grant: childGrantOrg2RoleGrant, - }, - { - RoleId: childGrantOrg2Role.PublicId, - ScopeId: childGrantProj2b.PublicId, - Grant: childGrantOrg2RoleGrant, - }, + t.Run("exploded-grants", func(t *testing.T) { + // We expect to see: + // + // * No grants from noOrg/noProj + // * Grants from direct orgs/projs: + // * directGrantOrg1/directGrantOrg2 on org and respective projects (6 grants total per org) + // * directGrantProj on respective projects (4 grants total) + expGrantTuples := []perms.GrantTuple{ + // No grants from noOrg/noProj + // Grants from direct org1 to org1/proj1a/proj1b: + { + RoleId: directGrantOrg1Role.PublicId, + RoleScopeId: directGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg1.PublicId, + Grant: directGrantOrg1RoleGrant1, + }, + { + RoleId: directGrantOrg1Role.PublicId, + RoleScopeId: directGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg1.PublicId, + Grant: directGrantOrg1RoleGrant2, + }, + // Grants from direct org 1 proj 1a: + { + RoleId: directGrantProj1aRole.PublicId, + RoleScopeId: directGrantProj1a.PublicId, + RoleParentScopeId: directGrantOrg1.PublicId, + GrantScopeId: directGrantProj1a.PublicId, + Grant: directGrantProj1aRoleGrant, + }, + // Grant from direct org 1 proj 1 b: + { + RoleId: directGrantProj1bRole.PublicId, + RoleScopeId: directGrantProj1b.PublicId, + RoleParentScopeId: directGrantOrg1.PublicId, + GrantScopeId: directGrantProj1b.PublicId, + Grant: directGrantProj1bRoleGrant, + }, - // Grants from global to every org: - { - RoleId: childGrantGlobalRole.PublicId, - ScopeId: noGrantOrg1.PublicId, - Grant: childGrantGlobalRoleGrant, - }, - { - RoleId: childGrantGlobalRole.PublicId, - ScopeId: noGrantOrg2.PublicId, - Grant: childGrantGlobalRoleGrant, - }, - { - RoleId: childGrantGlobalRole.PublicId, - ScopeId: directGrantOrg1.PublicId, - Grant: childGrantGlobalRoleGrant, - }, - { - RoleId: childGrantGlobalRole.PublicId, - ScopeId: directGrantOrg2.PublicId, - Grant: childGrantGlobalRoleGrant, - }, - { - RoleId: childGrantGlobalRole.PublicId, - ScopeId: childGrantOrg1.PublicId, - Grant: childGrantGlobalRoleGrant, - }, - { - RoleId: childGrantGlobalRole.PublicId, - ScopeId: childGrantOrg2.PublicId, - Grant: childGrantGlobalRoleGrant, - }, + // Grants from direct org2 to org2/proj2a/proj2b: + { + RoleId: directGrantOrg2Role.PublicId, + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg2.PublicId, + Grant: directGrantOrg2RoleGrant1, + }, + { + RoleId: directGrantOrg2Role.PublicId, + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantProj2a.PublicId, + Grant: directGrantOrg2RoleGrant1, + }, + { + RoleId: directGrantOrg2Role.PublicId, + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg2.PublicId, + Grant: directGrantOrg2RoleGrant2, + }, + { + RoleId: directGrantOrg2Role.PublicId, + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantProj2a.PublicId, + Grant: directGrantOrg2RoleGrant2, + }, + // Grants from direct org 2 proj 2a: + { + RoleId: directGrantProj2aRole.PublicId, + RoleScopeId: directGrantProj2a.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeId: directGrantProj2a.PublicId, + Grant: directGrantProj2aRoleGrant, + }, + // Grant from direct org 2 proj 2 b: + { + RoleId: directGrantProj2bRole.PublicId, + RoleScopeId: directGrantProj2b.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeId: directGrantProj2b.PublicId, + Grant: directGrantProj2bRoleGrant, + }, + // Child grants from child org1 to proj1a/proj1b: + { + RoleId: childGrantOrg1Role.PublicId, + RoleScopeId: childGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Grant: childGrantOrg1RoleGrant, + }, + // Child grants from child org2 to proj2a/proj2b: + { + RoleId: childGrantOrg2Role.PublicId, + RoleScopeId: childGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Grant: childGrantOrg2RoleGrant1, + }, + { + RoleId: childGrantOrg2Role.PublicId, + RoleScopeId: childGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Grant: childGrantOrg2RoleGrant2, + }, - // Grants from global to every org and project: - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: noGrantOrg1.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: noGrantProj1.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: noGrantOrg2.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: noGrantProj2.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: directGrantOrg1.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: directGrantProj1a.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: directGrantProj1b.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: directGrantOrg2.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: directGrantProj2a.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: directGrantProj2b.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: childGrantOrg1.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, + // Grants from global to every org: + { + RoleId: childGrantGlobalRole.PublicId, + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Grant: childGrantGlobalRoleGrant, + }, + + // Grants from global to every org and project: + { + RoleId: descendantGrantGlobalRole.PublicId, + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeDescendants, + Grant: descendantGrantGlobalRoleGrant, + }, + } + + multiGrantTuplesCache := new([]multiGrantTuple) + grantTuples, err := repo.GrantsForUser(ctx, user.PublicId, withTestCacheMultiGrantTuples(multiGrantTuplesCache)) + require.NoError(t, err) + assert.ElementsMatch(t, grantTuples, expGrantTuples) + }) + + t.Run("acl-grants", func(t *testing.T) { + grantTuples, err := repo.GrantsForUser(ctx, user.PublicId) + require.NoError(t, err) + grants := make([]perms.Grant, 0, len(grantTuples)) + for _, gt := range grantTuples { + grant, err := perms.Parse(ctx, gt) + require.NoError(t, err) + grants = append(grants, grant) + } + acl := perms.NewACL(grants...) + + t.Run("descendant-grants", func(t *testing.T) { + descendantGrants := acl.DescendantsGrants() + expDescendantGrants := []perms.AclGrant{ + { + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeDescendants, + Id: "*", + Type: resource.Credential, + ActionSet: perms.ActionSet{action.All: true}, + }, + } + assert.ElementsMatch(t, descendantGrants, expDescendantGrants) + }) + + t.Run("child-grants", func(t *testing.T) { + childrenGrants := acl.ChildrenScopeGrantMap() + expChildrenGrants := map[string][]perms.AclGrant{ + childGrantOrg1.PublicId: { + { + RoleScopeId: childGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Id: "*", + Type: resource.HostSet, + ActionSet: perms.ActionSet{action.AddHosts: true, action.RemoveHosts: true}, + }, + }, + childGrantOrg2.PublicId: { + { + RoleScopeId: childGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Id: "*", + Type: resource.Session, + ActionSet: perms.ActionSet{action.CancelSelf: true}, + }, + { + RoleScopeId: childGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Id: "*", + Type: resource.Session, + ActionSet: perms.ActionSet{action.ReadSelf: true}, + }, + }, + scope.Global.String(): { + { + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Id: "*", + Type: resource.Account, + ActionSet: perms.ActionSet{action.All: true}, + }, + }, + } + assert.Len(t, childrenGrants, len(expChildrenGrants)) + for k, v := range childrenGrants { + assert.ElementsMatch(t, v, expChildrenGrants[k]) + } + }) + + t.Run("direct-grants", func(t *testing.T) { + directGrants := acl.DirectScopeGrantMap() + expDirectGrants := map[string][]perms.AclGrant{ + directGrantOrg1.PublicId: { + { + RoleScopeId: directGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg1.PublicId, + Id: "*", + Type: resource.All, + ActionSet: perms.ActionSet{action.All: true}, + }, + { + RoleScopeId: directGrantOrg1.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg1.PublicId, + Id: "*", + Type: resource.Role, + ActionSet: perms.ActionSet{action.List: true, action.Read: true}, + }, + }, + directGrantProj1a.PublicId: { + { + RoleScopeId: directGrantProj1a.PublicId, + RoleParentScopeId: directGrantOrg1.PublicId, + GrantScopeId: directGrantProj1a.PublicId, + Id: "*", + Type: resource.Target, + ActionSet: perms.ActionSet{action.AuthorizeSession: true, action.Read: true}, + }, + }, + directGrantProj1b.PublicId: { + { + RoleScopeId: directGrantProj1b.PublicId, + RoleParentScopeId: directGrantOrg1.PublicId, + GrantScopeId: directGrantProj1b.PublicId, + Id: "*", + Type: resource.Session, + ActionSet: perms.ActionSet{action.List: true, action.Read: true}, + }, + }, + directGrantOrg2.PublicId: { + { + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg2.PublicId, + Id: "*", + Type: resource.User, + ActionSet: perms.ActionSet{action.All: true}, + }, + { + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantOrg2.PublicId, + Id: "*", + Type: resource.Group, + ActionSet: perms.ActionSet{action.List: true, action.Read: true}, + }, + }, + directGrantProj2a.PublicId: { + { + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantProj2a.PublicId, + Id: "*", + Type: resource.User, + ActionSet: perms.ActionSet{action.All: true}, + }, + { + RoleScopeId: directGrantOrg2.PublicId, + RoleParentScopeId: scope.Global.String(), + GrantScopeId: directGrantProj2a.PublicId, + Id: "*", + Type: resource.Group, + ActionSet: perms.ActionSet{action.List: true, action.Read: true}, + }, + { + RoleScopeId: directGrantProj2a.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeId: directGrantProj2a.PublicId, + Id: "hcst_abcd1234", + Type: resource.Unknown, + ActionSet: perms.ActionSet{action.All: true}, + }, + { + RoleScopeId: directGrantProj2a.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeId: directGrantProj2a.PublicId, + Id: "hcst_1234abcd", + Type: resource.Unknown, + ActionSet: perms.ActionSet{action.All: true}, + }, + }, + directGrantProj2b.PublicId: { + { + RoleScopeId: directGrantProj2b.PublicId, + RoleParentScopeId: directGrantOrg2.PublicId, + GrantScopeId: directGrantProj2b.PublicId, + Id: "cs_abcd1234", + Type: resource.Unknown, + ActionSet: perms.ActionSet{action.Update: true, action.Read: true}, + }, + }, + } + /* + log.Println("org1", directGrantOrg1.PublicId) + log.Println("proj1a", directGrantProj1a.PublicId) + log.Println("proj1b", directGrantProj1b.PublicId) + log.Println("org2", directGrantOrg2.PublicId) + log.Println("proj2a", directGrantProj2a.PublicId) + log.Println("proj2b", directGrantProj2b.PublicId) + */ + assert.Len(t, directGrants, len(expDirectGrants)) + for k, v := range directGrants { + assert.ElementsMatch(t, v, expDirectGrants[k]) + } + }) + }) + t.Run("real-world", func(t *testing.T) { + // These tests cases crib from the initial setup of the grants, and + // include a number of cases to ensure the ones that should work do and + // various that should not do not + type testCase struct { + name string + res perms.Resource + act action.Type + shouldWork bool + } + testCases := []testCase{} + + // These test cases should fail because the grants are in roles where + // the user is not a principal { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: childGrantProj1a.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, + testCases = append(testCases, testCase{ + name: "nogrant-a", + res: perms.Resource{ + ScopeId: noGrantOrg1.PublicId, + Id: "u_abcd1234", + Type: resource.Scope, + ParentScopeId: scope.Global.String(), + }, + act: action.Read, + }, testCase{ + name: "nogrant-b", + res: perms.Resource{ + ScopeId: noGrantProj1.PublicId, + Id: "u_abcd1234", + Type: resource.User, + ParentScopeId: noGrantOrg1.String(), + }, + act: action.Read, + }, testCase{ + name: "nogrant-c", + res: perms.Resource{ + ScopeId: noGrantOrg2.PublicId, + Id: "u_abcd1234", + Type: resource.Scope, + ParentScopeId: scope.Global.String(), + }, + act: action.Read, + }, testCase{ + name: "nogrant-d", + res: perms.Resource{ + ScopeId: noGrantProj2.PublicId, + Id: "u_abcd1234", + Type: resource.User, + ParentScopeId: noGrantOrg2.String(), + }, + act: action.Read, + }, + ) + } + // These test cases are for org1 and its projects where the grants are + // direct, not via children/descendants. They test some actions that + // should work and some that shouldn't. { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: childGrantProj1b.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, + testCases = append(testCases, testCase{ + name: "direct-a", + res: perms.Resource{ + ScopeId: directGrantOrg1.PublicId, + Id: "u_abcd1234", + Type: resource.User, + ParentScopeId: scope.Global.String(), + }, + act: action.Read, + shouldWork: true, + }, testCase{ + name: "direct-b", + res: perms.Resource{ + ScopeId: directGrantOrg1.PublicId, + Id: "r_abcd1234", + Type: resource.Role, + ParentScopeId: scope.Global.String(), + }, + act: action.Read, + shouldWork: true, + }, testCase{ + name: "direct-c", + res: perms.Resource{ + ScopeId: directGrantProj1a.PublicId, + Id: "ttcp_abcd1234", + Type: resource.Target, + ParentScopeId: directGrantOrg1.PublicId, + }, + act: action.AuthorizeSession, + shouldWork: true, + }, testCase{ + name: "direct-d", + res: perms.Resource{ + ScopeId: directGrantProj1a.PublicId, + Id: "s_abcd1234", + Type: resource.Session, + ParentScopeId: directGrantOrg1.PublicId, + }, + act: action.Read, + }, testCase{ + name: "direct-e", + res: perms.Resource{ + ScopeId: directGrantProj1b.PublicId, + Id: "ttcp_abcd1234", + Type: resource.Target, + ParentScopeId: directGrantOrg1.PublicId, + }, + act: action.AuthorizeSession, + }, testCase{ + name: "direct-f", + res: perms.Resource{ + ScopeId: directGrantProj1b.PublicId, + Id: "s_abcd1234", + Type: resource.Session, + ParentScopeId: directGrantOrg1.PublicId, + }, + act: action.Read, + shouldWork: true, + }, + ) + } + // These test cases are for org2 and its projects where the grants are + // direct, not via children/descendants. They test some actions that + // should work and some that shouldn't. { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: childGrantOrg2.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, + testCases = append(testCases, testCase{ + name: "direct-g", + res: perms.Resource{ + ScopeId: directGrantOrg2.PublicId, + Id: "u_abcd1234", + Type: resource.User, + ParentScopeId: scope.Global.String(), + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "direct-m", + res: perms.Resource{ + ScopeId: directGrantOrg2.PublicId, + Id: "g_abcd1234", + Type: resource.Group, + ParentScopeId: scope.Global.String(), + }, + act: action.Update, + }, testCase{ + name: "direct-h", + res: perms.Resource{ + ScopeId: directGrantOrg2.PublicId, + Id: "acct_abcd1234", + Type: resource.Account, + ParentScopeId: scope.Global.String(), + }, + act: action.Delete, + shouldWork: true, + }, testCase{ + name: "direct-i", + res: perms.Resource{ + ScopeId: directGrantProj2a.PublicId, + Type: resource.Group, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.List, + shouldWork: true, + }, testCase{ + name: "direct-j", + res: perms.Resource{ + ScopeId: directGrantProj2a.PublicId, + Id: "r_abcd1234", + Type: resource.Role, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.Read, + }, testCase{ + name: "direct-n", + res: perms.Resource{ + ScopeId: directGrantProj2a.PublicId, + Id: "u_abcd1234", + Type: resource.User, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.Read, + shouldWork: true, + }, testCase{ + name: "direct-k", + res: perms.Resource{ + ScopeId: directGrantProj2a.PublicId, + Id: "hcst_abcd1234", + Type: resource.HostCatalog, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.Read, + shouldWork: true, + }, testCase{ + name: "direct-l", + res: perms.Resource{ + ScopeId: directGrantProj2b.PublicId, + Id: "cs_abcd1234", + Type: resource.CredentialStore, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.Update, + shouldWork: true, + }, + testCase{ + name: "direct-m", + res: perms.Resource{ + ScopeId: directGrantProj2b.PublicId, + Id: "cl_abcd1234", + Type: resource.CredentialLibrary, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.Update, + }, + ) + } + // These test cases are child grants { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: childGrantProj2a.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, + testCases = append(testCases, testCase{ + name: "children-a", + res: perms.Resource{ + ScopeId: scope.Global.String(), + Id: "a_abcd1234", + Type: resource.Account, + }, + act: action.Update, + }, testCase{ + name: "children-b", + res: perms.Resource{ + ScopeId: noGrantOrg1.PublicId, + Id: "a_abcd1234", + Type: resource.Account, + ParentScopeId: scope.Global.String(), + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "children-c", + res: perms.Resource{ + ScopeId: directGrantOrg1.PublicId, + Id: "a_abcd1234", + Type: resource.Account, + ParentScopeId: scope.Global.String(), + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "children-d", + res: perms.Resource{ + ScopeId: directGrantOrg2.PublicId, + Id: "a_abcd1234", + Type: resource.Account, + ParentScopeId: scope.Global.String(), + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "children-e", + res: perms.Resource{ + ScopeId: childGrantOrg2.PublicId, + Id: "s_abcd1234", + Type: resource.Session, + ParentScopeId: scope.Global.String(), + }, + act: action.CancelSelf, + }, testCase{ + name: "children-f", + res: perms.Resource{ + ScopeId: childGrantOrg1Proj.PublicId, + Id: "s_abcd1234", + Type: resource.Session, + ParentScopeId: childGrantOrg1.PublicId, + }, + act: action.CancelSelf, + }, testCase{ + name: "children-g", + res: perms.Resource{ + ScopeId: childGrantOrg2Proj.PublicId, + Id: "s_abcd1234", + Type: resource.Session, + ParentScopeId: childGrantOrg2.PublicId, + }, + act: action.CancelSelf, + shouldWork: true, + }, testCase{ + name: "children-h", + res: perms.Resource{ + ScopeId: childGrantOrg2Proj.PublicId, + Id: "s_abcd1234", + Type: resource.Session, + ParentScopeId: childGrantOrg2.PublicId, + }, + act: action.CancelSelf, + shouldWork: true, + }, testCase{ + name: "children-i", + res: perms.Resource{ + ScopeId: childGrantOrg1.PublicId, + Id: "hsst_abcd1234", + Type: resource.HostSet, + ParentScopeId: scope.Global.String(), + }, + act: action.AddHosts, + }, testCase{ + name: "children-j", + res: perms.Resource{ + ScopeId: childGrantOrg1Proj.PublicId, + Id: "hsst_abcd1234", + Type: resource.HostSet, + ParentScopeId: childGrantOrg1.PublicId, + }, + act: action.AddHosts, + shouldWork: true, + }, testCase{ + name: "children-k", + res: perms.Resource{ + ScopeId: childGrantOrg2Proj.PublicId, + Id: "hsst_abcd1234", + Type: resource.HostSet, + ParentScopeId: childGrantOrg2.PublicId, + }, + act: action.AddHosts, + }, + ) + } + // These test cases are global descendants grants { - RoleId: descendantGrantGlobalRole.PublicId, - ScopeId: childGrantProj2b.PublicId, - Grant: descendantGrantGlobalRoleGrant, - }, - } - - grantTuples, err := repo.GrantsForUser(ctx, user.PublicId) - require.NoError(err) - assert.ElementsMatch(grantTuples, expGrantTuples) + testCases = append(testCases, testCase{ + name: "descendants-a", + res: perms.Resource{ + ScopeId: scope.Global.String(), + Id: "cs_abcd1234", + Type: resource.Credential, + }, + act: action.Update, + }, testCase{ + name: "descendants-b", + res: perms.Resource{ + ScopeId: noGrantProj1.PublicId, + Id: "cs_abcd1234", + Type: resource.Credential, + ParentScopeId: noGrantOrg1.PublicId, + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "descendants-c", + res: perms.Resource{ + ScopeId: directGrantOrg2.PublicId, + Id: "cs_abcd1234", + Type: resource.Credential, + ParentScopeId: scope.Global.String(), + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "descendants-d", + res: perms.Resource{ + ScopeId: directGrantProj1a.PublicId, + Id: "cs_abcd1234", + Type: resource.Credential, + ParentScopeId: directGrantOrg1.PublicId, + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "descendants-e", + res: perms.Resource{ + ScopeId: directGrantProj1a.PublicId, + Id: "cs_abcd1234", + Type: resource.Credential, + ParentScopeId: directGrantOrg1.PublicId, + }, + act: action.Update, + shouldWork: true, + }, testCase{ + name: "descendants-f", + res: perms.Resource{ + ScopeId: directGrantProj2b.PublicId, + Id: "cs_abcd1234", + Type: resource.Credential, + ParentScopeId: directGrantOrg2.PublicId, + }, + act: action.Update, + shouldWork: true, + }, + ) + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + grantTuples, err := repo.GrantsForUser(ctx, user.PublicId) + require.NoError(t, err) + grants := make([]perms.Grant, 0, len(grantTuples)) + for _, gt := range grantTuples { + grant, err := perms.Parse(ctx, gt) + require.NoError(t, err) + grants = append(grants, grant) + } + acl := perms.NewACL(grants...) + assert.True(t, acl.Allowed(tc.res, tc.act, "u_abc123").Authorized == tc.shouldWork) + }) + } + }) } diff --git a/internal/iam/role_grant.go b/internal/iam/role_grant.go index c84ad0d1b9..1e991ca8ca 100644 --- a/internal/iam/role_grant.go +++ b/internal/iam/role_grant.go @@ -40,7 +40,7 @@ func NewRoleGrant(ctx context.Context, roleId string, grant string, _ ...Option) // Validate that the grant parses successfully. Note that we fake the scope // here to avoid a lookup as the scope is only relevant at actual ACL // checking time and we just care that it parses correctly. - perm, err := perms.Parse(ctx, "o_abcd1234", grant) + perm, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: "o_abcd1234", GrantScopeId: "o_abcd1234", Grant: grant}) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("parsing grant string")) } @@ -80,7 +80,7 @@ func (g *RoleGrant) VetForWrite(ctx context.Context, _ db.Reader, _ db.OpType, _ // checking time and we just care that it parses correctly. We may have // already done this in NewRoleGrant, but we re-check and set it here // anyways because it should still be part of the vetting process. - perm, err := perms.Parse(ctx, "o_abcd1234", g.RawGrant) + perm, err := perms.Parse(ctx, perms.GrantTuple{RoleScopeId: "o_abcd1234", GrantScopeId: "o_abcd1234", Grant: g.RawGrant}) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("parsing grant string")) } diff --git a/internal/perms/acl.go b/internal/perms/acl.go index db006e3b36..2f6037e001 100644 --- a/internal/perms/acl.go +++ b/internal/perms/acl.go @@ -9,23 +9,30 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/boundary/internal/types/scope" "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes" ) // AclGrant is used to decouple API-based grants from those we utilize for ACLs. // Notably it uses a single ID per grant instead of multiple IDs. type AclGrant struct { - // The scope ID, which will be a project ID or an org ID - scope Scope + // The scope ID of the role that sourced this grant + RoleScopeId string + + // The parent scope ID of the role that sourced this grant + RoleParentScopeId string + + // The grant's applied scope ID + GrantScopeId string // The ID to use - id string + Id string // The type, if provided - typ resource.Type + Type resource.Type // The set of actions being granted - actions actionSet + ActionSet ActionSet // The set of output fields granted OutputFields *OutputFields @@ -33,14 +40,48 @@ type AclGrant struct { // Actions returns the actions as a slice from the internal map, along with the // string representations of those actions. -func (a AclGrant) Actions() ([]action.Type, []string) { - return a.actions.Actions() +func (ag AclGrant) Actions() ([]action.Type, []string) { + return ag.ActionSet.Actions() +} + +func (ag AclGrant) Clone() AclGrant { + ret := AclGrant{ + RoleScopeId: ag.RoleScopeId, + RoleParentScopeId: ag.RoleParentScopeId, + GrantScopeId: ag.GrantScopeId, + Id: ag.Id, + Type: ag.Type, + } + if ag.ActionSet != nil { + ret.ActionSet = make(map[action.Type]bool, len(ag.ActionSet)) + for k, v := range ag.ActionSet { + ret.ActionSet[k] = v + } + } + if ag.OutputFields != nil { + ret.OutputFields = new(OutputFields) + ret.OutputFields.fields = make(map[string]bool, len(ag.OutputFields.fields)) + for k, v := range ag.OutputFields.fields { + ret.OutputFields.fields[k] = v + } + } + return ret } // ACL provides an entry point into the permissions engine for determining if an -// action is allowed on a resource based on a principal's (user or group) grants. +// action is allowed on a resource based on a principal's (user or group) +// grants. type ACL struct { - scopeMap map[string][]AclGrant + // directScopeMap is a map of scope IDs to grants valid for that scope ID + // where the grant scope ID was specified directly + directScopeMap map[string][]AclGrant + // childrenScopeMap is a map of _parent_ scope IDs to grants, so that when + // we are checking a resource we can see if there were any "children" grant + // scope IDs that match + childrenScopeMap map[string][]AclGrant + // descendantsGrants is a list of grants that apply to all descendants of + // global + descendantsGrants []AclGrant } // ACLResults provides a type for the permission's engine results so that we can @@ -52,15 +93,19 @@ type ACLResults struct { OutputFields *OutputFields // This is included but unexported for testing/debugging - scopeMap map[string][]AclGrant + directScopeMap map[string][]AclGrant + childrenScopeMap map[string][]AclGrant + descendantsGrants []AclGrant } // Permission provides information about the specific // resources that a user has been granted access to for a given scope, resource, and action. type Permission struct { - ScopeId string // The scope id for which the permission applies. - Resource resource.Type - Action action.Type + RoleScopeId string // The scope id of the granting role + RoleParentScopeId string // The parent scope id of the granting role + GrantScopeId string // Same as the scope ID unless "children" or "descendants" was used. + Resource resource.Type + Action action.Type ResourceIds []string // Any specific resource ids that have been referred in the grant's `id` field, if applicable. OnlySelf bool // The grant only allows actions against the user's own resources. @@ -88,37 +133,85 @@ type Resource struct { // Pin if defined would constrain the resource within the collection of the // pin id. Pin string `json:"pin,omitempty"` + + // ParentScopeId is the parent scope of the resource. + ParentScopeId string `json:"-"` } // NewACL creates an ACL from the grants provided. Note that this converts the // API-based Grants to AclGrants. func NewACL(grants ...Grant) ACL { ret := ACL{ - scopeMap: make(map[string][]AclGrant, len(grants)), + directScopeMap: make(map[string][]AclGrant, len(grants)), + childrenScopeMap: make(map[string][]AclGrant, len(grants)), + descendantsGrants: make([]AclGrant, 0, len(grants)), } for _, grant := range grants { - switch { - case len(grant.ids) > 0: - for _, id := range grant.ids { - ret.scopeMap[grant.scope.Id] = append(ret.scopeMap[grant.scope.Id], aclGrantFromGrant(grant, id)) - } - default: + ids := grant.ids + if len(ids) == 0 { // This handles the no-ID case as well as the deprecated single-ID case - ret.scopeMap[grant.scope.Id] = append(ret.scopeMap[grant.scope.Id], aclGrantFromGrant(grant, grant.id)) + ids = []string{grant.id} + } + for _, id := range ids { + switch grant.grantScopeId { + case globals.GrantScopeDescendants: + ret.descendantsGrants = append(ret.descendantsGrants, aclGrantFromGrant(grant, id)) + case globals.GrantScopeChildren: + // We use the role's scope here because we're evaluating the + // grants themselves, not the resource, so we want to know the + // scope of the role that said "children" + ret.childrenScopeMap[grant.roleScopeId] = append(ret.childrenScopeMap[grant.roleScopeId], aclGrantFromGrant(grant, id)) + default: + ret.directScopeMap[grant.grantScopeId] = append(ret.directScopeMap[grant.grantScopeId], aclGrantFromGrant(grant, id)) + } } } return ret } +func (a ACL) DirectScopeGrantMap() map[string][]AclGrant { + ret := make(map[string][]AclGrant, len(a.directScopeMap)) + for k, v := range a.directScopeMap { + newSlice := make([]AclGrant, len(v)) + for i, g := range v { + newSlice[i] = g.Clone() + } + ret[k] = newSlice + } + return ret +} + +func (a ACL) ChildrenScopeGrantMap() map[string][]AclGrant { + ret := make(map[string][]AclGrant, len(a.childrenScopeMap)) + for k, v := range a.childrenScopeMap { + newSlice := make([]AclGrant, len(v)) + for i, g := range v { + newSlice[i] = g.Clone() + } + ret[k] = newSlice + } + return ret +} + +func (a ACL) DescendantsGrants() []AclGrant { + ret := make([]AclGrant, len(a.descendantsGrants)) + for i, v := range a.descendantsGrants { + ret[i] = v.Clone() + } + return ret +} + func aclGrantFromGrant(grant Grant, id string) AclGrant { return AclGrant{ - scope: grant.scope, - id: id, - typ: grant.typ, - actions: grant.actions, - OutputFields: grant.OutputFields, + RoleScopeId: grant.roleScopeId, + RoleParentScopeId: grant.roleParentScopeId, + GrantScopeId: grant.grantScopeId, + Id: id, + Type: grant.typ, + ActionSet: grant.actions, + OutputFields: grant.OutputFields, } } @@ -126,9 +219,16 @@ func aclGrantFromGrant(grant Grant, id string) AclGrant { func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option) (results ACLResults) { opts := getOpts(opt...) - // First, get the grants within the specified scope - grants := a.scopeMap[r.ScopeId] - results.scopeMap = a.scopeMap + // First, get the grants within the specified scopes + grants := a.directScopeMap[r.ScopeId] + grants = append(grants, a.childrenScopeMap[r.ParentScopeId]...) + if r.ScopeId != scope.Global.String() { + // Descendants grants do not apply to global! + grants = append(grants, a.descendantsGrants...) + } + results.directScopeMap = a.directScopeMap + results.childrenScopeMap = a.childrenScopeMap + results.descendantsGrants = a.descendantsGrants var parentAction action.Type split := strings.Split(aType.String(), ":") @@ -139,7 +239,7 @@ func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option for _, grant := range grants { var outputFieldsOnly bool switch { - case len(grant.actions) == 0: + case len(grant.ActionSet) == 0: // Continue with the next grant, unless we have output fields // specified in which case we continue to be able to apply the // output fields depending on ID and type. @@ -148,13 +248,13 @@ func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option } else { continue } - case grant.actions[aType]: + case grant.ActionSet[aType]: // We have this action - case grant.actions[parentAction]: + case grant.ActionSet[parentAction]: // We don't have this action, but it's a subaction and we have the // parent action. As an example, if we are looking for "read:self" // and have "read", this is sufficient. - case grant.actions[action.All]: + case grant.ActionSet[action.All]: // All actions are allowed default: // No actions in the grant match what we're looking for, so continue @@ -188,35 +288,35 @@ func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option switch { // Allow discovery of scopes, so that auth methods within can be // discovered - case grant.typ == r.Type && - grant.typ == resource.Scope && + case grant.Type == r.Type && + grant.Type == resource.Scope && (aType == action.List || aType == action.NoOp): found = true // Allow discovery of and authenticating to auth methods - case grant.typ == r.Type && - grant.typ == resource.AuthMethod && + case grant.Type == r.Type && + grant.Type == resource.AuthMethod && (aType == action.List || aType == action.NoOp || aType == action.Authenticate): found = true } // Case 2: - // id=;actions= where ID cannot be a wildcard; or - // id=;output_fields= where fields cannot be a + // id=;actions= where ID cannot be a wildcard; or + // id=;output_fields= where fields cannot be a // wildcard. - case grant.id == r.Id && - grant.id != "" && - grant.id != "*" && - (grant.typ == resource.Unknown || grant.typ == globals.ResourceInfoFromPrefix(grant.id).Type) && + case grant.Id == r.Id && + grant.Id != "" && + grant.Id != "*" && + (grant.Type == resource.Unknown || grant.Type == globals.ResourceInfoFromPrefix(grant.Id).Type) && !action.List.IsActionOrParent(aType) && !action.Create.IsActionOrParent(aType): found = true - // Case 3: type=;actions= when action is list or + // Case 3: type=;actions= when action is list or // create (cannot be a wildcard). Must be a top level collection, // otherwise must be one of the two formats specified in cases 4 or 5. - // Or, type=resource.type;output_fields= and no action. This is + // Or, type=resource.Type;output_fields= and no action. This is // more of a semantic difference compared to 4 more than a security // difference; this type is for clarity as it ties more closely to the // concept of create and list as actions on a collection, operating on a @@ -228,10 +328,10 @@ func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option // "two ways of doing things" but it's a reasonable UX tradeoff given // that "all IDs" can reasonably be construed to include "and the one // I'm making" and "all of them for listing". - case grant.id == "" && + case grant.Id == "" && r.Id == "" && - grant.typ == r.Type && - grant.typ != resource.Unknown && + grant.Type == r.Type && + grant.Type != resource.Unknown && resource.TopLevelType(r.Type) && (action.List.IsActionOrParent(aType) || action.Create.IsActionOrParent(aType)): @@ -239,24 +339,24 @@ func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option found = true // Case 4: - // id=*;type=;actions= where type cannot be + // id=*;type=;actions= where type cannot be // unknown but can be a wildcard to allow any resource at all; or - // id=*;type=;output_fields= with no action. - case grant.id == "*" && - grant.typ != resource.Unknown && - (grant.typ == r.Type || - grant.typ == resource.All): + // id=*;type=;output_fields= with no action. + case grant.Id == "*" && + grant.Type != resource.Unknown && + (grant.Type == r.Type || + grant.Type == resource.All): found = true // Case 5: - // id=;type=;actions= where type can be a + // id=;type=;actions= where type can be a // wildcard and this this is operating on a non-top-level type. Same for // output fields only. - case grant.id != "" && - grant.id == r.Pin && - grant.typ != resource.Unknown && - (grant.typ == r.Type || grant.typ == resource.All) && + case grant.Id != "" && + grant.Id == r.Pin && + grant.Type != resource.Unknown && + (grant.Type == r.Type || grant.Type == resource.All) && !resource.TopLevelType(r.Type): found = true @@ -276,25 +376,110 @@ func (a ACL) Allowed(r Resource, aType action.Type, userId string, opt ...Option return } -// ListResolvablePermissions builds a set of Permissions based on the grants in -// the ACL. The permissions will only be created if there is at least +// ListResolvableAliasesPermissions builds a set of Permissions based on the +// grants in the ACL. The permissions will only be created if there is at least // one grant of the provided resource type that includes at least one of the -// provided actions in the action set. -// Note that unlike the ListPermissions method, this method does not attempt to -// generate permissions for the u_recovery user. To get the resolvable aliases -// for u_recovery, the user could simply query all aliases with a destination id. -func (a ACL) ListResolvablePermissions(requestedType resource.Type, actions action.ActionSet) []Permission { - perms := make([]Permission, 0, len(a.scopeMap)) - for scopeId := range a.scopeMap { - // Consider all scopes in the grants. They may not exist, but if that is - // the case the +// provided actions in the action set. Note that unlike the ListPermissions +// method, this method does not attempt to generate permissions for the +// u_recovery user. To get the resolvable aliases for u_recovery, the user could +// simply query all aliases with a destination id. +func (a ACL) ListResolvableAliasesPermissions(requestedType resource.Type, actions action.ActionSet) []Permission { + perms := make([]Permission, 0, len(a.directScopeMap)+len(a.childrenScopeMap)+len(a.descendantsGrants)) + + childScopeMap := a.childrenScopeMap + scopeMap := a.directScopeMap + + // Unilaterally add the descendants grants, if any. Not specifying an Id or + // ParentScopeId in ScopeInfo means that the only grants that might match + // are descendants, and we tell buildPermission to include descendants. + p := Permission{ + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeDescendants, + Resource: requestedType, + Action: action.ListResolvableAliases, + OnlySelf: true, // default to only self to be restrictive + } + if a.buildPermission(&scopes.ScopeInfo{}, requestedType, actions, true, &p) { + perms = append(perms, p) + // Shortcut here because this is all we need -- this will turn into all + // scopes. We only need to check for "global" in the direct map. + if _, ok := a.directScopeMap[scope.Global.String()]; !ok { + return perms + } + childScopeMap = nil + scopeMap = map[string][]AclGrant{scope.Global.String(): a.directScopeMap[scope.Global.String()]} + } + + // Next look at children grants; provide only the parent scope ID and tell + // buildPermission to ignore descendants so that we know that the + // permissions being looked at come from a child relationship. Cache the + // scope IDs so we can ignore direct grants. + childrenScopes := map[string]struct{}{} + for scopeId := range childScopeMap { p := Permission{ - ScopeId: scopeId, - Resource: requestedType, - Action: action.ListResolvableAliases, - OnlySelf: true, // default to only self to be restrictive + RoleScopeId: scopeId, + GrantScopeId: globals.GrantScopeChildren, + Resource: requestedType, + Action: action.ListResolvableAliases, + OnlySelf: true, // default to only self to be restrictive + } + if scopeId != scope.Global.String() { // Must be an org then so global is parent + p.RoleParentScopeId = scope.Global.String() + } + if a.buildPermission(&scopes.ScopeInfo{ParentScopeId: scopeId}, requestedType, actions, false, &p) { + perms = append(perms, p) + childrenScopes[scopeId] = struct{}{} + } + } + + // Now look at direct grants; provide only Id so that we know the + // permissions being looked at will include those specific scopes. + for grantScopeId, grants := range scopeMap { + p := Permission{ + GrantScopeId: grantScopeId, + Resource: requestedType, + Action: action.ListResolvableAliases, + OnlySelf: true, // default to only self to be restrictive + } + + if len(grants) > 0 { + // Since scopeIds will be the same for all of these grants, and it's + // not children or descendants, we can get it from any of the grants + p.RoleParentScopeId = grants[0].RoleParentScopeId + p.RoleScopeId = grants[0].RoleScopeId + } + + switch { + case grantScopeId == p.RoleScopeId: + // If the role and grant scope IDs are the same, they share a + // parent, so we can look at the role's parent scope ID in the + // children scopes map + if _, ok := childrenScopes[p.RoleParentScopeId]; ok { + // We already looked at this scope in the children grants, so skip it + continue + } + case strings.HasPrefix(p.RoleScopeId, scope.Org.Prefix()): + // Since direct grants must be in the same scope or downstream, if + // the role scope ID is an org and the role and grant scopes are + // different, the grant is on a project, so look for children from + // the org + if _, ok := childrenScopes[p.RoleScopeId]; ok { + // We already found grants at this scope in the children grants, + // so skip it + continue + } + default: + // Since direct grants must be the same scope or downstream, the + // only possibility left for a children grant is that the parent is + // global and the grant is on the org -- if it was for projects it + // would need to be a descendants grant + if _, ok := childrenScopes[scope.Global.String()]; ok { + // We already looked at this scope in the children grants, so skip it + continue + } } - if a.buildPermission(scopeId, requestedType, actions, &p) { + + if a.buildPermission(&scopes.ScopeInfo{Id: grantScopeId}, requestedType, actions, false, &p) { perms = append(perms, p) } } @@ -307,18 +492,31 @@ func (a ACL) ListResolvablePermissions(requestedType resource.Type, actions acti // or for action.All in order for a Permission to be created for the scope. // The set of "id actions" is resource dependant, but will generally include all // actions that can be taken on an individual resource. -func (a ACL) ListPermissions(requestedScopes map[string]*scopes.ScopeInfo, +func (a ACL) ListPermissions( + requestedScopes map[string]*scopes.ScopeInfo, requestedType resource.Type, idActions action.ActionSet, userId string, ) []Permission { perms := make([]Permission, 0, len(requestedScopes)) - for scopeId := range requestedScopes { + for scopeId, scopeInfo := range requestedScopes { + if scopeInfo == nil { + continue + } + // Note: this function is called either with the scope resulting from + // authentication (which would have the scope info for the specific + // resource) or recursive scopes, which are fully resolved. The scopes + // included have already been run through acl.Allowed() to see if the + // user has access to the resource, so the grant scope ID can correctly + // be set here to be the same as the role scope ID even if it's + // technically coming from children/descendants grants. p := Permission{ - ScopeId: scopeId, - Resource: requestedType, - Action: action.List, - OnlySelf: true, // default to only self to be restrictive + RoleScopeId: scopeId, + RoleParentScopeId: scopeInfo.ParentScopeId, + GrantScopeId: scopeId, + Resource: requestedType, + Action: action.List, + OnlySelf: true, // default to only self to be restrictive } if userId == globals.RecoveryUserId { p.All = true @@ -326,7 +524,7 @@ func (a ACL) ListPermissions(requestedScopes map[string]*scopes.ScopeInfo, perms = append(perms, p) continue } - if a.buildPermission(scopeId, requestedType, idActions, &p) { + if a.buildPermission(scopeInfo, requestedType, idActions, false, &p) { perms = append(perms, p) } } @@ -336,27 +534,43 @@ func (a ACL) ListPermissions(requestedScopes map[string]*scopes.ScopeInfo, // buildPermission populates the provided permission with either the resource ids // or marking All to true if there are grants that have an action that match // one of the provided idActions for the provided type -func (a ACL) buildPermission(scopeId string, +func (a ACL) buildPermission( + scopeInfo *scopes.ScopeInfo, requestedType resource.Type, idActions action.ActionSet, + includeDescendants bool, p *Permission, ) bool { // Get grants for a specific scope id from the source of truth. - grants := a.scopeMap[scopeId] + if scopeInfo == nil { + return false + } + var grants []AclGrant + if scopeInfo.Id != "" { + grants = a.directScopeMap[scopeInfo.Id] + } + if scopeInfo.ParentScopeId != "" { + grants = append(grants, a.childrenScopeMap[scopeInfo.ParentScopeId]...) + } + // If the scope is global it needs to be a direct grant; descendants doesn't + // include global + if includeDescendants || (scopeInfo.Id != "" && scopeInfo.Id != scope.Global.String()) { + grants = append(grants, a.descendantsGrants...) + } for _, grant := range grants { // This grant doesn't match what we're looking for, ignore. - if grant.typ != requestedType && grant.typ != resource.All && globals.ResourceInfoFromPrefix(grant.id).Type != requestedType { + if grant.Type != requestedType && grant.Type != resource.All && globals.ResourceInfoFromPrefix(grant.Id).Type != requestedType { continue } // We found a grant that matches the requested resource type: // Search to see if one or all actions in the action set have been granted. found := false - if ok := grant.actions[action.All]; ok { + if ok := grant.ActionSet[action.All]; ok { found = true } else { for idA := range idActions { - if ok := grant.actions[idA]; ok { + if ok := grant.ActionSet[idA]; ok { found = true break } @@ -375,13 +589,13 @@ func (a ACL) buildPermission(scopeId string, } p.OnlySelf = p.OnlySelf && excludeList.OnlySelf() - switch grant.id { + switch grant.Id { case "*": p.All = true case "": continue default: - p.ResourceIds = append(p.ResourceIds, grant.id) + p.ResourceIds = append(p.ResourceIds, grant.Id) } } diff --git a/internal/perms/acl_test.go b/internal/perms/acl_test.go index 12d3423f3b..02cebf3cd9 100644 --- a/internal/perms/acl_test.go +++ b/internal/perms/acl_test.go @@ -19,8 +19,10 @@ import ( ) type scopeGrant struct { - scope string - grants []string + roleScope string + roleParentScopeId string + grantScope string + grants []string } func Test_ACLAllowed(t *testing.T) { @@ -45,7 +47,8 @@ func Test_ACLAllowed(t *testing.T) { // A set of common grants to use in the following tests commonGrants := []scopeGrant{ { - scope: "o_a", + roleScope: "o_a", + grantScope: "o_a", grants: []string{ "ids=ampw_bar,ampw_baz;actions=read,update", "ids=ampw_bop;actions=read:self,update", @@ -55,7 +58,8 @@ func Test_ACLAllowed(t *testing.T) { }, }, { - scope: "o_b", + roleScope: "o_b", + grantScope: "o_b", grants: []string{ "ids=*;type=host-set;actions=list,create", "ids=hcst_mypin;type=host;actions=*;output_fields=name,description", @@ -64,7 +68,8 @@ func Test_ACLAllowed(t *testing.T) { }, }, { - scope: "o_d", + roleScope: "o_d", + grantScope: "o_d", grants: []string{ "ids=*;type=*;actions=create,update", "ids=*;type=session;actions=*", @@ -74,7 +79,8 @@ func Test_ACLAllowed(t *testing.T) { } templateGrants := []scopeGrant{ { - scope: "o_c", + roleScope: "o_c", + grantScope: "o_c", grants: []string{ "ids={{user.id }};actions=read,update", "ids={{ account.id}};actions=change-password", @@ -104,7 +110,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "top level create with type only", - resource: Resource{ScopeId: "o_a", Type: resource.HostCatalog}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Type: resource.HostCatalog}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Create, authorized: true}, @@ -113,7 +119,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and id no matching action", - resource: Resource{ScopeId: "o_a", Id: "a_foo", Type: resource.Role}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "a_foo", Type: resource.Role}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Update}, @@ -122,7 +128,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and id and matching action first id", - resource: Resource{ScopeId: "o_a", Id: "ampw_bar"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "ampw_bar"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, authorized: true}, @@ -132,7 +138,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and id and matching action second id", - resource: Resource{ScopeId: "o_a", Id: "ampw_baz"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "ampw_baz"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, authorized: true}, @@ -142,7 +148,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and type and all action with valid pin", - resource: Resource{ScopeId: "o_b", Pin: "hcst_mypin", Type: resource.Host}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_b", Pin: "hcst_mypin", Type: resource.Host}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, authorized: true, outputFields: []string{"description", "id", "name"}}, @@ -152,7 +158,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and type and all action but bad pin", - resource: Resource{ScopeId: "o_b", Pin: "notmypin", Type: resource.Host}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_b", Pin: "notmypin", Type: resource.Host}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, outputFields: []string{"id"}}, @@ -162,7 +168,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and id and some action", - resource: Resource{ScopeId: "o_b", Id: "myhost", Type: resource.HostSet}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_b", Id: "myhost", Type: resource.HostSet}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.List, authorized: true, outputFields: []string{"id"}}, @@ -172,7 +178,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and id and all action but bad specifier", - resource: Resource{ScopeId: "o_b", Id: "id_g"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_b", Id: "id_g"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, outputFields: []string{"id"}}, @@ -182,7 +188,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and not matching type", - resource: Resource{ScopeId: "o_a", Type: resource.HostCatalog}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Type: resource.HostCatalog}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Update}, @@ -191,7 +197,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope and matching type", - resource: Resource{ScopeId: "o_a", Type: resource.HostSet}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Type: resource.HostSet}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.List, authorized: true}, @@ -201,7 +207,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope, type, action, random id and bad pin first id", - resource: Resource{ScopeId: "o_a", Id: "anything", Type: resource.HostCatalog, Pin: "ampw_bar"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "anything", Type: resource.HostCatalog, Pin: "ampw_bar"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Update}, @@ -211,7 +217,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "matching scope, type, action, random id and bad pin second id", - resource: Resource{ScopeId: "o_a", Id: "anything", Type: resource.HostCatalog, Pin: "ampw_baz"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "anything", Type: resource.HostCatalog, Pin: "ampw_baz"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Update}, @@ -221,7 +227,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "wrong scope and matching type", - resource: Resource{ScopeId: "o_bad", Type: resource.HostSet}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_bad", Type: resource.HostSet}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.List}, @@ -231,7 +237,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "any id", - resource: Resource{ScopeId: "o_b", Type: resource.AuthMethod}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_b", Type: resource.AuthMethod}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.List, outputFields: []string{"id"}}, @@ -241,7 +247,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "bad templated user id", - resource: Resource{ScopeId: "o_c"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_c"}, scopeGrants: append(commonGrants, templateGrants...), actionsAuthorized: []actionAuthorized{ {action: action.List}, @@ -252,7 +258,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "good templated user id", - resource: Resource{ScopeId: "o_c", Id: "u_abcd1234"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_c", Id: "u_abcd1234"}, scopeGrants: append(commonGrants, templateGrants...), actionsAuthorized: []actionAuthorized{ {action: action.Read, authorized: true}, @@ -262,7 +268,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "bad templated old account id", - resource: Resource{ScopeId: "o_c"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_c"}, scopeGrants: append(commonGrants, templateGrants...), actionsAuthorized: []actionAuthorized{ {action: action.List}, @@ -273,7 +279,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "good templated old account id", - resource: Resource{ScopeId: "o_c", Id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPreviousPrefix)}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_c", Id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPreviousPrefix)}, scopeGrants: append(commonGrants, templateGrants...), actionsAuthorized: []actionAuthorized{ {action: action.ChangePassword, authorized: true}, @@ -283,7 +289,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "bad templated new account id", - resource: Resource{ScopeId: "o_c"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_c"}, scopeGrants: append(commonGrants, templateGrants...), actionsAuthorized: []actionAuthorized{ {action: action.List}, @@ -294,7 +300,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "good templated new account id", - resource: Resource{ScopeId: "o_c", Id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPrefix)}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_c", Id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPrefix)}, scopeGrants: append(commonGrants, templateGrants...), actionsAuthorized: []actionAuthorized{ {action: action.ChangePassword, authorized: true}, @@ -304,7 +310,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "all type", - resource: Resource{ScopeId: "o_d", Type: resource.Account}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_d", Type: resource.Account}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Create, authorized: true}, @@ -314,7 +320,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "list with top level list", - resource: Resource{ScopeId: "o_a", Type: resource.Target}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Type: resource.Target}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.List, authorized: true}, @@ -322,7 +328,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "list sessions with wildcard actions", - resource: Resource{ScopeId: "o_d", Type: resource.Session}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_d", Type: resource.Session}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.List, authorized: true}, @@ -330,7 +336,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "read self with top level read first id", - resource: Resource{ScopeId: "o_a", Id: "ampw_bar"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "ampw_bar"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, authorized: true}, @@ -339,7 +345,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "read self with top level read second id", - resource: Resource{ScopeId: "o_a", Id: "ampw_baz"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "ampw_baz"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read, authorized: true}, @@ -348,7 +354,7 @@ func Test_ACLAllowed(t *testing.T) { }, { name: "read self only", - resource: Resource{ScopeId: "o_a", Id: "ampw_bop"}, + resource: Resource{ParentScopeId: scope.Global.String(), ScopeId: "o_a", Id: "ampw_bop"}, scopeGrants: commonGrants, actionsAuthorized: []actionAuthorized{ {action: action.Read}, @@ -360,7 +366,8 @@ func Test_ACLAllowed(t *testing.T) { resource: Resource{ScopeId: scope.Global.String(), Type: resource.Worker}, scopeGrants: []scopeGrant{ { - scope: scope.Global.String(), + roleScope: scope.Global.String(), + grantScope: scope.Global.String(), grants: []string{ "type=worker;actions=create", }, @@ -375,7 +382,8 @@ func Test_ACLAllowed(t *testing.T) { resource: Resource{ScopeId: scope.Global.String(), Type: resource.Worker}, scopeGrants: []scopeGrant{ { - scope: scope.Global.String(), + roleScope: scope.Global.String(), + grantScope: scope.Global.String(), grants: []string{ "type=worker;actions=create:worker-led", }, @@ -392,7 +400,7 @@ func Test_ACLAllowed(t *testing.T) { var grants []Grant for _, sg := range test.scopeGrants { for _, g := range sg.grants { - grant, err := Parse(ctx, sg.scope, g, WithAccountId(test.accountId), WithUserId(test.userId)) + grant, err := Parse(ctx, GrantTuple{RoleScopeId: sg.roleScope, GrantScopeId: sg.grantScope, Grant: g}, WithAccountId(test.accountId), WithUserId(test.userId)) require.NoError(t, err) grants = append(grants, grant) } @@ -429,8 +437,10 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Requested resource mismatch", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=target;actions=list,read"}, // List & Read for all Targets + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=target;actions=list,read"}, // List & Read for all Targets }, }, resourceType: resource.Session, // We're requesting sessions. @@ -441,8 +451,10 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Requested actions not available for the requested scope id", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=delete"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=delete"}, }, }, resourceType: resource.Session, @@ -453,8 +465,10 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "No specific id or wildcard provided for `id` field", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"type=*;actions=list,read"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"type=*;actions=list,read"}, }, }, resourceType: resource.Session, @@ -466,20 +480,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Allow all ids", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=update,read"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=update,read"}, }, }, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -487,20 +505,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Allow all ids, :self actions", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,read:self"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,read:self"}, }, }, resourceType: resource.Session, actionSet: action.NewActionSet(action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -508,7 +530,9 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Allow specific IDs", aclGrants: []scopeGrant{ { - scope: "o_1", + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", grants: []string{ "ids=s_1;type=session;actions=list,read", "ids=s_2,s_3;type=session;actions=list,read", @@ -519,12 +543,14 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { actionSet: action.NewActionSet(action.Read), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: []string{"s_1", "s_2", "s_3"}, - OnlySelf: false, - All: false, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: []string{"s_1", "s_2", "s_3"}, + OnlySelf: false, + All: false, }, }, }, @@ -532,20 +558,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "No specific type 1", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=list,read:self"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=list,read:self"}, }, }, resourceType: resource.Session, actionSet: action.NewActionSet(action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -553,20 +583,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "List + No-op action with id wildcard", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,no-op"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,no-op"}, }, }, resourceType: resource.Session, actionSet: action.NewActionSet(action.NoOp), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -574,8 +608,10 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "List + No-op action with id wildcard, read present", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,no-op"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,no-op"}, }, }, resourceType: resource.Session, @@ -586,7 +622,9 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "List + No-op action with specific ids", aclGrants: []scopeGrant{ { - scope: "o_1", + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", grants: []string{ "ids=s_1;type=session;actions=list,no-op", "ids=s_2,s_3;type=session;actions=list,no-op", @@ -597,12 +635,14 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { actionSet: action.NewActionSet(action.NoOp), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: []string{"s_1", "s_2", "s_3"}, - OnlySelf: false, - All: false, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: []string{"s_1", "s_2", "s_3"}, + OnlySelf: false, + All: false, }, }, }, @@ -610,20 +650,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "No specific type 2", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=list,read:self"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=list,read:self"}, }, }, resourceType: resource.Host, actionSet: action.NewActionSet(action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Host, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Host, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -631,7 +675,9 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Grant hierarchy is respected", aclGrants: []scopeGrant{ { - scope: "o_1", + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", grants: []string{ "ids=*;type=*;actions=*", "ids=*;type=session;actions=cancel:self,list,read:self", @@ -642,12 +688,14 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { actionSet: action.NewActionSet(action.NoOp, action.Read, action.ReadSelf, action.Cancel, action.CancelSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -655,20 +703,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Full access 1", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=*"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=*"}, }, }, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read, action.Create, action.Delete), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -676,20 +728,24 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Full access 2", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=*"}, + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=*"}, }, }, resourceType: resource.Host, actionSet: action.NewActionSet(action.Read, action.Create, action.Delete), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Host, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Host, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -697,35 +753,43 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { name: "Multiple scopes", aclGrants: []scopeGrant{ { - scope: "o_1", + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", grants: []string{ "ids=s_1;type=session;actions=create,read", "ids=s_2,s_3;type=session;actions=update,read", }, }, { - scope: "o_2", - grants: []string{"ids=*;type=session;actions=read:self"}, + roleScope: "o_2", + grantScope: "o_2", + roleParentScopeId: scope.Global.String(), + grants: []string{"ids=*;type=session;actions=read:self"}, }, }, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read, action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: []string{"s_1", "s_2", "s_3"}, - OnlySelf: false, - All: false, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: []string{"s_1", "s_2", "s_3"}, + OnlySelf: false, + All: false, }, { - ScopeId: "o_2", - Resource: resource.Session, - Action: action.ListResolvableAliases, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_2", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_2", + Resource: resource.Session, + Action: action.ListResolvableAliases, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -735,7 +799,9 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { actionSet: action.NewActionSet(action.Read, action.Cancel), aclGrants: []scopeGrant{ { - scope: "p_1", + roleScope: "p_1", + roleParentScopeId: "o_1", + grantScope: "p_1", grants: []string{ "type=target;actions=list", "ids=ttcp_1234567890;actions=read", @@ -744,12 +810,287 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { }, expPermissions: []Permission{ { - ScopeId: "p_1", - Resource: resource.Target, - Action: action.ListResolvableAliases, - ResourceIds: []string{"ttcp_1234567890"}, - All: false, - OnlySelf: false, + RoleScopeId: "p_1", + RoleParentScopeId: "o_1", + GrantScopeId: "p_1", + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "global_no_this_with_descendants", + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "global", + grantScope: globals.GrantScopeDescendants, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeDescendants, + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "global_with_this_with_descendants", + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "global", + grantScope: globals.GrantScopeDescendants, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "global", + grantScope: "global", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeDescendants, + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: scope.Global.String(), + GrantScopeId: scope.Global.String(), + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "global_no_this_with_valid_children", + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "global", + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "org_no_this_with_children_and_direct_grant", + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "o_2", + roleParentScopeId: scope.Global.String(), + grantScope: "o_2", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "o_2", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_2", + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "org_with_this_with_children_and_direct_grant", + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "org_with_this_with_child_scope_direct_grants", + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "o_1", + roleParentScopeId: scope.Global.String(), + grantScope: "o_1", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "p_1a", + roleParentScopeId: "o_1", + grantScope: "p_1a", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "p_1b", + roleParentScopeId: "o_1", + grantScope: "p_1b", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "p_2", + roleParentScopeId: "o_2", + grantScope: "p_2", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: globals.GrantScopeChildren, + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "p_2", + RoleParentScopeId: "o_2", + GrantScopeId: "p_2", + Resource: resource.Target, + Action: action.ListResolvableAliases, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, }, }, }, @@ -759,15 +1100,18 @@ func TestACL_ListResolvableAliasesPermissions(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var grants []Grant for _, sg := range tt.aclGrants { + if sg.roleScope == "" { + sg.roleScope = sg.grantScope + } for _, g := range sg.grants { - grant, err := Parse(ctx, sg.scope, g, WithSkipFinalValidation(tt.skipGrantValidationChecking)) + grant, err := Parse(ctx, GrantTuple{RoleScopeId: sg.roleScope, RoleParentScopeId: sg.roleParentScopeId, GrantScopeId: sg.grantScope, Grant: g}, WithSkipFinalValidation(tt.skipGrantValidationChecking)) require.NoError(t, err) grants = append(grants, grant) } } acl := NewACL(grants...) - perms := acl.ListResolvablePermissions(tt.resourceType, tt.actionSet) + perms := acl.ListResolvableAliasesPermissions(tt.resourceType, tt.actionSet) require.ElementsMatch(t, tt.expPermissions, perms) }) } @@ -792,8 +1136,8 @@ func TestACL_ListPermissions(t *testing.T) { name: "Requested scope(s) not present in ACL scope map", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,read"}, + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,read"}, }, }, scopes: map[string]*scopes.ScopeInfo{ @@ -808,8 +1152,8 @@ func TestACL_ListPermissions(t *testing.T) { name: "Requested resource mismatch", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=target;actions=list,read"}, // List & Read for all Targets + grantScope: "o_1", + grants: []string{"ids=*;type=target;actions=list,read"}, // List & Read for all Targets }, }, scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, @@ -821,8 +1165,8 @@ func TestACL_ListPermissions(t *testing.T) { name: "Requested actions not available for the requested scope id", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=delete"}, + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=delete"}, }, }, scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, @@ -834,8 +1178,8 @@ func TestACL_ListPermissions(t *testing.T) { name: "No specific id or wildcard provided for `id` field", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"type=*;actions=list,read"}, + grantScope: "o_1", + grants: []string{"type=*;actions=list,read"}, }, }, scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, @@ -848,21 +1192,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "Allow all ids", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,read"}, + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,read"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -870,21 +1216,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "Allow all ids, :self actions", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,read:self"}, + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,read:self"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -892,24 +1240,26 @@ func TestACL_ListPermissions(t *testing.T) { name: "Allow specific IDs", aclGrants: []scopeGrant{ { - scope: "o_1", + grantScope: "o_1", grants: []string{ "ids=s_1;type=session;actions=list,read", "ids=s_2,s_3;type=session;actions=list,read", }, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: []string{"s_1", "s_2", "s_3"}, - OnlySelf: false, - All: false, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: []string{"s_1", "s_2", "s_3"}, + OnlySelf: false, + All: false, }, }, }, @@ -917,21 +1267,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "No specific type 1", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=list,read:self"}, + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=list,read:self"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -939,21 +1291,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "List + No-op action with id wildcard", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,no-op"}, + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,no-op"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.NoOp), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -961,11 +1315,11 @@ func TestACL_ListPermissions(t *testing.T) { name: "List + No-op action with id wildcard, read present", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=session;actions=list,no-op"}, + grantScope: "o_1", + grants: []string{"ids=*;type=session;actions=list,no-op"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read), expPermissions: []Permission{}, @@ -974,24 +1328,26 @@ func TestACL_ListPermissions(t *testing.T) { name: "List + No-op action with specific ids", aclGrants: []scopeGrant{ { - scope: "o_1", + grantScope: "o_1", grants: []string{ "ids=s_1;type=session;actions=list,no-op", "ids=s_2,s_3;type=session;actions=list,no-op", }, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.NoOp), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: []string{"s_1", "s_2", "s_3"}, - OnlySelf: false, - All: false, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: []string{"s_1", "s_2", "s_3"}, + OnlySelf: false, + All: false, }, }, }, @@ -999,21 +1355,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "No specific type 2", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=list,read:self"}, + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=list,read:self"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Host, actionSet: action.NewActionSet(action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Host, - Action: action.List, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Host, + Action: action.List, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, @@ -1021,24 +1379,26 @@ func TestACL_ListPermissions(t *testing.T) { name: "Grant hierarchy is respected", aclGrants: []scopeGrant{ { - scope: "o_1", + grantScope: "o_1", grants: []string{ "ids=*;type=*;actions=*", "ids=*;type=session;actions=cancel:self,list,read:self", }, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.NoOp, action.Read, action.ReadSelf, action.Cancel, action.CancelSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -1046,21 +1406,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "Full access 1", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=*"}, + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=*"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read, action.Create, action.Delete), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -1068,21 +1430,23 @@ func TestACL_ListPermissions(t *testing.T) { name: "Full access 2", aclGrants: []scopeGrant{ { - scope: "o_1", - grants: []string{"ids=*;type=*;actions=*"}, + grantScope: "o_1", + grants: []string{"ids=*;type=*;actions=*"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}}, resourceType: resource.Host, actionSet: action.NewActionSet(action.Read, action.Create, action.Delete), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Host, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Host, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, @@ -1090,97 +1454,137 @@ func TestACL_ListPermissions(t *testing.T) { name: "Multiple scopes", aclGrants: []scopeGrant{ { - scope: "o_1", + grantScope: "o_1", grants: []string{ "ids=s_1;type=session;actions=list,read", "ids=s_2,s_3;type=session;actions=list,read", }, }, { - scope: "o_2", - grants: []string{"ids=*;type=session;actions=list,read:self"}, + grantScope: "o_2", + grants: []string{"ids=*;type=session;actions=list,read:self"}, }, }, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil, "o_2": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}, "o_2": {Id: "o_2", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read, action.ReadSelf), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: []string{"s_1", "s_2", "s_3"}, - OnlySelf: false, - All: false, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: []string{"s_1", "s_2", "s_3"}, + OnlySelf: false, + All: false, }, { - ScopeId: "o_2", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: true, - All: true, + RoleScopeId: "o_2", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_2", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: true, + All: true, }, }, }, { name: "Allow recovery user full access to sessions", userId: globals.RecoveryUserId, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil, "o_2": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}, "o_2": {Id: "o_2", ParentScopeId: scope.Global.String()}}, resourceType: resource.Session, actionSet: action.NewActionSet(action.Read, action.Create, action.Delete), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, { - ScopeId: "o_2", - Resource: resource.Session, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_2", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_2", + Resource: resource.Session, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, { name: "Allow recovery user full access to targets", userId: globals.RecoveryUserId, - scopes: map[string]*scopes.ScopeInfo{"o_1": nil, "o_2": nil}, + scopes: map[string]*scopes.ScopeInfo{"o_1": {Id: "o_1", ParentScopeId: scope.Global.String()}, "o_2": {Id: "o_2", ParentScopeId: scope.Global.String()}}, resourceType: resource.Target, actionSet: action.NewActionSet(action.Read, action.Create, action.Delete), expPermissions: []Permission{ { - ScopeId: "o_1", - Resource: resource.Target, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_1", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, { - ScopeId: "o_2", - Resource: resource.Target, - Action: action.List, - ResourceIds: nil, - OnlySelf: false, - All: true, + RoleScopeId: "o_2", + RoleParentScopeId: scope.Global.String(), + GrantScopeId: "o_2", + Resource: resource.Target, + Action: action.List, + ResourceIds: nil, + OnlySelf: false, + All: true, }, }, }, { name: "separate_type_id_resource_grants", - scopes: map[string]*scopes.ScopeInfo{"p_1": nil}, + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}}, + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + grantScope: "p_1", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "p_1", + RoleParentScopeId: "o_1", + GrantScopeId: "p_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "global_no_this_with_descendants", + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}, "global": {Id: "global", ParentScopeId: ""}}, resourceType: resource.Target, actionSet: action.NewActionSet(action.Read, action.Cancel), aclGrants: []scopeGrant{ { - scope: "p_1", + roleScope: "global", + grantScope: globals.GrantScopeDescendants, grants: []string{ "type=target;actions=list", "ids=ttcp_1234567890;actions=read", @@ -1189,12 +1593,196 @@ func TestACL_ListPermissions(t *testing.T) { }, expPermissions: []Permission{ { - ScopeId: "p_1", - Resource: resource.Target, - Action: action.List, - ResourceIds: []string{"ttcp_1234567890"}, - All: false, - OnlySelf: false, + RoleScopeId: "p_1", + RoleParentScopeId: "o_1", + GrantScopeId: "p_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "global_with_this_with_descendants", + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}, "global": {Id: "global", ParentScopeId: ""}}, + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "global", + grantScope: globals.GrantScopeDescendants, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "global", + grantScope: "global", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "p_1", + RoleParentScopeId: "o_1", + GrantScopeId: "p_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "global", + GrantScopeId: "global", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "global_no_this_with_invalid_children", + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}, "global": {Id: "global", ParentScopeId: ""}}, + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "global", + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: nil, + }, + { + name: "global_no_this_with_valid_children", + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}, "o_2": {Id: "o_2", ParentScopeId: "global"}}, + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "global", + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "o_2", + RoleParentScopeId: "global", + GrantScopeId: "o_2", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "org_no_this_with_children_and_direct_grant", + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}, "o_2": {Id: "o_2", ParentScopeId: "global"}}, + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "o_1", + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "o_2", + grantScope: "o_2", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "p_1", + RoleParentScopeId: "o_1", + GrantScopeId: "p_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "o_2", + RoleParentScopeId: "global", + GrantScopeId: "o_2", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + }, + }, + { + name: "org_with_this_with_children_and_direct_grant", + scopes: map[string]*scopes.ScopeInfo{"p_1": {Id: "p_1", ParentScopeId: "o_1"}, "o_2": {Id: "o_2", ParentScopeId: "global"}, "o_1": {Id: "o_1", ParentScopeId: "global"}}, + resourceType: resource.Target, + actionSet: action.NewActionSet(action.Read, action.Cancel), + aclGrants: []scopeGrant{ + { + roleScope: "o_1", + grantScope: globals.GrantScopeChildren, + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + { + roleScope: "o_1", + grantScope: "o_1", + grants: []string{ + "type=target;actions=list", + "ids=ttcp_1234567890;actions=read", + }, + }, + }, + expPermissions: []Permission{ + { + RoleScopeId: "o_1", + RoleParentScopeId: "global", + GrantScopeId: "o_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, + }, + { + RoleScopeId: "p_1", + RoleParentScopeId: "o_1", + GrantScopeId: "p_1", + Resource: resource.Target, + Action: action.List, + ResourceIds: []string{"ttcp_1234567890"}, + All: false, + OnlySelf: false, }, }, }, @@ -1208,8 +1796,11 @@ func TestACL_ListPermissions(t *testing.T) { } var grants []Grant for _, sg := range tt.aclGrants { + if sg.roleScope == "" { + sg.roleScope = sg.grantScope + } for _, g := range sg.grants { - grant, err := Parse(ctx, sg.scope, g, WithSkipFinalValidation(tt.skipGrantValidationChecking)) + grant, err := Parse(ctx, GrantTuple{RoleScopeId: sg.roleScope, GrantScopeId: sg.grantScope, Grant: g}, WithSkipFinalValidation(tt.skipGrantValidationChecking)) require.NoError(t, err) grants = append(grants, grant) } @@ -1315,7 +1906,7 @@ func Test_AnonRestrictions(t *testing.T) { grant = fmt.Sprintf(grant, action.Type(j).String()) } - parsedGrant, err := Parse(ctx, scope.Global.String(), grant, WithSkipFinalValidation(true)) + parsedGrant, err := Parse(ctx, GrantTuple{RoleScopeId: scope.Global.String(), GrantScopeId: scope.Global.String(), Grant: grant}, WithSkipFinalValidation(true)) require.NoError(err) acl := NewACL(parsedGrant) diff --git a/internal/perms/grants.go b/internal/perms/grants.go index fa8a061356..97d5bc8948 100644 --- a/internal/perms/grants.go +++ b/internal/perms/grants.go @@ -22,11 +22,11 @@ import ( "golang.org/x/exp/slices" ) -type actionSet map[action.Type]bool +type ActionSet map[action.Type]bool // Actions is a helper that goes through the map and returns both the actual // types of actions as a slice and the equivalent strings -func (a actionSet) Actions() (typs []action.Type, strs []string) { +func (a ActionSet) Actions() (typs []action.Type, strs []string) { typs = make([]action.Type, 0, len(a)) strs = make([]string, 0, len(a)) for k, v := range a { @@ -43,9 +43,11 @@ func (a actionSet) Actions() (typs []action.Type, strs []string) { // GrantTuple is simply a struct that can be reference from other code to return // a set of scopes and grants to parse type GrantTuple struct { - RoleId string - ScopeId string - Grant string + RoleId string + RoleScopeId string + RoleParentScopeId string + GrantScopeId string + Grant string } type GrantTuples []GrantTuple @@ -56,7 +58,7 @@ func (g GrantTuples) GrantHash(ctx context.Context) ([]byte, error) { // TODO: Should this return an error when the GrantTuples is empty? var values []string for _, grant := range g { - values = append(values, grant.Grant, grant.RoleId, grant.ScopeId) + values = append(values, grant.Grant, grant.RoleId, grant.GrantScopeId) } // Sort for deterministic output slices.Sort(values) @@ -106,14 +108,17 @@ type Scope struct { // Id is the public id of the iam.Scope Id string - // Type is the scope's type (org or project) - Type scope.Type + // ParentId is the parent scope ID + ParentId string } // Grant is a Go representation of a parsed grant type Grant struct { - // The scope, containing the ID and type - scope Scope + // The role scope ID + roleScopeId string + + // The role's parent scope ID, if any + roleParentScopeId string // The ID of the grant, if provided. Deprecated in favor of ids. id string @@ -121,11 +126,14 @@ type Grant struct { // The IDs in the grant, if provided ids []string + // The grant scope ID of the grant + grantScopeId string + // The type, if provided typ resource.Type // The set of actions being granted - actions actionSet + actions ActionSet // The set of output fields granted OutputFields *OutputFields @@ -145,6 +153,11 @@ func (g Grant) Ids() []string { return g.ids } +// GrantScopeId returns the grant scope ID the grant refers to, if any +func (g Grant) GrantScopeId() string { + return g.grantScopeId +} + // Type returns the type the grant refers to, or Unknown func (g Grant) Type() resource.Type { return g.typ @@ -172,10 +185,12 @@ func (g Grant) hasActionOrSubaction(act action.Type) bool { func (g Grant) clone() *Grant { ret := &Grant{ - scope: g.scope, - id: g.id, - ids: g.ids, - typ: g.typ, + roleScopeId: g.roleScopeId, + roleParentScopeId: g.roleParentScopeId, + id: g.id, + ids: g.ids, + grantScopeId: g.grantScopeId, + typ: g.typ, } if g.ids != nil { ret.ids = make([]string, len(g.ids)) @@ -435,47 +450,42 @@ func (g *Grant) unmarshalText(ctx context.Context, grantString string) error { // // The scope must be the org and project where this grant originated, not the // request. -func Parse(ctx context.Context, scopeId, grantString string, opt ...Option) (Grant, error) { +func Parse(ctx context.Context, tuple GrantTuple, opt ...Option) (Grant, error) { const op = "perms.Parse" - if len(grantString) == 0 { + if len(tuple.Grant) == 0 { return Grant{}, errors.New(ctx, errors.InvalidParameter, op, "missing grant string") } - if scopeId == "" { - return Grant{}, errors.New(ctx, errors.InvalidParameter, op, "missing scope id") + if tuple.RoleScopeId == "" { + return Grant{}, errors.New(ctx, errors.InvalidParameter, op, "missing role scope id") + } + if tuple.GrantScopeId == "" { + return Grant{}, errors.New(ctx, errors.InvalidParameter, op, "missing grant scope id") } - grantString = strings.ToValidUTF8(grantString, string(unicode.ReplacementChar)) + tuple.Grant = strings.ToValidUTF8(tuple.Grant, string(unicode.ReplacementChar)) grant := Grant{ - scope: Scope{Id: strings.ToValidUTF8(scopeId, string(unicode.ReplacementChar))}, - } - switch { - case scopeId == scope.Global.String(): - grant.scope.Type = scope.Global - case strings.HasPrefix(scopeId, scope.Org.Prefix()): - grant.scope.Type = scope.Org - case strings.HasPrefix(scopeId, scope.Project.Prefix()): - grant.scope.Type = scope.Project - default: - return Grant{}, errors.New(ctx, errors.InvalidParameter, op, "invalid scope type") + roleScopeId: strings.ToValidUTF8(tuple.RoleScopeId, string(unicode.ReplacementChar)), + roleParentScopeId: tuple.RoleParentScopeId, + grantScopeId: tuple.GrantScopeId, } switch { - case grantString[0] == '{': - if err := grant.unmarshalJSON(ctx, []byte(grantString)); err != nil { + case tuple.Grant[0] == '{': + if err := grant.unmarshalJSON(ctx, []byte(tuple.Grant)); err != nil { return Grant{}, errors.Wrap(ctx, err, op, errors.WithMsg("unable to parse JSON grant string")) } default: - if err := grant.unmarshalText(ctx, grantString); err != nil { + if err := grant.unmarshalText(ctx, tuple.Grant); err != nil { return Grant{}, errors.Wrap(ctx, err, op, errors.WithMsg("unable to parse grant string")) } } if grant.id != "" && len(grant.ids) > 0 { - return Grant{}, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("input grant string %q contains both %q and %q fields", grantString, "id", "ids")) + return Grant{}, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("input grant string %q contains both %q and %q fields", tuple.Grant, "id", "ids")) } if len(grant.ids) > 1 && slices.Contains(grant.ids, "*") { - return Grant{}, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("input grant string %q contains both wildcard and non-wildcard values in %q field", grantString, "ids")) + return Grant{}, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("input grant string %q contains both wildcard and non-wildcard values in %q field", tuple.Grant, "ids")) } opts := getOpts(opt...) @@ -498,7 +508,7 @@ func Parse(ctx context.Context, scopeId, grantString string, opt ...Option) (Gra continue } if seenType != globals.ResourceInfoFromPrefix(id).Type { - return Grant{}, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("input grant string %q contains ids of differently-typed resources", grantString)) + return Grant{}, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("input grant string %q contains ids of differently-typed resources", tuple.Grant)) } } } @@ -646,19 +656,43 @@ func Parse(ctx context.Context, scopeId, grantString string, opt ...Option) (Gra grantForValidation := grant.clone() grantForValidation.id = grantIds[i] acl := NewACL(*grantForValidation) - r := Resource{ - ScopeId: scopeId, - Id: grantIds[i], - Type: grant.typ, - } - if !resource.TopLevelType(grant.typ) { - r.Pin = grantIds[i] + // For special scope names we aren't sure where the resource + // might be, so check possible scopes and see if any are valid + scopesToCheck := make([]string, 0, 2) + var parentScopeId string + switch { + case grant.grantScopeId == globals.GrantScopeDescendants: + scopesToCheck = append(scopesToCheck, "o_1234567890", "p_1234567890") + case grant.grantScopeId == globals.GrantScopeChildren: + if grant.roleScopeId == scope.Global.String() { + scopesToCheck = append(scopesToCheck, "o_1234567890") + parentScopeId = scope.Global.String() + } else { + scopesToCheck = append(scopesToCheck, "p_1234567890") + parentScopeId = grant.roleScopeId + } + default: + scopesToCheck = append(scopesToCheck, grant.grantScopeId) } var allowed bool - for k := range grant.actions { - results := acl.Allowed(r, k, globals.AnonymousUserId, WithSkipAnonymousUserRestrictions(true)) - if results.Authorized { - allowed = true + for _, scopeId := range scopesToCheck { + r := Resource{ + ScopeId: scopeId, + Id: grantIds[i], + Type: grant.typ, + ParentScopeId: parentScopeId, + } + if !resource.TopLevelType(grant.typ) { + r.Pin = grantIds[i] + } + for k := range grant.actions { + results := acl.Allowed(r, k, globals.AnonymousUserId, WithSkipAnonymousUserRestrictions(true)) + if results.Authorized { + allowed = true + break + } + } + if allowed { break } } diff --git a/internal/perms/grants_test.go b/internal/perms/grants_test.go index 071b29f989..f2bcd68804 100644 --- a/internal/perms/grants_test.go +++ b/internal/perms/grants_test.go @@ -11,7 +11,6 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" - "github.com/hashicorp/boundary/internal/types/scope" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -141,22 +140,14 @@ func Test_MarshalingAndCloning(t *testing.T) { tests := []input{ { - name: "empty", - input: Grant{ - scope: Scope{ - Type: scope.Org, - }, - }, + name: "empty", jsonOutput: `{}`, canonicalString: ``, }, { name: "type and id", input: Grant{ - id: "baz", - scope: Scope{ - Type: scope.Project, - }, + id: "baz", typ: resource.Group, }, jsonOutput: `{"id":"baz","type":"group"}`, @@ -166,9 +157,6 @@ func Test_MarshalingAndCloning(t *testing.T) { name: "type and ids", input: Grant{ ids: []string{"baz", "bop"}, - scope: Scope{ - Type: scope.Project, - }, typ: resource.Group, }, jsonOutput: `{"ids":["baz","bop"],"type":"group"}`, @@ -178,9 +166,6 @@ func Test_MarshalingAndCloning(t *testing.T) { name: "type and ids single id", input: Grant{ ids: []string{"baz"}, - scope: Scope{ - Type: scope.Project, - }, typ: resource.Group, }, jsonOutput: `{"ids":["baz"],"type":"group"}`, @@ -189,10 +174,7 @@ func Test_MarshalingAndCloning(t *testing.T) { { name: "output fields id", input: Grant{ - id: "baz", - scope: Scope{ - Type: scope.Project, - }, + id: "baz", typ: resource.Group, OutputFields: &OutputFields{ fields: map[string]bool{ @@ -209,9 +191,6 @@ func Test_MarshalingAndCloning(t *testing.T) { name: "output fields ids", input: Grant{ ids: []string{"baz", "bop"}, - scope: Scope{ - Type: scope.Project, - }, typ: resource.Group, OutputFields: &OutputFields{ fields: map[string]bool{ @@ -227,10 +206,7 @@ func Test_MarshalingAndCloning(t *testing.T) { { name: "everything id", input: Grant{ - id: "baz", - scope: Scope{ - Type: scope.Project, - }, + id: "baz", typ: resource.Group, actions: map[action.Type]bool{ action.Create: true, @@ -252,9 +228,6 @@ func Test_MarshalingAndCloning(t *testing.T) { name: "everything ids", input: Grant{ ids: []string{"baz", "bop"}, - scope: Scope{ - Type: scope.Project, - }, typ: resource.Group, actions: map[action.Type]bool{ action.Create: true, @@ -663,160 +636,14 @@ func Test_Parse(t *testing.T) { } tests := []input{ - { - name: "empty", - err: `perms.Parse: missing grant string: parameter violation: error #100`, - }, - { - name: "bad json", - input: "{2:193}", - err: `perms.Parse: unable to parse JSON grant string: perms.(Grant).unmarshalJSON: error occurred during decode, encoding issue: error #303: invalid character '2' looking for beginning of object key string`, - }, - { - name: "bad text", - input: "id=foo=bar", - err: `perms.Parse: unable to parse grant string: perms.(Grant).unmarshalText: segment "id=foo=bar" not formatted correctly, wrong number of equal signs: parameter violation: error #100`, - }, - { - name: "bad type", - input: "ids=s_foobar;type=barfoo;actions=read", - err: `perms.Parse: unable to parse grant string: perms.(Grant).unmarshalText: unknown type specifier "barfoo": parameter violation: error #100`, - }, - { - name: "bad actions", - input: "ids=hcst_foobar;type=host-catalog;actions=createread", - err: `perms.Parse: perms.(Grant).parseAndValidateActions: unknown action "createread": parameter violation: error #100`, - }, - { - name: "bad id type", - input: "id=foobar;actions=read", - err: `perms.Parse: parsed grant string "id=foobar;actions=read" contains an id "foobar" of an unknown resource type: parameter violation: error #100`, - }, - { - name: "bad ids type first position", - input: "ids=foobar,hcst_foobar;actions=read", - err: `perms.Parse: input grant string "ids=foobar,hcst_foobar;actions=read" contains ids of differently-typed resources: parameter violation: error #100`, - }, - { - name: "bad ids type second position", - input: "ids=hcst_foobar,foobar;actions=read", - err: `perms.Parse: input grant string "ids=hcst_foobar,foobar;actions=read" contains ids of differently-typed resources: parameter violation: error #100`, - }, - { - name: "bad create action for ids", - input: "ids=u_foobar;actions=create", - err: `perms.Parse: parsed grant string "ids=u_foobar;actions=create" contains create or list action in a format that does not allow these: parameter violation: error #100`, - }, - { - name: "bad create action for ids with other perms", - input: "ids=u_foobar;actions=read,create", - err: `perms.Parse: parsed grant string "ids=u_foobar;actions=create,read" contains create or list action in a format that does not allow these: parameter violation: error #100`, - }, - { - name: "bad list action for id", - input: "id=u_foobar;actions=list", - err: `perms.Parse: parsed grant string "id=u_foobar;actions=list" contains create or list action in a format that does not allow these: parameter violation: error #100`, - }, - { - name: "bad list action for type with other perms", - input: "type=host-catalog;actions=list,read", - err: `perms.Parse: parsed grant string "type=host-catalog;actions=list,read" contains non-create or non-list action in a format that only allows these: parameter violation: error #100`, - }, - { - name: "wildcard id and actions without collection", - input: "id=*;actions=read", - err: `perms.Parse: parsed grant string "id=*;actions=read" contains wildcard id and no specified type: parameter violation: error #100`, - }, - { - name: "wildcard ids and actions without collection", - input: "ids=*;actions=read", - err: `perms.Parse: parsed grant string "ids=*;actions=read" contains wildcard id and no specified type: parameter violation: error #100`, - }, - { - name: "wildcard id and actions with list", - input: "id=*;actions=read,list", - err: `perms.Parse: parsed grant string "id=*;actions=list,read" contains wildcard id and no specified type: parameter violation: error #100`, - }, - { - name: "wildcard ids and actions with list", - input: "ids=*;actions=read,list", - err: `perms.Parse: parsed grant string "ids=*;actions=list,read" contains wildcard id and no specified type: parameter violation: error #100`, - }, - { - name: "wildcard type with no ids", - input: "type=*;actions=read,list", - err: `perms.Parse: parsed grant string "type=*;actions=list,read" contains wildcard type with no id value: parameter violation: error #100`, - }, - { - name: "mixed wildcard and non wildcard ids first position", - input: "ids=*,u_foobar;actions=read,list", - err: `perms.Parse: input grant string "ids=*,u_foobar;actions=read,list" contains both wildcard and non-wildcard values in "ids" field: parameter violation: error #100`, - }, - { - name: "mixed wildcard and non wildcard ids second position", - input: "ids=u_foobar,*;actions=read,list", - err: `perms.Parse: input grant string "ids=u_foobar,*;actions=read,list" contains both wildcard and non-wildcard values in "ids" field: parameter violation: error #100`, - }, - { - name: "empty ids and type", - input: "actions=create", - err: `perms.Parse: parsed grant string "actions=create" contains no id or type: parameter violation: error #100`, - }, - { - name: "wildcard type non child id", - input: "id=ttcp_1234567890;type=*;actions=create", - err: `perms.Parse: parsed grant string "id=ttcp_1234567890;type=*;actions=create" contains an id that does not support child types: parameter violation: error #100`, - }, - { - name: "wildcard type non child ids first position", - input: "ids=ttcp_1234567890,ttcp_1234567890;type=*;actions=create", - err: `perms.Parse: parsed grant string "ids=ttcp_1234567890,ttcp_1234567890;type=*;actions=create" contains an id that does not support child types: parameter violation: error #100`, - }, - { - name: "wildcard type non child ids second position", - input: "ids=ttcp_1234567890,ttcp_1234567890;type=*;actions=create", - err: `perms.Parse: parsed grant string "ids=ttcp_1234567890,ttcp_1234567890;type=*;actions=create" contains an id that does not support child types: parameter violation: error #100`, - }, - { - name: "specified resource type non child id", - input: "id=hcst_1234567890;type=account;actions=read", - err: `perms.Parse: parsed grant string "id=hcst_1234567890;type=account;actions=read" contains type account that is not a child type of the type (host-catalog) of the specified id: parameter violation: error #100`, - }, - { - name: "specified resource type non child ids first position", - input: "ids=hcst_1234567890,hcst_1234567890;type=account;actions=read", - err: `perms.Parse: parsed grant string "ids=hcst_1234567890,hcst_1234567890;type=account;actions=read" contains type account that is not a child type of the type (host-catalog) of the specified id: parameter violation: error #100`, - }, - { - name: "specified resource type non child ids second position", - input: "ids=hcst_1234567890,hcst_1234567890;type=account;actions=read", - err: `perms.Parse: parsed grant string "ids=hcst_1234567890,hcst_1234567890;type=account;actions=read" contains type account that is not a child type of the type (host-catalog) of the specified id: parameter violation: error #100`, - }, - { - name: "no id with one bad action", - input: "type=host-set;actions=read", - err: `perms.Parse: parsed grant string "type=host-set;actions=read" contains non-create or non-list action in a format that only allows these: parameter violation: error #100`, - }, - { - name: "no id with two bad action", - input: "type=host-set;actions=read,create", - err: `perms.Parse: parsed grant string "type=host-set;actions=create,read" contains non-create or non-list action in a format that only allows these: parameter violation: error #100`, - }, - { - name: "no id with three bad action", - input: "type=host-set;actions=list,read,create", - err: `perms.Parse: parsed grant string "type=host-set;actions=create,list,read" contains non-create or non-list action in a format that only allows these: parameter violation: error #100`, - }, { name: "empty output fields", input: "id=*;type=*;actions=read,list;output_fields=", expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "*", - typ: resource.All, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "*", + typ: resource.All, actions: map[action.Type]bool{ action.Read: true, action.List: true, @@ -830,12 +657,10 @@ func Test_Parse(t *testing.T) { name: "empty output fields json", input: `{"id": "*", "type": "*", "actions": ["read", "list"], "output_fields": []}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "*", - typ: resource.All, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "*", + typ: resource.All, actions: map[action.Type]bool{ action.Read: true, action.List: true, @@ -849,12 +674,10 @@ func Test_Parse(t *testing.T) { name: "wildcard id and type and actions with list", input: "id=*;type=*;actions=read,list", expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "*", - typ: resource.All, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "*", + typ: resource.All, actions: map[action.Type]bool{ action.Read: true, action.List: true, @@ -865,12 +688,10 @@ func Test_Parse(t *testing.T) { name: "wildcard ids and type and actions with list", input: "ids=*;type=*;actions=read,list", expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - ids: []string{"*"}, - typ: resource.All, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + ids: []string{"*"}, + typ: resource.All, actions: map[action.Type]bool{ action.Read: true, action.List: true, @@ -881,11 +702,9 @@ func Test_Parse(t *testing.T) { name: "good json type", input: `{"type":"host-catalog","actions":["create"]}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - typ: resource.HostCatalog, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + typ: resource.HostCatalog, actions: map[action.Type]bool{ action.Create: true, }, @@ -895,12 +714,10 @@ func Test_Parse(t *testing.T) { name: "good json id", input: `{"id":"u_foobar","actions":["read"]}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "u_foobar", - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "u_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -910,12 +727,10 @@ func Test_Parse(t *testing.T) { name: "good json ids", input: `{"ids":["hcst_foobar", "hcst_foobaz"],"actions":["read"]}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - ids: []string{"hcst_foobar", "hcst_foobaz"}, - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + ids: []string{"hcst_foobar", "hcst_foobaz"}, + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -925,12 +740,10 @@ func Test_Parse(t *testing.T) { name: "good json output fields id", input: `{"id":"u_foobar","actions":["read"],"output_fields":["version","id","name"]}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "u_foobar", - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "u_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -947,12 +760,10 @@ func Test_Parse(t *testing.T) { name: "good json output fields ids", input: `{"ids":["u_foobar"],"actions":["read"],"output_fields":["version","ids","name"]}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - ids: []string{"u_foobar"}, - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + ids: []string{"u_foobar"}, + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -969,12 +780,10 @@ func Test_Parse(t *testing.T) { name: "good json output fields no action", input: `{"id":"u_foobar","output_fields":["version","id","name"]}`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "u_foobar", - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "u_foobar", + typ: resource.Unknown, OutputFields: &OutputFields{ fields: map[string]bool{ "version": true, @@ -988,11 +797,9 @@ func Test_Parse(t *testing.T) { name: "good text type", input: `type=host-catalog;actions=create`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - typ: resource.HostCatalog, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + typ: resource.HostCatalog, actions: map[action.Type]bool{ action.Create: true, }, @@ -1002,12 +809,10 @@ func Test_Parse(t *testing.T) { name: "good text id", input: `id=u_foobar;actions=read`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "u_foobar", - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "u_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1017,12 +822,10 @@ func Test_Parse(t *testing.T) { name: "good text ids", input: `ids=hcst_foobar,hcst_foobaz;actions=read`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - ids: []string{"hcst_foobar", "hcst_foobaz"}, - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + ids: []string{"hcst_foobar", "hcst_foobaz"}, + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1032,12 +835,10 @@ func Test_Parse(t *testing.T) { name: "good output fields id", input: `id=u_foobar;actions=read;output_fields=version,id,name`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "u_foobar", - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "u_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1054,12 +855,10 @@ func Test_Parse(t *testing.T) { name: "good output fields ids", input: `ids=hcst_foobar,hcst_foobaz;actions=read;output_fields=version,ids,name`, expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - ids: []string{"hcst_foobar", "hcst_foobaz"}, - typ: resource.Unknown, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + ids: []string{"hcst_foobar", "hcst_foobaz"}, + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1077,12 +876,10 @@ func Test_Parse(t *testing.T) { input: `id=hcst_foobar;actions=read`, scopeOverride: "p_1234", expected: Grant{ - scope: Scope{ - Id: "p_1234", - Type: scope.Project, - }, - id: "hcst_foobar", - typ: resource.Unknown, + roleScopeId: "p_1234", + grantScopeId: "p_1234", + id: "hcst_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1093,12 +890,10 @@ func Test_Parse(t *testing.T) { input: `id=acctpw_foobar;actions=read`, scopeOverride: "o_1234", expected: Grant{ - scope: Scope{ - Id: "o_1234", - Type: scope.Org, - }, - id: "acctpw_foobar", - typ: resource.Unknown, + roleScopeId: "o_1234", + grantScopeId: "o_1234", + id: "acctpw_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1109,12 +904,10 @@ func Test_Parse(t *testing.T) { input: `id=acctpw_foobar;actions=read`, scopeOverride: "global", expected: Grant{ - scope: Scope{ - Id: "global", - Type: scope.Global, - }, - id: "acctpw_foobar", - typ: resource.Unknown, + roleScopeId: "global", + grantScopeId: "global", + id: "acctpw_foobar", + typ: resource.Unknown, actions: map[action.Type]bool{ action.Read: true, }, @@ -1137,11 +930,9 @@ func Test_Parse(t *testing.T) { input: `id={{ user.id}};actions=read,update`, userId: "u_abcd1234", expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: "u_abcd1234", + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: "u_abcd1234", actions: map[action.Type]bool{ action.Update: true, action.Read: true, @@ -1171,11 +962,9 @@ func Test_Parse(t *testing.T) { input: `id={{ account.id}};actions=update,read`, accountId: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPreviousPrefix), expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPreviousPrefix), + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPreviousPrefix), actions: map[action.Type]bool{ action.Update: true, action.Read: true, @@ -1187,11 +976,9 @@ func Test_Parse(t *testing.T) { input: `id={{ account.id}};actions=update,read`, accountId: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPrefix), expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPrefix), + roleScopeId: "o_scope", + grantScopeId: "o_scope", + id: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPrefix), actions: map[action.Type]bool{ action.Update: true, action.Read: true, @@ -1204,11 +991,9 @@ func Test_Parse(t *testing.T) { userId: "u_abcd1234", accountId: fmt.Sprintf("%s_1234567890", globals.PasswordAccountPrefix), expected: Grant{ - scope: Scope{ - Id: "o_scope", - Type: scope.Org, - }, - ids: []string{"u_abcd1234", "acctpw_1234567890"}, + roleScopeId: "o_scope", + grantScopeId: "o_scope", + ids: []string{"u_abcd1234", "acctpw_1234567890"}, actions: map[action.Type]bool{ action.Update: true, action.Read: true, @@ -1217,13 +1002,17 @@ func Test_Parse(t *testing.T) { }, } - _, err := Parse(ctx, "", "") + _, err := Parse(ctx, GrantTuple{RoleScopeId: "", GrantScopeId: "", Grant: ""}) require.Error(t, err) assert.Equal(t, "perms.Parse: missing grant string: parameter violation: error #100", err.Error()) - _, err = Parse(ctx, "", "{}") + _, err = Parse(ctx, GrantTuple{RoleScopeId: "", GrantScopeId: "", Grant: "{}"}) + require.Error(t, err) + assert.Equal(t, "perms.Parse: missing role scope id: parameter violation: error #100", err.Error()) + + _, err = Parse(ctx, GrantTuple{RoleScopeId: "p_abcd", GrantScopeId: "", Grant: "{}"}) require.Error(t, err) - assert.Equal(t, "perms.Parse: missing scope id: parameter violation: error #100", err.Error()) + assert.Equal(t, "perms.Parse: missing grant scope id: parameter violation: error #100", err.Error()) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -1233,7 +1022,7 @@ func Test_Parse(t *testing.T) { if test.scopeOverride != "" { scope = test.scopeOverride } - grant, err := Parse(ctx, scope, test.input, WithUserId(test.userId), WithAccountId(test.accountId)) + grant, err := Parse(ctx, GrantTuple{RoleScopeId: scope, GrantScopeId: scope, Grant: test.input}, WithUserId(test.userId), WithAccountId(test.accountId)) if test.err != "" { require.Error(err) assert.Equal(test.err, err.Error()) @@ -1329,11 +1118,11 @@ func FuzzParse(f *testing.F) { } f.Fuzz(func(t *testing.T, grant string) { - g, err := Parse(ctx, "global", grant, WithSkipFinalValidation(true)) + g, err := Parse(ctx, GrantTuple{GrantScopeId: "global", Grant: grant}, WithSkipFinalValidation(true)) if err != nil { return } - g2, err := Parse(ctx, "global", g.CanonicalString(), WithSkipFinalValidation(true)) + g2, err := Parse(ctx, GrantTuple{GrantScopeId: "global", Grant: g.CanonicalString()}, WithSkipFinalValidation(true)) if err != nil { t.Fatal("Failed to parse canonical string:", err) } @@ -1344,7 +1133,7 @@ func FuzzParse(f *testing.F) { if err != nil { t.Error("Failed to marshal JSON:", err) } - g3, err := Parse(ctx, "global", string(jsonBytes), WithSkipFinalValidation(true)) + g3, err := Parse(ctx, GrantTuple{GrantScopeId: "global", Grant: string(jsonBytes)}, WithSkipFinalValidation(true)) if err != nil { t.Fatal("Failed to parse json string:", err) } diff --git a/internal/perms/output_fields_test.go b/internal/perms/output_fields_test.go index 3aba89dcb5..e9d9c8825c 100644 --- a/internal/perms/output_fields_test.go +++ b/internal/perms/output_fields_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/boundary/internal/types/scope" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -149,14 +150,14 @@ func Test_ACLOutputFields(t *testing.T) { tests := []input{ { name: "default", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, action: action.Read, grants: []string{"ids=u_bar;actions=read,update"}, authorized: true, }, { name: "single value", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{"ids=u_bar;actions=read,update;output_fields=id"}, action: action.Read, fields: []string{"id"}, @@ -164,7 +165,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "compound no overlap", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read,update;output_fields=id", "ids=*;type=host-catalog;actions=read,update;output_fields=version", @@ -175,7 +176,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "compound", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read,update;output_fields=id", "ids=*;type=role;output_fields=version", @@ -186,7 +187,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "wildcard with type", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read,update;output_fields=read", "ids=*;type=role;output_fields=*", @@ -197,7 +198,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "wildcard with wildcard type", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read,update;output_fields=read", "ids=*;type=*;output_fields=*", @@ -208,7 +209,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "subaction exact", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read:self,update;output_fields=version", }, @@ -220,7 +221,7 @@ func Test_ACLOutputFields(t *testing.T) { // If the action is a subaction, parent output fields will apply, in // addition to subaction. This matches authorization. name: "subaction parent action", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read,update;output_fields=version", "ids=u_bar;actions=read:self;output_fields=id", @@ -235,7 +236,7 @@ func Test_ACLOutputFields(t *testing.T) { // non-self actions. This is useful to allow more visibility to self // actions and less in the general case. name: "subaction child action", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;actions=read:self,update;output_fields=version", "ids=u_bar;actions=read;output_fields=id", @@ -246,7 +247,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "initial grant unauthorized with star", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;output_fields=*", "ids=u_bar;actions=delete;output_fields=id", @@ -257,7 +258,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "unauthorized id only", - resource: Resource{ScopeId: "o_myorg", Id: "u_bar", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Id: "u_bar", Type: resource.Role}, grants: []string{ "ids=u_bar;output_fields=name", }, @@ -266,7 +267,7 @@ func Test_ACLOutputFields(t *testing.T) { }, { name: "unauthorized type only", - resource: Resource{ScopeId: "o_myorg", Type: resource.Role}, + resource: Resource{ScopeId: "o_myorg", ParentScopeId: scope.Global.String(), Type: resource.Role}, grants: []string{ "type=role;output_fields=name", }, @@ -279,7 +280,7 @@ func Test_ACLOutputFields(t *testing.T) { t.Run(test.name, func(t *testing.T) { var grants []Grant for _, g := range test.grants { - grant, err := Parse(ctx, "o_myorg", g) + grant, err := Parse(ctx, GrantTuple{RoleScopeId: "o_myorg", GrantScopeId: "o_myorg", Grant: g}) require.NoError(t, err) grants = append(grants, grant) } diff --git a/internal/session/repository.go b/internal/session/repository.go index 8f73923753..b7e98b56ef 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -98,7 +98,7 @@ func (r *Repository) listPermissionWhereClauses() ([]string, []any) { var clauses []string clauses = append(clauses, fmt.Sprintf("project_id = @project_id_%d", inClauseCnt)) - args = append(args, sql.Named(fmt.Sprintf("project_id_%d", inClauseCnt), p.ScopeId)) + args = append(args, sql.Named(fmt.Sprintf("project_id_%d", inClauseCnt), p.GrantScopeId)) if len(p.ResourceIds) > 0 { clauses = append(clauses, fmt.Sprintf("public_id = any(@public_id_%d)", inClauseCnt)) diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 001dbd0a66..9e28a90818 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -54,9 +54,9 @@ func TestRepository_ListSession(t *testing.T) { UserId: composedOf.UserId, Permissions: []perms.Permission{ { - ScopeId: composedOf.ProjectId, - Resource: resource.Session, - Action: action.List, + GrantScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.List, }, }, } @@ -109,9 +109,9 @@ func TestRepository_ListSession(t *testing.T) { perms: &perms.UserPermissions{ Permissions: []perms.Permission{ { - ScopeId: "o_thisIsNotValid", - Resource: resource.Session, - Action: action.List, + GrantScopeId: "o_thisIsNotValid", + Resource: resource.Session, + Action: action.List, }, }, }, @@ -126,9 +126,9 @@ func TestRepository_ListSession(t *testing.T) { perms: &perms.UserPermissions{ Permissions: []perms.Permission{ { - ScopeId: composedOf.ProjectId, - Resource: resource.Session, - Action: action.Read, + GrantScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.Read, }, }, }, @@ -200,10 +200,10 @@ func TestRepository_ListSession(t *testing.T) { UserId: s.UserId, Permissions: []perms.Permission{ { - ScopeId: s.ProjectId, - Resource: resource.Session, - Action: action.List, - OnlySelf: true, + GrantScopeId: s.ProjectId, + Resource: resource.Session, + Action: action.List, + OnlySelf: true, }, }, } @@ -227,9 +227,9 @@ func TestRepository_ListSession(t *testing.T) { UserId: composedOf.UserId, Permissions: []perms.Permission{ { - ScopeId: composedOf.ProjectId, - Resource: resource.Session, - Action: action.List, + GrantScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.List, }, }, } @@ -336,9 +336,9 @@ func TestRepository_ListSessions_Multiple_Scopes(t *testing.T) { for i := 0; i < numPerScope; i++ { composedOf := TestSessionParams(t, conn, wrapper, iamRepo) p = append(p, perms.Permission{ - ScopeId: composedOf.ProjectId, - Resource: resource.Session, - Action: action.List, + GrantScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.List, }) s := TestSession(t, conn, wrapper, composedOf) _ = TestState(t, conn, s.PublicId, StatusActive) diff --git a/internal/session/service_list_ext_test.go b/internal/session/service_list_ext_test.go index 539e1fc11e..175720a7bf 100644 --- a/internal/session/service_list_ext_test.go +++ b/internal/session/service_list_ext_test.go @@ -47,9 +47,9 @@ func TestService_List(t *testing.T) { UserId: composedOf.UserId, Permissions: []perms.Permission{ { - ScopeId: composedOf.ProjectId, - Resource: resource.Session, - Action: action.List, + GrantScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.List, }, }, } diff --git a/internal/target/options_test.go b/internal/target/options_test.go index 1069ba80e9..3e27f270a5 100644 --- a/internal/target/options_test.go +++ b/internal/target/options_test.go @@ -168,9 +168,9 @@ func Test_GetOpts(t *testing.T) { }) t.Run("WithPermissions", func(t *testing.T) { assert := assert.New(t) - opts := GetOpts(WithPermissions([]perms.Permission{{ScopeId: "test1"}, {ScopeId: "test2"}})) + opts := GetOpts(WithPermissions([]perms.Permission{{GrantScopeId: "test1"}, {GrantScopeId: "test2"}})) testOpts := getDefaultOptions() - testOpts.WithPermissions = []perms.Permission{{ScopeId: "test1"}, {ScopeId: "test2"}} + testOpts.WithPermissions = []perms.Permission{{GrantScopeId: "test1"}, {GrantScopeId: "test2"}} assert.Equal(opts, testOpts) }) t.Run("WithCredentialLibraries", func(t *testing.T) { diff --git a/internal/target/repository.go b/internal/target/repository.go index 1b27afb32c..c91624dba5 100644 --- a/internal/target/repository.go +++ b/internal/target/repository.go @@ -387,7 +387,7 @@ func (r *Repository) listPermissionWhereClauses() ([]string, []any) { var clauses []string clauses = append(clauses, fmt.Sprintf("project_id = @project_id_%d", inClauseCnt)) - args = append(args, sql.Named(fmt.Sprintf("project_id_%d", inClauseCnt), p.ScopeId)) + args = append(args, sql.Named(fmt.Sprintf("project_id_%d", inClauseCnt), p.GrantScopeId)) if len(p.ResourceIds) > 0 { clauses = append(clauses, fmt.Sprintf("public_id = any(@public_id_%d)", inClauseCnt)) diff --git a/internal/target/repository_ext_test.go b/internal/target/repository_ext_test.go index 53a5004c1f..9f253c62b3 100644 --- a/internal/target/repository_ext_test.go +++ b/internal/target/repository_ext_test.go @@ -155,16 +155,16 @@ func TestRepository_ListTargets(t *testing.T) { repo, err := target.NewRepository(ctx, rw, rw, testKms, target.WithPermissions([]perms.Permission{ { - ScopeId: proj1.PublicId, - Resource: resource.Target, - Action: action.List, - All: true, + GrantScopeId: proj1.PublicId, + Resource: resource.Target, + Action: action.List, + All: true, }, { - ScopeId: proj2.PublicId, - Resource: resource.Target, - Action: action.List, - All: true, + GrantScopeId: proj2.PublicId, + Resource: resource.Target, + Action: action.List, + All: true, }, }), ) @@ -327,16 +327,16 @@ func TestRepository_ListTargets_Multiple_Scopes(t *testing.T) { repo, err := target.NewRepository(ctx, rw, rw, testKms, target.WithPermissions([]perms.Permission{ { - ScopeId: proj1.PublicId, - Resource: resource.Target, - Action: action.List, - All: true, + GrantScopeId: proj1.PublicId, + Resource: resource.Target, + Action: action.List, + All: true, }, { - ScopeId: proj2.PublicId, - Resource: resource.Target, - Action: action.List, - All: true, + GrantScopeId: proj2.PublicId, + Resource: resource.Target, + Action: action.List, + All: true, }, }), ) @@ -371,10 +371,10 @@ func TestRepository_ListRoles_Above_Default_Count(t *testing.T) { repo, err := target.NewRepository(ctx, rw, rw, testKms, target.WithPermissions([]perms.Permission{ { - ScopeId: proj.PublicId, - Resource: resource.Target, - Action: action.List, - All: true, + GrantScopeId: proj.PublicId, + Resource: resource.Target, + Action: action.List, + All: true, }, })) require.NoError(t, err) diff --git a/internal/target/repository_test.go b/internal/target/repository_test.go index 4663fae122..237537cd05 100644 --- a/internal/target/repository_test.go +++ b/internal/target/repository_test.go @@ -91,8 +91,8 @@ func TestNewRepository(t *testing.T) { kms: testKms, opts: []Option{ WithPermissions([]perms.Permission{ - {ScopeId: "test1", Resource: resource.Target}, - {ScopeId: "test2", Resource: resource.Target}, + {GrantScopeId: "test1", Resource: resource.Target}, + {GrantScopeId: "test2", Resource: resource.Target}, }), }, }, @@ -102,8 +102,8 @@ func TestNewRepository(t *testing.T) { kms: testKms, defaultLimit: db.DefaultLimit, permissions: []perms.Permission{ - {ScopeId: "test1", Resource: resource.Target}, - {ScopeId: "test2", Resource: resource.Target}, + {GrantScopeId: "test1", Resource: resource.Target}, + {GrantScopeId: "test2", Resource: resource.Target}, }, }, wantErr: false, @@ -116,8 +116,8 @@ func TestNewRepository(t *testing.T) { kms: testKms, opts: []Option{ WithPermissions([]perms.Permission{ - {ScopeId: "test1", Resource: resource.Target}, - {ScopeId: "test2", Resource: resource.Host}, + {GrantScopeId: "test1", Resource: resource.Target}, + {GrantScopeId: "test2", Resource: resource.Host}, }), }, }, diff --git a/internal/target/service_list_ext_test.go b/internal/target/service_list_ext_test.go index ee23b2e752..7be5f01c55 100644 --- a/internal/target/service_list_ext_test.go +++ b/internal/target/service_list_ext_test.go @@ -65,10 +65,10 @@ func TestService_List(t *testing.T) { repo, err := target.NewRepository(ctx, rw, rw, testKms, target.WithPermissions([]perms.Permission{ { - ScopeId: proj1.PublicId, - Resource: resource.Target, - Action: action.List, - All: true, + GrantScopeId: proj1.PublicId, + Resource: resource.Target, + Action: action.List, + All: true, }, }), ) diff --git a/internal/tests/api/users/user_test.go b/internal/tests/api/users/user_test.go index b2243cbc1d..06353ba996 100644 --- a/internal/tests/api/users/user_test.go +++ b/internal/tests/api/users/user_test.go @@ -187,7 +187,7 @@ func TestListResolvableAliases(t *testing.T) { tarClient := targets.NewClient(client) resp, err := tarClient.List(tc.Context(), "global", targets.WithRecursive(true)) require.NoError(err) - assert.Len(resp.Items, 2) + require.Len(resp.Items, 2) firstTargetId := resp.Items[0].Id secondTargetId := resp.Items[1].Id From c9578b441a82a7c72f599bbb4f4c8938d8113583 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Fri, 13 Sep 2024 19:31:04 +0000 Subject: [PATCH 11/15] perf(db): Add indexes on foreign key to improve delete performance This adds indexes for a few categories: 1. Foreign keys on the session table. These are set to `null` when the referenced row is deleted. These indexes will help to more efficiently set these `null` values in the case where a target is deleted, or and auth token is deleted. 2. Foreign keys from other tables to the session table. These are either set to `null` or cascade deleted when a session is deleted. These indexes help with these update/deletes when a session is deleted. 3. A multi-column index on session_state. This helps with the query that is used to delete all terminated sessions that have been terminated for over an hour. (cherry picked from commit 263a96e9ddd197cb36f2c4c74e91ac034112321e) --- .../postgres/91/02_indexes_fk_delete.up.sql | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 internal/db/schema/migrations/oss/postgres/91/02_indexes_fk_delete.up.sql diff --git a/internal/db/schema/migrations/oss/postgres/91/02_indexes_fk_delete.up.sql b/internal/db/schema/migrations/oss/postgres/91/02_indexes_fk_delete.up.sql new file mode 100644 index 0000000000..4b1b4c3a97 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/91/02_indexes_fk_delete.up.sql @@ -0,0 +1,46 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + -- Index to help when setting target_id to null on a session + -- when the corresponding target is deleted. + drop index if exists session_target_id_ix; + create index session_target_id_ix + on session (target_id); + + -- Index to help when setting auth_token_id to null on a session + -- when the corresponding auth_token is deleted. + drop index if exists session_auth_token_id_ix; + create index session_auth_token_id_ix + on session (auth_token_id); + + -- Index to help when setting session_id to null on a credential_vault_credential + -- when the corresponding session is deleted. + drop index if exists credential_vault_credential_session_id_ix; + create index credential_vault_credential_session_id_ix + on credential_vault_credential (session_id); + + -- Index to help delete cascade of session_worker_protocol + -- when the corresponding session is deleted. + drop index if exists session_worker_protocol_session_id_ix; + create index session_worker_protocol_session_id_ix + on session_worker_protocol (session_id); + + -- Index to help when setting session_id to null on recording_connection + -- when the corresponding session is deleted. + drop index if exists recording_connection_session_id_ix; + create index recording_connection_session_id_ix + on recording_connection (session_id); + + -- Index to help delete of terminated sessions. + drop index if exists session_state_state_start_time_ix; + create index session_state_state_terminated_start_time_ix + on session_state (state, start_time) + where state = 'terminated'; + + analyze session, + credential_vault_credential, + session_worker_protocol, + recording_connection, + session_state; +commit; From 091ed6bacb1bad1229186beb88f1825c02ff14fb Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Mon, 16 Sep 2024 13:43:32 +0000 Subject: [PATCH 12/15] perf(db): Reorder index columns for primary keys When processing a controller API request, system uses a query to fetch the grants for the requesting user. This query requires pulling together information from several tables, and is currently performing many sequential scans to do so. For a number of the tables involved, there are indexes from multi-column primary keys, however, due to the order of the columns in the index, the postgres planner would need to do a full scan of the index, which can be less efficient than a sequential scan, so instead it will perform a sequential scan of the table. This recreates the primary keys while swapping the order of the columns in the primary key definition, and thus the order in the index. By doing so, the planner will not need to perform a full index scan, and will be more likely to use the index when executing the grants query. (cherry picked from commit 5577b3c358e9d45cd7ec02fc951e45d52e50595e) --- .../91/03_indexes_grants_query.up.sql | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 internal/db/schema/migrations/oss/postgres/91/03_indexes_grants_query.up.sql diff --git a/internal/db/schema/migrations/oss/postgres/91/03_indexes_grants_query.up.sql b/internal/db/schema/migrations/oss/postgres/91/03_indexes_grants_query.up.sql new file mode 100644 index 0000000000..97c779dbdd --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/91/03_indexes_grants_query.up.sql @@ -0,0 +1,39 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + -- For each of these tables, swap the ordering of the + -- columns in the index for the primary key. + -- This helps the grants query that contains + -- several where clauses on what is currently the second + -- column in these indexes. By swapping the order, this + -- will make it more likely that the query planner will + -- choose to use the index. + -- See: https://www.postgresql.org/docs/current/indexes-multicolumn.html + + alter table auth_oidc_managed_group_member_account + drop constraint auth_oidc_managed_group_member_account_pkey, + add primary key (member_id, managed_group_id); + + alter table iam_managed_group_role + drop constraint iam_managed_group_role_pkey, + add primary key (principal_id, role_id); + + alter table iam_group_member_user + drop constraint iam_group_member_user_pkey, + add primary key (member_id, group_id); + + alter table iam_group_role + drop constraint iam_group_role_pkey, + add primary key (principal_id, role_id); + + alter table iam_user_role + drop constraint iam_user_role_pkey, + add primary key (principal_id, role_id); + + analyze auth_oidc_managed_group_member_account, + iam_managed_group_role, + iam_group_member_user, + iam_group_role, + iam_user_role; +commit; From 33037b488718f88442f6601cd7015f5b91da795e Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Mon, 16 Sep 2024 13:52:37 +0000 Subject: [PATCH 13/15] perf(db): Use statement trigger for marking deleted sessions When sessions are deleted, a trigger is used to insert records into the session_deleted table. This table is utilized by the sessions list endpoint when using a refresh token to inform a client of the sessions that have been deleted since the last request. We delete sessions in bulk via a controller job to delete sessions that have been terminated over an hour ago, which results in the trigger running a large number of separate insert statements while processing the delete statement. This changes the trigger to run once for the delete statement, instead of for each row, resulting in a single bulk insert statement to the session_deleted table. This new trigger function also avoids the use of `on conflict`. When testing this function, while the single statement was still faster than running multiple inserts, the `on conflict` still added significant overhead, even when there were no conflicts. It should be safe to perform the insert without the `on conflict`, since the same ID should never be deleted more than once if it is successfully deleted. (cherry picked from commit 3569b36c6eb2bee7a4e36e7584fcc6bbf1124f95) --- ..._insert_session_delete_on_statement.up.sql | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 internal/db/schema/migrations/oss/postgres/91/04_insert_session_delete_on_statement.up.sql diff --git a/internal/db/schema/migrations/oss/postgres/91/04_insert_session_delete_on_statement.up.sql b/internal/db/schema/migrations/oss/postgres/91/04_insert_session_delete_on_statement.up.sql new file mode 100644 index 0000000000..99e8eebad2 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/91/04_insert_session_delete_on_statement.up.sql @@ -0,0 +1,25 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + create function bulk_insert_deleted_ids() returns trigger + as $$ + begin + execute format('insert into %I (public_id, delete_time) + select o.public_id, now() + from old_table o;', + tg_argv[0]); + return null; + end; + $$ language plpgsql; + comment on function bulk_insert_deleted_ids is + 'bulk_insert_deleted_ids is a function that inserts records into the table ' + 'specified by the first trigger argument. It takes the public IDs from the ' + 'set of rows that where deleted and the current timestamp.'; + + drop trigger insert_deleted_id on session; + create trigger bulk_insert_deleted_ids + after delete on session + referencing old table as old_table + for each statement execute function bulk_insert_deleted_ids('session_deleted'); +commit; From a92c88dfb5640ee7ef9a3ceab5c0c7b26dca9d43 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Mon, 16 Sep 2024 16:48:03 +0000 Subject: [PATCH 14/15] refact(db): Replace get_deleted_tables function with a view This function was returning the set of deletion tables. A view seems better suited for this task since it would allow for applying additional filters of the result set. This was particularly necessary to easily make changes to some sqltests due to switching the delete trigger for the session table. (cherry picked from commit 32d41639892950b21fafc39b50c10b997c6d5121) --- .../81/01_deleted_tables_and_triggers.up.sql | 1 + .../91/05_deletion_tables_view.up.sql | 18 +++++++++++ .../tests/purge/deleted_table_tests.sql | 32 +++++++++++++++++-- internal/pagination/purge/purge_test.go | 2 +- internal/pagination/purge/query.go | 3 +- 5 files changed, 51 insertions(+), 5 deletions(-) create mode 100644 internal/db/schema/migrations/oss/postgres/91/05_deletion_tables_view.up.sql diff --git a/internal/db/schema/migrations/oss/postgres/81/01_deleted_tables_and_triggers.up.sql b/internal/db/schema/migrations/oss/postgres/81/01_deleted_tables_and_triggers.up.sql index 6d47514eab..844625ac08 100644 --- a/internal/db/schema/migrations/oss/postgres/81/01_deleted_tables_and_triggers.up.sql +++ b/internal/db/schema/migrations/oss/postgres/81/01_deleted_tables_and_triggers.up.sql @@ -308,6 +308,7 @@ begin; 'affected by the trigger and the current timestamp. It is used to populate rows ' 'of the deleted tables.'; + -- Removed in 91/05_deletion_tables_view and replaced with a view. create function get_deletion_tables() returns setof name as $$ select c.relname diff --git a/internal/db/schema/migrations/oss/postgres/91/05_deletion_tables_view.up.sql b/internal/db/schema/migrations/oss/postgres/91/05_deletion_tables_view.up.sql new file mode 100644 index 0000000000..ccdf35647c --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/91/05_deletion_tables_view.up.sql @@ -0,0 +1,18 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + -- Originially added in 81/01_deleted_tables_and_triggers.up.sql + -- This is being replaced with a view. + drop function get_deletion_tables; + + -- This view uses the pg_catalog to find all tables that end in _deleted and are visibile. + -- See: https://www.postgresql.org/docs/current/catalog-pg-class.html + -- https://www.postgresql.org/docs/current/functions-info.html#FUNCTIONS-INFO-SCHEMA + create view deletion_table as + select c.relname as tablename + from pg_catalog.pg_class c + where c.relkind in ('r') -- r = ordinary table + and c.relname operator(pg_catalog.~) '^(.+_deleted)$' collate pg_catalog.default + and pg_catalog.pg_table_is_visible(c.oid); +commit; diff --git a/internal/db/sqltest/tests/purge/deleted_table_tests.sql b/internal/db/sqltest/tests/purge/deleted_table_tests.sql index 9d303f2a04..b0bb36752a 100644 --- a/internal/db/sqltest/tests/purge/deleted_table_tests.sql +++ b/internal/db/sqltest/tests/purge/deleted_table_tests.sql @@ -21,6 +21,15 @@ begin; ); $$ language sql; + -- tests that the deletion table has the bulk insert trigger + create function has_bulk_insert_trigger(deletion_table_name name) returns text + as $$ + select * from collect_tap( + has_trigger(op_table(deletion_table_name), 'bulk_insert_deleted_ids'), + trigger_is(op_table(deletion_table_name), 'bulk_insert_deleted_ids', 'bulk_insert_deleted_ids') + ); + $$ language sql; + -- tests the public_id column create function has_public_id(deletion_table_name name) returns text as $$ @@ -72,15 +81,32 @@ begin; ); $$ language sql; + -- like above, but using the bulk delete trigger + create function test_bulk_deletion_table(deletion_table_name name) returns text + as $$ + select * from collect_tap( + has_correct_tables(deletion_table_name), + has_public_id(deletion_table_name), + has_delete_time(deletion_table_name), + has_delete_time_index(deletion_table_name), + has_bulk_insert_trigger(deletion_table_name) + ); + $$ language sql; + -- 11 tests for each deletion table select plan(a.table_count::integer) from ( select 11 * count(*) as table_count - from get_deletion_tables() + from deletion_table ) as a; - select test_deletion_table(a) - from get_deletion_tables() a; + select test_deletion_table(a.tablename) + from deletion_table a + where a.tablename not in ('session_deleted'); + + select test_bulk_deletion_table(a.tablename) + from deletion_table a + where a.tablename in ('session_deleted'); select * from finish(); rollback; diff --git a/internal/pagination/purge/purge_test.go b/internal/pagination/purge/purge_test.go index 8b1c41bb42..6ef25ba4fb 100644 --- a/internal/pagination/purge/purge_test.go +++ b/internal/pagination/purge/purge_test.go @@ -26,7 +26,7 @@ func TestPurgeTables(t *testing.T) { t.Errorf("error getting db connection %s", err) } - rows, err := db.Query("select get_deletion_tables()") + rows, err := db.Query("select tablename from deletion_table") if err != nil { t.Errorf("unable to query for deletion tables %s", err) } diff --git a/internal/pagination/purge/query.go b/internal/pagination/purge/query.go index 966016ccd1..daccd014a9 100644 --- a/internal/pagination/purge/query.go +++ b/internal/pagination/purge/query.go @@ -5,7 +5,8 @@ package purge const ( getDeletionTablesQuery = ` -select get_deletion_tables(); +select tablename + from deletion_table; ` deleteQueryTemplate = ` delete from %s where delete_time < now() - interval '30 days' From 632d048f7bcaa06b60d68b11ad8f31fcbd252240 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Mon, 23 Sep 2024 15:29:16 +0000 Subject: [PATCH 15/15] lint: Fix linter errors (cherry picked from commit 51a2b2035f9c6bb51fe8f5f1defeeb514183fecd) --- internal/clientcache/internal/cache/refresh.go | 5 +++-- internal/clientcache/internal/daemon/server_test.go | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/internal/clientcache/internal/cache/refresh.go b/internal/clientcache/internal/cache/refresh.go index 015ca44800..a7193ac963 100644 --- a/internal/clientcache/internal/cache/refresh.go +++ b/internal/clientcache/internal/cache/refresh.go @@ -19,8 +19,9 @@ import ( "github.com/hashicorp/go-hclog" ) -// This is used as an internal error to indicate that a refresh is already in -// progress, as a signal that the caller may want to handle it a different way. +// ErrRefreshInProgress is used as an internal error to indicate that a refresh +// is already in progress, as a signal that the caller may want to handle it a +// different way. var ErrRefreshInProgress error = stderrors.New("cache refresh in progress") type RefreshService struct { diff --git a/internal/clientcache/internal/daemon/server_test.go b/internal/clientcache/internal/daemon/server_test.go index 660f73dd3f..fe212fa917 100644 --- a/internal/clientcache/internal/daemon/server_test.go +++ b/internal/clientcache/internal/daemon/server_test.go @@ -28,7 +28,9 @@ func Test_openStore(t *testing.T) { require.NotNil(t, store) assert.FileExists(t, tmpDir+"/test.db") rw := db.New(store) - rw.Query(ctx, "select * from target", nil) + rows, err := rw.Query(ctx, "select * from target", nil) + require.NoError(t, err) + rows.Close() }) t.Run("homedir", func(t *testing.T) { tmpDir := t.TempDir()