From 3cdd07889be5e4e6f21d044cb60fad1face4be08 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 28 Aug 2024 10:15:04 -0700 Subject: [PATCH] Restrict cache results to 250 entries by default (#5049) A query parameter `max_result_set_size` can be specified to set a per-request limit, with `-1` meaning all values. If there were more entries than the max size used for the request, an `incomplete` boolean is set in the response. --- internal/clientcache/cmd/search/search.go | 43 ++++- internal/clientcache/internal/cache/consts.go | 17 ++ .../clientcache/internal/cache/options.go | 21 ++- .../internal/cache/options_test.go | 15 +- .../internal/cache/refresh_test.go | 166 +++++++++++------- .../cache/repository_resolvable_aliases.go | 22 ++- .../repository_resolvable_aliases_test.go | 142 +++++++++++++-- .../internal/cache/repository_sessions.go | 22 ++- .../cache/repository_sessions_test.go | 142 +++++++++++++-- .../internal/cache/repository_targets.go | 22 ++- .../internal/cache/repository_targets_test.go | 142 +++++++++++++-- internal/clientcache/internal/cache/search.go | 68 ++++--- .../internal/daemon/404_handler.go | 2 +- .../internal/daemon/search_handler.go | 37 ++-- 14 files changed, 703 insertions(+), 158 deletions(-) create mode 100644 internal/clientcache/internal/cache/consts.go diff --git a/internal/clientcache/cmd/search/search.go b/internal/clientcache/cmd/search/search.go index f5fee5e475..e095671811 100644 --- a/internal/clientcache/cmd/search/search.go +++ b/internal/clientcache/cmd/search/search.go @@ -7,6 +7,7 @@ import ( "context" stderrors "errors" "fmt" + "math" "net/url" "os" "strings" @@ -40,9 +41,10 @@ var ( type SearchCommand struct { *base.Command - flagQuery string - flagResource string - flagForceRefresh bool + flagQuery string + flagResource string + flagForceRefresh bool + flagMaxResultSetSize int64 } func (c *SearchCommand) Synopsis() string { @@ -112,6 +114,12 @@ func (c *SearchCommand) Flags() *base.FlagSets { Usage: `Specifies the resource type to search over`, Completion: complete.PredictSet(supportedResourceTypes...), }) + f.Int64Var(&base.Int64Var{ + Name: "max-result-set-size", + Target: &c.flagMaxResultSetSize, + Usage: `Specifies an override to the default maximum result set size. Set to -1 to disable the limit. 0 will use the default.`, + Completion: complete.PredictNothing, + }) f.BoolVar(&base.BoolVar{ Name: "force-refresh", Target: &c.flagForceRefresh, @@ -148,6 +156,15 @@ func (c *SearchCommand) Run(args []string) int { return base.CommandUserError } + switch { + case c.flagMaxResultSetSize < -1: + c.PrintCliError(stderrors.New("Max result set size must be greater than or equal to -1")) + return base.CommandUserError + case c.flagMaxResultSetSize > math.MaxInt: + c.PrintCliError(stderrors.New(fmt.Sprintf("Max result set size must be less than or equal to the %v", math.MaxInt))) + return base.CommandUserError + } + resp, result, apiErr, err := c.Search(ctx) if err != nil { c.PrintCliError(err) @@ -164,6 +181,9 @@ func (c *SearchCommand) Run(args []string) int { return base.CommandCliError } default: + if result.Incomplete { + c.UI.Warn("The maximum result set size was reached and the search results are incomplete. Please narrow your search or adjust the -max-result-set-size parameter.") + } switch { case len(result.ResolvableAliases) > 0: c.UI.Output(printAliasListTable(result.ResolvableAliases)) @@ -199,6 +219,9 @@ func (c *SearchCommand) Search(ctx context.Context) (*api.Response, *daemon.Sear authTokenId: strings.Join(tSlice[:2], "_"), forceRefresh: c.flagForceRefresh, } + if c.flagMaxResultSetSize != 0 { + tf.maxResultSetSize = int(c.flagMaxResultSetSize) + } var opts []client.Option if c.FlagOutputCurlString { opts = append(opts, client.WithOutputCurlString()) @@ -230,6 +253,9 @@ func search(ctx context.Context, daemonPath string, fb filterBy, opt ...client.O if fb.forceRefresh { q.Add("force_refresh", "true") } + if fb.maxResultSetSize != 0 { + q.Add("max_result_set_size", fmt.Sprintf("%d", fb.maxResultSetSize)) + } resp, err := c.Get(ctx, "/v1/search", q, opt...) if err != nil { return nil, nil, nil, fmt.Errorf("Error when sending request to the cache: %w.", err) @@ -424,9 +450,10 @@ func printSessionListTable(items []*sessions.Session) string { } type filterBy struct { - flagFilter string - flagQuery string - authTokenId string - resource string - forceRefresh bool + flagFilter string + flagQuery string + authTokenId string + resource string + forceRefresh bool + maxResultSetSize int } diff --git a/internal/clientcache/internal/cache/consts.go b/internal/clientcache/internal/cache/consts.go new file mode 100644 index 0000000000..2321907d6d --- /dev/null +++ b/internal/clientcache/internal/cache/consts.go @@ -0,0 +1,17 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +/* +Package cache contains the domain logic for the client cache. +*/ +package cache + +const ( + // defaultLimitedResultSetSize is the default number of results to + // return when limiting + defaultLimitedResultSetSize = 250 + + // unlimitedMaxResultSetSize is the value to use when we want to return all + // results + unlimitedMaxResultSetSize = -1 +) diff --git a/internal/clientcache/internal/cache/options.go b/internal/clientcache/internal/cache/options.go index 6f68ebcea3..09e34ee9d8 100644 --- a/internal/clientcache/internal/cache/options.go +++ b/internal/clientcache/internal/cache/options.go @@ -4,6 +4,8 @@ package cache import ( + stderrors "errors" + "github.com/hashicorp/go-dbw" ) @@ -16,6 +18,7 @@ type options struct { withTargetRetrievalFunc TargetRetrievalFunc withSessionRetrievalFunc SessionRetrievalFunc withIgnoreSearchStaleness bool + withMaxResultSetSize int } // Option - how options are passed as args @@ -23,7 +26,8 @@ type Option func(*options) error func getDefaultOptions() options { return options{ - withDbType: dbw.Sqlite, + withDbType: dbw.Sqlite, + withMaxResultSetSize: defaultLimitedResultSetSize, } } @@ -94,3 +98,18 @@ func WithIgnoreSearchStaleness(b bool) Option { return nil } } + +// WithMaxResultSetSize provides an option for limiting the result set, e.g. +// when no filter is provided on a list. A 0 does nothing (keeps the default). +func WithMaxResultSetSize(with int) Option { + return func(o *options) error { + switch { + case with == 0: + return nil + case with < -1: + return stderrors.New("max result set size must be -1 or greater") + } + o.withMaxResultSetSize = with + return nil + } +} diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 2475356203..aed96cd5ba 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -22,7 +22,8 @@ func Test_GetOpts(t *testing.T) { opts, err := getOpts() require.NoError(t, err) testOpts := options{ - withDbType: dbw.Sqlite, + withDbType: dbw.Sqlite, + withMaxResultSetSize: defaultLimitedResultSetSize, } assert.Equal(t, opts, testOpts) }) @@ -93,4 +94,16 @@ func Test_GetOpts(t *testing.T) { testOpts.withIgnoreSearchStaleness = true assert.Equal(t, opts, testOpts) }) + t.Run("withMaxResultSetSize", func(t *testing.T) { + opts, err := getOpts(WithMaxResultSetSize(defaultLimitedResultSetSize)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withMaxResultSetSize = defaultLimitedResultSetSize + assert.Equal(t, opts, testOpts) + opts, err = getOpts(WithMaxResultSetSize(0)) + require.Nil(t, err) + assert.Equal(t, opts, testOpts) + _, err = getOpts(WithMaxResultSetSize(-2)) + require.Error(t, err) + }) } diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index f7c0c003aa..a24dedf198 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -440,15 +440,18 @@ func TestRefreshForSearch(t *testing.T) { assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedTargets) + require.NoError(t, err) + assert.Empty(t, cachedTargets.Targets) + assert.Empty(t, cachedTargets.ResolvableAliases) + assert.Empty(t, cachedTargets.Sessions) + assert.False(t, cachedTargets.Incomplete) // Now load up a few resources and a token, and trying again should // see the RefreshForSearch update more fields. assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // Let 2 milliseconds pass so the items are stale enough time.Sleep(2 * time.Millisecond) @@ -456,7 +459,7 @@ func TestRefreshForSearch(t *testing.T) { assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[2:], cachedTargets) + assert.ElementsMatch(t, retTargets[2:], cachedTargets.Targets) }) t.Run("targets forced refreshed for searching", func(t *testing.T) { @@ -492,27 +495,30 @@ func TestRefreshForSearch(t *testing.T) { assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedTargets) + require.NoError(t, err) + assert.Empty(t, cachedTargets.Targets) + assert.Empty(t, cachedTargets.ResolvableAliases) + assert.Empty(t, cachedTargets.Sessions) + assert.False(t, cachedTargets.Incomplete) // Now load up a few resources and a token, and trying again should // see the RefreshForSearch update more fields. assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // No refresh happened because it is not considered stale assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // Now force refresh assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Targets, append(opts, WithIgnoreSearchStaleness(true))...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[2:], cachedTargets) + assert.ElementsMatch(t, retTargets[2:], cachedTargets.Targets) }) t.Run("no refresh token no refresh for search", func(t *testing.T) { @@ -537,8 +543,11 @@ func TestRefreshForSearch(t *testing.T) { assert.ErrorContains(t, err, ErrRefreshNotSupported.Error()) got, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) // Now that we know that this user doesn't support refresh tokens, they // wont be refreshed any more, and we wont see the error when refreshing @@ -556,8 +565,11 @@ func TestRefreshForSearch(t *testing.T) { assert.Nil(t, err) got, err = r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) // Now simulate the controller updating to support refresh tokens and // the resources starting to be cached. @@ -569,7 +581,7 @@ func TestRefreshForSearch(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Targets, 2) }) t.Run("sessions refreshed for searching", func(t *testing.T) { @@ -605,19 +617,22 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no sessions were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedSessions) + require.NoError(t, err) + assert.Empty(t, cachedSessions.Targets) + assert.Empty(t, cachedSessions.ResolvableAliases) + assert.Empty(t, cachedSessions.Sessions) + assert.False(t, cachedSessions.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[2:], cachedSessions) + assert.ElementsMatch(t, retSess[2:], cachedSessions.Sessions) }) t.Run("sessions forced refreshed for searching", func(t *testing.T) { @@ -654,25 +669,28 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no sessions were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedSessions) + require.NoError(t, err) + assert.Empty(t, cachedSessions.Targets) + assert.Empty(t, cachedSessions.ResolvableAliases) + assert.Empty(t, cachedSessions.Sessions) + assert.False(t, cachedSessions.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Refresh for search doesn't refresh anything because it isn't stale assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Now force the refresh and see things get updated assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, Sessions, append(opts, WithIgnoreSearchStaleness(true))...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[2:], cachedSessions) + assert.ElementsMatch(t, retSess[2:], cachedSessions.Sessions) }) t.Run("aliases refreshed for searching", func(t *testing.T) { @@ -708,19 +726,22 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no aliases were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedAliases) + require.NoError(t, err) + assert.Empty(t, cachedAliases.Targets) + assert.Empty(t, cachedAliases.ResolvableAliases) + assert.Empty(t, cachedAliases.Sessions) + assert.False(t, cachedAliases.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAl[:3], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAl[:3], cachedAliases.ResolvableAliases) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAl[2:], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAl[2:], cachedAliases.ResolvableAliases) }) t.Run("aliases forced refreshed for searching", func(t *testing.T) { @@ -757,25 +778,28 @@ func TestRefreshForSearch(t *testing.T) { // First call doesn't sync anything because no aliases were already synced yet assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedAliases) + require.NoError(t, err) + assert.Empty(t, cachedAliases.Targets) + assert.Empty(t, cachedAliases.ResolvableAliases) + assert.Empty(t, cachedAliases.Sessions) + assert.False(t, cachedAliases.Incomplete) assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAls[:3], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAls[:3], cachedAliases.ResolvableAliases) // Refresh for search doesn't refresh anything because it isn't stale assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAls[:3], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAls[:3], cachedAliases.ResolvableAliases) // Now force the refresh and see things get updated assert.NoError(t, rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, append(opts, WithIgnoreSearchStaleness(true))...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.ElementsMatch(t, retAls[2:], cachedAliases) + require.NoError(t, err) + assert.ElementsMatch(t, retAls[2:], cachedAliases.ResolvableAliases) }) } @@ -829,13 +853,13 @@ func TestRefresh(t *testing.T) { cachedTargets, err := r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[:3], cachedTargets) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err = r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets[2:], cachedTargets) + assert.ElementsMatch(t, retTargets[2:], cachedTargets.Targets) }) t.Run("set sessions", func(t *testing.T) { @@ -870,13 +894,13 @@ func TestRefresh(t *testing.T) { assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[:3], cachedSessions) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err = r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess[2:], cachedSessions) + assert.ElementsMatch(t, retSess[2:], cachedSessions.Sessions) }) t.Run("set aliases", func(t *testing.T) { @@ -911,13 +935,13 @@ func TestRefresh(t *testing.T) { assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retAls[:3], cachedAliases) + assert.ElementsMatch(t, retAls[:3], cachedAliases.ResolvableAliases) // Second call removes the first 2 resources from the cache and adds the last assert.NoError(t, rs.Refresh(ctx, opts...)) cachedAliases, err = r.ListResolvableAliases(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retAls[2:], cachedAliases) + assert.ElementsMatch(t, retAls[2:], cachedAliases.ResolvableAliases) }) t.Run("error propagates up", func(t *testing.T) { @@ -1022,8 +1046,11 @@ func TestRecheckCachingSupport(t *testing.T) { WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)))) got, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), @@ -1032,8 +1059,11 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) // now a full fetch will work since the user has resources and no refresh token assert.NoError(t, rs.RecheckCachingSupport(ctx, @@ -1057,8 +1087,11 @@ func TestRecheckCachingSupport(t *testing.T) { WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), @@ -1067,16 +1100,22 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) }) t.Run("aliases", func(t *testing.T) { @@ -1094,8 +1133,11 @@ func TestRecheckCachingSupport(t *testing.T) { WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err := r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) err = rs.Refresh(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), @@ -1104,16 +1146,22 @@ func TestRecheckCachingSupport(t *testing.T) { assert.ErrorIs(t, err, ErrRefreshNotSupported) got, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) assert.NoError(t, rs.RecheckCachingSupport(ctx, WithAliasRetrievalFunc(testNoRefreshRetrievalFuncForId[*aliases.Alias](t)), WithTargetRetrievalFunc(testNoRefreshRetrievalFunc[*targets.Target](t)), WithSessionRetrievalFunc(testNoRefreshRetrievalFunc[*sessions.Session](t)))) got, err = r.ListResolvableAliases(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, got) + require.NoError(t, err) + assert.Empty(t, got.Targets) + assert.Empty(t, got.ResolvableAliases) + assert.Empty(t, got.Sessions) + assert.False(t, got.Incomplete) }) t.Run("error propagates up", func(t *testing.T) { diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases.go b/internal/clientcache/internal/cache/repository_resolvable_aliases.go index 1bec59e9ae..39e3bf5f36 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases.go @@ -288,20 +288,20 @@ func upsertResolvableAliases(ctx context.Context, w db.Writer, u *user, in []*al return nil } -func (r *Repository) ListResolvableAliases(ctx context.Context, authTokenId string) ([]*aliases.Alias, error) { +func (r *Repository) ListResolvableAliases(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).ListResolvableAliases" switch { case authTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") } - ret, err := r.searchResolvableAliases(ctx, "true", nil, withAuthTokenId(authTokenId)) + ret, err := r.searchResolvableAliases(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) QueryResolvableAliases(ctx context.Context, authTokenId, query string) ([]*aliases.Alias, error) { +func (r *Repository) QueryResolvableAliases(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).QueryResolvableAliases" switch { case authTokenId == "": @@ -314,14 +314,14 @@ func (r *Repository) QueryResolvableAliases(ctx context.Context, authTokenId, qu if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.InvalidParameter)) } - ret, err := r.searchResolvableAliases(ctx, w.Condition, w.Args, withAuthTokenId(authTokenId)) + ret, err := r.searchResolvableAliases(ctx, w.Condition, w.Args, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) searchResolvableAliases(ctx context.Context, condition string, searchArgs []any, opt ...Option) ([]*aliases.Alias, error) { +func (r *Repository) searchResolvableAliases(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).searchResolvableAliases" switch { case condition == "": @@ -346,7 +346,7 @@ func (r *Repository) searchResolvableAliases(ctx context.Context, condition stri } var cachedResolvableAliases []*ResolvableAlias - if err := r.rw.SearchWhere(ctx, &cachedResolvableAliases, condition, searchArgs, db.WithLimit(-1)); err != nil { + if err := r.rw.SearchWhere(ctx, &cachedResolvableAliases, condition, searchArgs, db.WithLimit(opts.withMaxResultSetSize+1)); err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -358,7 +358,15 @@ func (r *Repository) searchResolvableAliases(ctx context.Context, condition stri } retAliases = append(retAliases, &a) } - return retAliases, nil + + sr := &SearchResult{ + ResolvableAliases: retAliases, + } + if opts.withMaxResultSetSize > 0 && len(sr.ResolvableAliases) > opts.withMaxResultSetSize { + sr.ResolvableAliases = sr.ResolvableAliases[:opts.withMaxResultSetSize] + sr.Incomplete = true + } + return sr, nil } type ResolvableAlias struct { diff --git a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go index 948241ef61..086bf17c4c 100644 --- a/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go +++ b/internal/clientcache/internal/cache/repository_resolvable_aliases_test.go @@ -6,6 +6,7 @@ package cache import ( "context" "encoding/json" + "strconv" "sync" "testing" @@ -222,7 +223,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err := r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.ResolvableAliases, 2) // Refreshing again uses the refresh token and get additional aliases, appending // them to the response @@ -232,7 +233,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.ResolvableAliases, 3) // Refreshing again wont return any more resources, but also none should be // removed @@ -242,7 +243,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.ResolvableAliases, 3) // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshResolvableAliases(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, @@ -251,7 +252,7 @@ func TestRepository_RefreshAliases_withRefreshTokens(t *testing.T) { got, err = r.ListResolvableAliases(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.ResolvableAliases, 2) } func TestRepository_ListAliases(t *testing.T) { @@ -332,13 +333,73 @@ func TestRepository_ListAliases(t *testing.T) { t.Run("wrong user gets no aliases", func(t *testing.T) { l, err := r.ListResolvableAliases(ctx, kt2.AuthTokenId) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.ResolvableAliases) }) t.Run("correct token gets aliases", func(t *testing.T) { l, err := r.ListResolvableAliases(ctx, kt1.AuthTokenId) assert.NoError(t, err) - assert.Len(t, l, len(ss)) - assert.ElementsMatch(t, l, ss) + assert.Len(t, l.ResolvableAliases, len(ss)) + assert.ElementsMatch(t, l.ResolvableAliases, ss) + }) +} + +func TestRepository_ListAliasesLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*aliases.Alias + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, alias("s"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: ResolvableAliases, + AuthTokenId: kt.AuthTokenId, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } @@ -442,13 +503,74 @@ func TestRepository_QueryAliases(t *testing.T) { t.Run("wrong token gets no aliases", func(t *testing.T) { l, err := r.QueryResolvableAliases(ctx, kt2.AuthTokenId, query) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.ResolvableAliases) }) t.Run("correct token gets aliases", func(t *testing.T) { l, err := r.QueryResolvableAliases(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) - assert.Len(t, l, 2) - assert.ElementsMatch(t, l, ss[0:2]) + assert.Len(t, l.ResolvableAliases, 2) + assert.ElementsMatch(t, l.ResolvableAliases, ss[0:2]) + }) +} + +func TestRepository_QueryResolvableAliasesLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*aliases.Alias + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, alias("s"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshResolvableAliases(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId(t, [][]*aliases.Alias{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: ResolvableAliases, + AuthTokenId: kt.AuthTokenId, + Query: `(type % 'target')`, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.ResolvableAliases, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } diff --git a/internal/clientcache/internal/cache/repository_sessions.go b/internal/clientcache/internal/cache/repository_sessions.go index 99f5bad1c3..904202252a 100644 --- a/internal/clientcache/internal/cache/repository_sessions.go +++ b/internal/clientcache/internal/cache/repository_sessions.go @@ -290,20 +290,20 @@ func upsertSessions(ctx context.Context, w db.Writer, u *user, in []*sessions.Se return nil } -func (r *Repository) ListSessions(ctx context.Context, authTokenId string) ([]*sessions.Session, error) { +func (r *Repository) ListSessions(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).ListSessions" switch { case authTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") } - ret, err := r.searchSessions(ctx, "true", nil, withAuthTokenId(authTokenId)) + ret, err := r.searchSessions(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) QuerySessions(ctx context.Context, authTokenId, query string) ([]*sessions.Session, error) { +func (r *Repository) QuerySessions(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).QuerySessions" switch { case authTokenId == "": @@ -316,14 +316,14 @@ func (r *Repository) QuerySessions(ctx context.Context, authTokenId, query strin if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.InvalidParameter)) } - ret, err := r.searchSessions(ctx, w.Condition, w.Args, withAuthTokenId(authTokenId)) + ret, err := r.searchSessions(ctx, w.Condition, w.Args, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) searchSessions(ctx context.Context, condition string, searchArgs []any, opt ...Option) ([]*sessions.Session, error) { +func (r *Repository) searchSessions(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).searchSessions" switch { case condition == "": @@ -348,7 +348,7 @@ func (r *Repository) searchSessions(ctx context.Context, condition string, searc } var cachedSessions []*Session - if err := r.rw.SearchWhere(ctx, &cachedSessions, condition, searchArgs, db.WithLimit(-1)); err != nil { + if err := r.rw.SearchWhere(ctx, &cachedSessions, condition, searchArgs, db.WithLimit(opts.withMaxResultSetSize+1)); err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -360,7 +360,15 @@ func (r *Repository) searchSessions(ctx context.Context, condition string, searc } retSessions = append(retSessions, &sess) } - return retSessions, nil + + sr := &SearchResult{ + Sessions: retSessions, + } + if opts.withMaxResultSetSize > 0 && len(sr.Sessions) > opts.withMaxResultSetSize { + sr.Sessions = sr.Sessions[:opts.withMaxResultSetSize] + sr.Incomplete = true + } + return sr, nil } type Session struct { diff --git a/internal/clientcache/internal/cache/repository_sessions_test.go b/internal/clientcache/internal/cache/repository_sessions_test.go index 45e6847951..4627f8cd88 100644 --- a/internal/clientcache/internal/cache/repository_sessions_test.go +++ b/internal/clientcache/internal/cache/repository_sessions_test.go @@ -6,6 +6,7 @@ package cache import ( "context" "encoding/json" + "strconv" "sync" "testing" "time" @@ -239,7 +240,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err := r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Sessions, 2) // Refreshing again uses the refresh token and get additional sessions, appending // them to the response @@ -249,7 +250,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Sessions, 3) // Refreshing again wont return any more resources, but also none should be // removed @@ -259,7 +260,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Sessions, 3) // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, @@ -268,7 +269,7 @@ func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { got, err = r.ListSessions(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Sessions, 2) } func TestRepository_ListSessions(t *testing.T) { @@ -355,13 +356,73 @@ func TestRepository_ListSessions(t *testing.T) { t.Run("wrong user gets no sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt2.AuthTokenId) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Sessions) }) t.Run("correct token gets sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt1.AuthTokenId) assert.NoError(t, err) - assert.Len(t, l, len(ss)) - assert.ElementsMatch(t, l, ss) + assert.Len(t, l.Sessions, len(ss)) + assert.ElementsMatch(t, l.Sessions, ss) + }) +} + +func TestRepository_ListSessionsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*sessions.Session + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, session("s"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Sessions, + AuthTokenId: kt.AuthTokenId, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } @@ -471,13 +532,74 @@ func TestRepository_QuerySessions(t *testing.T) { t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Sessions) }) t.Run("correct token gets sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) - assert.Len(t, l, 2) - assert.ElementsMatch(t, l, ss[0:2]) + assert.Len(t, l.Sessions, 2) + assert.ElementsMatch(t, l.Sessions, ss[0:2]) + }) +} + +func TestRepository_QuerySessionsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*sessions.Session + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, session("t"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshSessions(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Sessions, + AuthTokenId: kt.AuthTokenId, + Query: `(id % 'session')`, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Sessions, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } diff --git a/internal/clientcache/internal/cache/repository_targets.go b/internal/clientcache/internal/cache/repository_targets.go index f52c9c0785..e5cfa6e1f9 100644 --- a/internal/clientcache/internal/cache/repository_targets.go +++ b/internal/clientcache/internal/cache/repository_targets.go @@ -291,20 +291,20 @@ func upsertTargets(ctx context.Context, w db.Writer, u *user, in []*targets.Targ return nil } -func (r *Repository) ListTargets(ctx context.Context, authTokenId string) ([]*targets.Target, error) { +func (r *Repository) ListTargets(ctx context.Context, authTokenId string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).ListTargets" switch { case authTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "auth token id is missing") } - ret, err := r.searchTargets(ctx, "true", nil, withAuthTokenId(authTokenId)) + ret, err := r.searchTargets(ctx, "true", nil, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) QueryTargets(ctx context.Context, authTokenId, query string) ([]*targets.Target, error) { +func (r *Repository) QueryTargets(ctx context.Context, authTokenId, query string, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).QueryTargets" switch { case authTokenId == "": @@ -317,14 +317,14 @@ func (r *Repository) QueryTargets(ctx context.Context, authTokenId, query string if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.InvalidParameter)) } - ret, err := r.searchTargets(ctx, w.Condition, w.Args, withAuthTokenId(authTokenId)) + ret, err := r.searchTargets(ctx, w.Condition, w.Args, append(opt, withAuthTokenId(authTokenId))...) if err != nil { return nil, errors.Wrap(ctx, err, op) } return ret, nil } -func (r *Repository) searchTargets(ctx context.Context, condition string, searchArgs []any, opt ...Option) ([]*targets.Target, error) { +func (r *Repository) searchTargets(ctx context.Context, condition string, searchArgs []any, opt ...Option) (*SearchResult, error) { const op = "cache.(Repository).searchTargets" switch { case condition == "": @@ -349,7 +349,7 @@ func (r *Repository) searchTargets(ctx context.Context, condition string, search } var cachedTargets []*Target - if err := r.rw.SearchWhere(ctx, &cachedTargets, condition, searchArgs, db.WithLimit(-1)); err != nil { + if err := r.rw.SearchWhere(ctx, &cachedTargets, condition, searchArgs, db.WithLimit(opts.withMaxResultSetSize+1)); err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -361,7 +361,15 @@ func (r *Repository) searchTargets(ctx context.Context, condition string, search } retTargets = append(retTargets, &tar) } - return retTargets, nil + + sr := &SearchResult{ + Targets: retTargets, + } + if opts.withMaxResultSetSize > 0 && len(sr.Targets) > opts.withMaxResultSetSize { + sr.Targets = sr.Targets[:opts.withMaxResultSetSize] + sr.Incomplete = true + } + return sr, nil } type Target struct { diff --git a/internal/clientcache/internal/cache/repository_targets_test.go b/internal/clientcache/internal/cache/repository_targets_test.go index a6923ac815..a3c705b513 100644 --- a/internal/clientcache/internal/cache/repository_targets_test.go +++ b/internal/clientcache/internal/cache/repository_targets_test.go @@ -6,6 +6,7 @@ package cache import ( "context" "encoding/json" + "strconv" "sync" "testing" @@ -264,7 +265,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err := r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Targets, 2) // Refreshing again uses the refresh token and get additional sessions, appending // them to the response @@ -274,7 +275,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Targets, 3) // Refreshing again wont return any more resources, but also none should be // removed @@ -284,7 +285,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 3) + assert.Len(t, got.Targets, 3) // Refresh again with the refresh token being reported as invalid. require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, @@ -293,7 +294,7 @@ func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { got, err = r.ListTargets(ctx, at.Id) require.NoError(t, err) - assert.Len(t, got, 2) + assert.Len(t, got.Targets, 2) } func TestRepository_ListTargets(t *testing.T) { @@ -349,13 +350,73 @@ func TestRepository_ListTargets(t *testing.T) { t.Run("wrong user gets no targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt2.AuthTokenId) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Targets) }) t.Run("correct token gets targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt1.AuthTokenId) assert.NoError(t, err) - assert.Len(t, l, len(ts)) - assert.ElementsMatch(t, l, ts) + assert.Len(t, l.Targets, len(ts)) + assert.ElementsMatch(t, l.Targets, ts) + }) +} + +func TestRepository_ListTargetsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*targets.Target + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, target("t"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshTargets(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Targets, + AuthTokenId: kt.AuthTokenId, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } @@ -456,13 +517,74 @@ func TestRepository_QueryTargets(t *testing.T) { t.Run("wrong token gets no targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query) assert.NoError(t, err) - assert.Empty(t, l) + assert.Empty(t, l.Targets) }) t.Run("correct token gets targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt1.AuthTokenId, query) assert.NoError(t, err) - assert.Len(t, l, 2) - assert.ElementsMatch(t, l, ts[0:2]) + assert.Len(t, l.Targets, 2) + assert.ElementsMatch(t, l.Targets, ts[0:2]) + }) +} + +func TestRepository_QueryTargetsLimiting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := &user{ + Id: "u", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at", + Token: "at_token", + UserId: u.Id, + } + kt := KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id} + + atMap := map[ringToken]*authtokens.AuthToken{ + {"k", "t"}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + var ts []*targets.Target + for i := 0; i < defaultLimitedResultSetSize*2; i++ { + ts = append(ts, target("t"+strconv.Itoa(i))) + } + require.NoError(t, r.refreshTargets(ctx, u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) + + searchService, err := NewSearchService(ctx, r) + require.NoError(t, err) + params := SearchParams{ + Resource: Targets, + AuthTokenId: kt.AuthTokenId, + Query: `(name % 'name')`, + } + + t.Run("default limit", func(t *testing.T) { + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("custom limit", func(t *testing.T) { + params.MaxResultSetSize = 20 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, params.MaxResultSetSize) + assert.True(t, searchResult.Incomplete) + }) + t.Run("no limit", func(t *testing.T) { + params.MaxResultSetSize = -1 + searchResult, err := searchService.Search(ctx, params) + require.NoError(t, err) + assert.Len(t, searchResult.Targets, defaultLimitedResultSetSize*2) + assert.False(t, searchResult.Incomplete) }) } diff --git a/internal/clientcache/internal/cache/search.go b/internal/clientcache/internal/cache/search.go index 2239030228..a26aa1a9c9 100644 --- a/internal/clientcache/internal/cache/search.go +++ b/internal/clientcache/internal/cache/search.go @@ -55,6 +55,8 @@ type SearchParams struct { Query string // the optional bexpr filter string that all results will be filtered by Filter string + // Max result set size is an override to the default max result set size + MaxResultSetSize int } // SearchResult returns the results from searching the cache. @@ -62,6 +64,10 @@ type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` Targets []*targets.Target `json:"targets,omitempty"` Sessions []*sessions.Session `json:"sessions,omitempty"` + + // Incomplete is true if the search results are incomplete, that is, we are + // returning only a subset based on the max result set size + Incomplete bool `json:"incomplete,omitempty"` } // SearchService is a domain service that can search across all resources in the @@ -83,22 +89,40 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er ResolvableAliases: &resourceSearchFns[*aliases.Alias]{ list: repo.ListResolvableAliases, query: repo.QueryResolvableAliases, - searchResult: func(a []*aliases.Alias) *SearchResult { - return &SearchResult{ResolvableAliases: a} + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*aliases.Alias, 0, len(in.ResolvableAliases)) + for _, item := range in.ResolvableAliases { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.ResolvableAliases = finalResults }, }, Targets: &resourceSearchFns[*targets.Target]{ list: repo.ListTargets, query: repo.QueryTargets, - searchResult: func(t []*targets.Target) *SearchResult { - return &SearchResult{Targets: t} + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*targets.Target, 0, len(in.Targets)) + for _, item := range in.Targets { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.Targets = finalResults }, }, Sessions: &resourceSearchFns[*sessions.Session]{ list: repo.ListSessions, query: repo.QuerySessions, - searchResult: func(s []*sessions.Session) *SearchResult { - return &SearchResult{Sessions: s} + filter: func(in *SearchResult, e *bexpr.Evaluator) { + finalResults := make([]*sessions.Session, 0, len(in.Sessions)) + for _, item := range in.Sessions { + if m, err := e.Evaluate(filterItem{item}); err == nil && m { + finalResults = append(finalResults, item) + } + } + in.Sessions = finalResults }, }, }, @@ -160,19 +184,15 @@ type resourceSearchFns[T any] struct { // list takes a context and an auth token and returns all resources for the // user of that auth token. If the provided auth token is not in the cache // an empty slice and no error is returned. - list func(context.Context, string) ([]T, error) + list func(context.Context, string, ...Option) (*SearchResult, error) // query takes a context, an auth token, and a query string and returns all // resources for that auth token that matches the provided query parameter. // If the provided auth token is not in the cache an empty slice and no // error is returned. - query func(context.Context, string, string) ([]T, error) - // searchResult is a function which provides a SearchResult based on the - // type of T. SearchResult contains different fields for the different - // resource types returned, so for example if T is *targets.Target the - // returned SearchResult will have it's "Targets" field populated so the - // searchResult should take the passed in paramater and assign it to the - // appropriate field in the SearchResult. - searchResult func([]T) *SearchResult + query func(context.Context, string, string, ...Option) (*SearchResult, error) + // filter takes results and a ready-to-use evaluator and filters the items + // in the result + filter func(*SearchResult, *bexpr.Evaluator) } // resourceSearcher is an interface that only resourceSearchFns[T] is expected @@ -191,33 +211,29 @@ type resourceSearcher interface { func (l *resourceSearchFns[T]) search(ctx context.Context, p SearchParams) (*SearchResult, error) { const op = "cache.(resourceSearchFns).search" - var found []T + var found *SearchResult var err error switch p.Query { case "": - found, err = l.list(ctx, p.AuthTokenId) + found, err = l.list(ctx, p.AuthTokenId, WithMaxResultSetSize(p.MaxResultSetSize)) default: - found, err = l.query(ctx, p.AuthTokenId, p.Query) + found, err = l.query(ctx, p.AuthTokenId, p.Query, WithMaxResultSetSize(p.MaxResultSetSize)) } if err != nil { return nil, errors.Wrap(ctx, err, op) } if p.Filter == "" { - return l.searchResult(found), nil + return found, nil } e, err := bexpr.CreateEvaluator(p.Filter, bexpr.WithTagName("json")) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("couldn't build filter"), errors.WithCode(errors.InvalidParameter)) } - finalResults := make([]T, 0, len(found)) - for _, item := range found { - if m, err := e.Evaluate(filterItem{item}); err == nil && m { - finalResults = append(finalResults, item) - } - } - return l.searchResult(finalResults), nil + + l.filter(found, e) + return found, nil } type filterItem struct { diff --git a/internal/clientcache/internal/daemon/404_handler.go b/internal/clientcache/internal/daemon/404_handler.go index e373319afe..4216cb9ef4 100644 --- a/internal/clientcache/internal/daemon/404_handler.go +++ b/internal/clientcache/internal/daemon/404_handler.go @@ -9,7 +9,7 @@ import ( ) // new404Func creates a handler that returns a custom 404 error message. -func new404Func(ctx context.Context) http.HandlerFunc { +func new404Func(_ context.Context) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { writeError(w, "Not found", http.StatusNotFound) } diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index d79a74b781..f450a76491 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -26,14 +26,16 @@ type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` Targets []*targets.Target `json:"targets,omitempty"` Sessions []*sessions.Session `json:"sessions,omitempty"` + Incomplete bool `json:"incomplete,omitempty"` } const ( - filterKey = "filter" - queryKey = "query" - resourceKey = "resource" - forceRefreshKey = "force_refresh" - authTokenIdKey = "auth_token_id" + filterKey = "filter" + queryKey = "query" + resourceKey = "resource" + forceRefreshKey = "force_refresh" + authTokenIdKey = "auth_token_id" + maxResultSetSizeKey = "max_result_set_size" ) func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshService *cache.RefreshService, logger hclog.Logger) (http.HandlerFunc, error) { @@ -54,8 +56,11 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe return func(w http.ResponseWriter, r *http.Request) { reqCtx := r.Context() - resource := r.URL.Query().Get(resourceKey) - authTokenId := r.URL.Query().Get(authTokenIdKey) + q := r.URL.Query() + resource := q.Get(resourceKey) + authTokenId := q.Get(authTokenIdKey) + maxResultSetSizeStr := q.Get(maxResultSetSizeKey) + maxResultSetSizeInt, maxResultSetSizeIntErr := strconv.Atoi(maxResultSetSizeStr) searchableResource := cache.ToSearchableResource(resource) switch { @@ -71,6 +76,14 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s is a required field but was empty", authTokenIdKey))) writeError(w, fmt.Sprintf("%s is a required field but was empty", authTokenIdKey), http.StatusBadRequest) return + case maxResultSetSizeStr != "" && maxResultSetSizeIntErr != nil: + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s is not able to be parsed as an integer", maxResultSetSizeStr))) + writeError(w, fmt.Sprintf("%s is not able to be parsed as an integer", maxResultSetSizeStr), http.StatusBadRequest) + return + case maxResultSetSizeInt < -1: + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%s must be greater than or equal to -1", maxResultSetSizeStr))) + writeError(w, fmt.Sprintf("%s must be greater than or equal to -1", maxResultSetSizeStr), http.StatusBadRequest) + return } t, err := repo.LookupToken(reqCtx, authTokenId, cache.WithUpdateLastAccessedTime(true)) @@ -114,10 +127,11 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe filter := r.URL.Query().Get(filterKey) res, err := s.Search(reqCtx, cache.SearchParams{ - AuthTokenId: authTokenId, - Resource: searchableResource, - Query: query, - Filter: filter, + AuthTokenId: authTokenId, + Resource: searchableResource, + Query: query, + Filter: filter, + MaxResultSetSize: maxResultSetSizeInt, }) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("when performing search", "auth_token_id", authTokenId, "resource", searchableResource, "query", query, "filter", filter)) @@ -152,6 +166,7 @@ func toApiResult(sr *cache.SearchResult) *SearchResult { ResolvableAliases: sr.ResolvableAliases, Targets: sr.Targets, Sessions: sr.Sessions, + Incomplete: sr.Incomplete, } }