From 54776091e32af2f2f110f2bf10d0617bd5677aba Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 29 Aug 2024 17:48:10 -0400 Subject: [PATCH] Ensure only one refresh is running for a user/resource at a time (#5054) This ensures we don't have competing update processes running. When doing the inline RefreshForSearch, a specific error will be returned so we can hint in the API that we're in the middle of a refresh. --- .../clientcache/cmd/search/search_test.go | 8 +- .../clientcache/internal/cache/options.go | 16 ++ .../internal/cache/options_test.go | 12 ++ .../clientcache/internal/cache/refresh.go | 118 ++++++++-- .../internal/cache/refresh_test.go | 203 ++++++++++++++++++ .../internal/daemon/search_handler.go | 38 +++- 6 files changed, 370 insertions(+), 25 deletions(-) diff --git a/internal/clientcache/cmd/search/search_test.go b/internal/clientcache/cmd/search/search_test.go index d36ea5fa64..3718e49479 100644 --- a/internal/clientcache/cmd/search/search_test.go +++ b/internal/clientcache/cmd/search/search_test.go @@ -140,7 +140,9 @@ func TestSearch(t *testing.T) { assert.Nil(t, apiErr) assert.NotNil(t, resp) assert.NotNil(t, r) - assert.EqualValues(t, r, &daemon.SearchResult{}) + assert.EqualValues(t, r, &daemon.SearchResult{ + RefreshStatus: daemon.NotRefreshing, + }) }) t.Run("empty response from query", func(t *testing.T) { @@ -154,7 +156,9 @@ func TestSearch(t *testing.T) { assert.Nil(t, apiErr) assert.NotNil(t, resp) assert.NotNil(t, r) - assert.EqualValues(t, r, &daemon.SearchResult{}) + assert.EqualValues(t, r, &daemon.SearchResult{ + RefreshStatus: daemon.NotRefreshing, + }) }) t.Run("unsupported boundary instance", func(t *testing.T) { diff --git a/internal/clientcache/internal/cache/options.go b/internal/clientcache/internal/cache/options.go index 09e34ee9d8..36907f8833 100644 --- a/internal/clientcache/internal/cache/options.go +++ b/internal/clientcache/internal/cache/options.go @@ -9,6 +9,11 @@ import ( "github.com/hashicorp/go-dbw" ) +type testRefreshWaitChs struct { + firstSempahore chan struct{} + secondSemaphore chan struct{} +} + type options struct { withUpdateLastAccessedTime bool withDbType dbw.DbType @@ -19,6 +24,7 @@ type options struct { withSessionRetrievalFunc SessionRetrievalFunc withIgnoreSearchStaleness bool withMaxResultSetSize int + withTestRefreshWaitChs *testRefreshWaitChs } // Option - how options are passed as args @@ -113,3 +119,13 @@ func WithMaxResultSetSize(with int) Option { return nil } } + +// WithTestRefreshWaitChs provides an option for specifying channels to wait on +// before proceeding. This allows testing the logic that ensures only one is +// running at a time. +func WithTestRefreshWaitChs(with *testRefreshWaitChs) Option { + return func(o *options) error { + o.withTestRefreshWaitChs = with + return nil + } +} diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index aed96cd5ba..44074fdbe6 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -106,4 +106,16 @@ func Test_GetOpts(t *testing.T) { _, err = getOpts(WithMaxResultSetSize(-2)) require.Error(t, err) }) + t.Run("withTestRefreshWaitChs", func(t *testing.T) { + waitCh := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + opts, err := getOpts(WithTestRefreshWaitChs(waitCh)) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.Empty(t, testOpts.withTestRefreshWaitChs) + testOpts.withTestRefreshWaitChs = waitCh + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/clientcache/internal/cache/refresh.go b/internal/clientcache/internal/cache/refresh.go index b9a77a5372..6cc294be09 100644 --- a/internal/clientcache/internal/cache/refresh.go +++ b/internal/clientcache/internal/cache/refresh.go @@ -7,6 +7,8 @@ import ( "context" stderrors "errors" "fmt" + "sync" + "sync/atomic" "time" "github.com/hashicorp/boundary/api" @@ -17,9 +19,19 @@ import ( "github.com/hashicorp/go-hclog" ) +// This is used as an internal error to indicate that a refresh is already in +// progress, as a signal that the caller may want to handle it a different way. +var ErrRefreshInProgress error = stderrors.New("cache refresh in progress") + type RefreshService struct { repo *Repository + // syncSempaphores is used to prevent multiple refreshes from happening at + // the same time for a given user ID and resource type. The key is a string + // (the user ID + resource type) and the value is an atomic bool we can CAS + // with. + syncSemaphores sync.Map + logger hclog.Logger // the amount of time that should have passed since the last refresh for a // call to RefreshForSearch will cause a refresh request to be sent to a @@ -135,21 +147,25 @@ func (r *RefreshService) cacheSupportedUsers(ctx context.Context, in []*user) ([ return ret, nil } -// RefreshForSearch refreshes a specific resource type owned by the user associated -// with the provided auth token id, if it hasn't been refreshed in a certain -// amount of time. If the resources has been updated recently, or if it is known -// that the user's boundary instance doesn't support partial refreshing, this -// method returns without any requests to the boundary controller. -// While the criteria of whether the user will refresh or not is almost the same -// as the criteria used in Refresh, RefreshForSearch will not refresh any -// data if there is not a refresh token stored for the resource. It -// might make sense to change this in the future, but the reasoning used is -// that we should not be making an initial load of all resources while blocking -// a search query, in case we have not yet even attempted to load the resources -// for this user yet. -// Note: Currently, if the context timesout we stop refreshing completely and -// return to the caller. A possible enhancement in the future would be to return -// when the context is Done, but allow the refresh to proceed in the background. +// RefreshForSearch refreshes a specific resource type owned by the user +// associated with the provided auth token id, if it hasn't been refreshed in a +// certain amount of time. If the resources has been updated recently, or if it +// is known that the user's boundary instance doesn't support partial +// refreshing, this method returns without any requests to the boundary +// controller. While the criteria of whether the user will refresh or not is +// almost the same as the criteria used in Refresh, RefreshForSearch will not +// refresh any data if there is not a refresh token stored for the resource. It +// might make sense to change this in the future, but the reasoning used is that +// we should not be making an initial load of all resources while blocking a +// search query, in case we have not yet even attempted to load the resources +// for this user yet. Note: Currently, if the context timesout we stop +// refreshing completely and return to the caller. A possible enhancement in +// the future would be to return when the context is Done, but allow the refresh +// to proceed in the background. +// +// If a refresh is already running for that user and resource type we will +// return ErrRefreshInProgress. If so we will not start another one and we can +// make it clear in the response that cache filling is ongoing. func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid string, resourceType SearchableResource, opt ...Option) error { const op = "cache.(RefreshService).RefreshForSearch" if r.maxSearchRefreshTimeout > 0 { @@ -191,6 +207,8 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op, errors.WithoutEvent()) } + cacheKey := fmt.Sprintf("%s-%s", u.Id, resourceType) + switch resourceType { case ResolvableAliases: rtv, err := r.repo.lookupRefreshToken(ctx, u, resolvableAliasResourceType) @@ -198,6 +216,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op) } if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness { + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + return ErrRefreshInProgress + } + defer semaphore.(*atomic.Bool).Store(false) tokens, err := r.cleanAndPickAuthTokens(ctx, u) if err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -206,6 +229,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin if rtv != nil { args = append(args, "alias staleness", time.Since(rtv.UpdateTime)) } + + if opts.withTestRefreshWaitChs != nil { + close(opts.withTestRefreshWaitChs.firstSempahore) + <-opts.withTestRefreshWaitChs.secondSemaphore + } + r.logger.Debug("refreshing aliases before performing search", args...) if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil { return errors.Wrap(ctx, err, op) @@ -217,6 +246,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op, errors.WithoutEvent()) } if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness { + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + return ErrRefreshInProgress + } + defer semaphore.(*atomic.Bool).Store(false) + tokens, err := r.cleanAndPickAuthTokens(ctx, u) if err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -225,6 +260,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin if rtv != nil { args = append(args, "target staleness", time.Since(rtv.UpdateTime)) } + + if opts.withTestRefreshWaitChs != nil { + close(opts.withTestRefreshWaitChs.firstSempahore) + <-opts.withTestRefreshWaitChs.secondSemaphore + } + r.logger.Debug("refreshing targets before performing search", args...) if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -236,6 +277,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin return errors.Wrap(ctx, err, op) } if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness { + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + return ErrRefreshInProgress + } + defer semaphore.(*atomic.Bool).Store(false) + tokens, err := r.cleanAndPickAuthTokens(ctx, u) if err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -244,6 +291,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin if rtv != nil { args = append(args, "session staleness", time.Since(rtv.UpdateTime)) } + + if opts.withTestRefreshWaitChs != nil { + close(opts.withTestRefreshWaitChs.firstSempahore) + <-opts.withTestRefreshWaitChs.secondSemaphore + } + r.logger.Debug("refreshing sessions before performing search", args...) if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { return errors.Wrap(ctx, err, op, errors.WithoutEvent()) @@ -262,6 +315,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin // cache with the values retrieved there. Refresh accepts the options // WithTargetRetrievalFunc and WithSessionRetrievalFunc which overwrites the // default functions used to retrieve those resources from boundary. +// +// This shares the map of sync semaphores with RefreshForSearch, so if a refresh +// is already happening via either method, it will be skipped by the other. In +// the case of this function it won't return ErrRefreshInProgress, but it will +// log that the refresh is already in progress. func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error { const op = "cache.(RefreshService).Refresh" if err := r.repo.cleanExpiredOrOrphanedAuthTokens(ctx); err != nil { @@ -290,16 +348,32 @@ func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error { continue } - if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) - } - if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + cacheKey := fmt.Sprintf("%s-%s", u.Id, ResolvableAliases) + semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + semaphore.(*atomic.Bool).Store(false) } - if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { - retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + + cacheKey = fmt.Sprintf("%s-%s", u.Id, Targets) + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + semaphore.(*atomic.Bool).Store(false) } + cacheKey = fmt.Sprintf("%s-%s", u.Id, Sessions) + semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool)) + if semaphore.(*atomic.Bool).CompareAndSwap(false, true) { + if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil { + retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id)))) + } + semaphore.(*atomic.Bool).Store(false) + } } return retErr } diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index a24dedf198..bd0905de93 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -803,6 +803,209 @@ func TestRefreshForSearch(t *testing.T) { }) } +func TestRefreshNonBlocking(t *testing.T) { + ctx := context.Background() + + boundaryAddr := "address" + u := &user{Id: "u1", Address: boundaryAddr} + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + ExpirationTime: time.Now().Add(time.Minute), + } + + boundaryAuthTokens := []*authtokens.AuthToken{at} + atMap := make(map[ringToken]*authtokens.AuthToken) + + atMap[ringToken{"k", "t"}] = at + + t.Run("targets refreshed for searching", func(t *testing.T) { + t.Parallel() + s, err := db.Open(ctx) + require.NoError(t, err) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + rs, err := NewRefreshService(ctx, r, hclog.NewNullLogger(), time.Millisecond, 0) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) + + retTargets := []*targets.Target{ + target("1"), + target("2"), + target("3"), + target("4"), + } + opts := []Option{ + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + [][]*targets.Target{ + retTargets[:3], + retTargets[3:], + }, + [][]string{ + nil, + {retTargets[0].Id, retTargets[1].Id}, + }, + )), + } + + refreshWaitChs := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + wg := new(sync.WaitGroup) + wg.Add(2) + extraOpts := []Option{WithTestRefreshWaitChs(refreshWaitChs), WithIgnoreSearchStaleness(true)} + go func() { + defer wg.Done() + blockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Targets, append(opts, extraOpts...)...) + assert.NoError(t, blockingRefreshError) + }() + go func() { + defer wg.Done() + // Sleep here to ensure ordering of the calls since both goroutines + // are spawned at the same time + <-refreshWaitChs.firstSempahore + nonblockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Targets, append(opts, extraOpts...)...) + close(refreshWaitChs.secondSemaphore) + assert.ErrorIs(t, nonblockingRefreshError, ErrRefreshInProgress) + }() + wg.Wait() + + // Unlike in the TestRefreshForSearch test, since we did a force + // refresh we do expect to see values + cachedTargets, err := r.ListTargets(ctx, at.Id) + assert.NoError(t, err) + assert.ElementsMatch(t, retTargets[:3], cachedTargets.Targets) + }) + + t.Run("sessions refreshed for searching", func(t *testing.T) { + t.Parallel() + s, err := db.Open(ctx) + require.NoError(t, err) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + rs, err := NewRefreshService(ctx, r, hclog.NewNullLogger(), 0, 0) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) + + retSess := []*sessions.Session{ + session("1"), + session("2"), + session("3"), + session("4"), + } + opts := []Option{ + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + [][]*sessions.Session{ + retSess[:3], + retSess[3:], + }, + [][]string{ + nil, + {retSess[0].Id, retSess[1].Id}, + }, + )), + } + + refreshWaitChs := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + wg := new(sync.WaitGroup) + wg.Add(2) + extraOpts := []Option{WithTestRefreshWaitChs(refreshWaitChs), WithIgnoreSearchStaleness(true)} + go func() { + defer wg.Done() + blockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Sessions, append(opts, extraOpts...)...) + assert.NoError(t, blockingRefreshError) + }() + go func() { + defer wg.Done() + // Sleep here to ensure ordering of the calls since both goroutines + // are spawned at the same time + <-refreshWaitChs.firstSempahore + nonblockingRefreshError := rs.RefreshForSearch(ctx, at.Id, Sessions, append(opts, extraOpts...)...) + close(refreshWaitChs.secondSemaphore) + assert.ErrorIs(t, nonblockingRefreshError, ErrRefreshInProgress) + }() + + wg.Wait() + + // Unlike in the TestRefreshForSearch test, since we are did a force + // refresh we do expect to see values + cachedSessions, err := r.ListSessions(ctx, at.Id) + assert.NoError(t, err) + assert.ElementsMatch(t, retSess[:3], cachedSessions.Sessions) + }) + + t.Run("aliases refreshed for searching", func(t *testing.T) { + t.Parallel() + s, err := db.Open(ctx) + require.NoError(t, err) + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(boundaryAuthTokens)) + require.NoError(t, err) + rs, err := NewRefreshService(ctx, r, hclog.Default(), 0, 0) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, boundaryAddr, KeyringToken{KeyringType: "k", TokenName: "t", AuthTokenId: at.Id})) + + retAl := []*aliases.Alias{ + alias("1"), + alias("2"), + alias("3"), + alias("4"), + } + opts := []Option{ + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithAliasRetrievalFunc(testStaticResourceRetrievalFuncForId[*aliases.Alias](t, + [][]*aliases.Alias{ + retAl[:3], + retAl[3:], + }, + [][]string{ + nil, + {retAl[0].Id, retAl[1].Id}, + }, + )), + } + + refreshWaitChs := &testRefreshWaitChs{ + firstSempahore: make(chan struct{}), + secondSemaphore: make(chan struct{}), + } + wg := new(sync.WaitGroup) + wg.Add(2) + extraOpts := []Option{WithTestRefreshWaitChs(refreshWaitChs), WithIgnoreSearchStaleness(true)} + go func() { + defer wg.Done() + blockingRefreshError := rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, append(opts, extraOpts...)...) + assert.NoError(t, blockingRefreshError) + }() + go func() { + defer wg.Done() + // Sleep here to ensure ordering of the calls since both goroutines + // are spawned at the same time + <-refreshWaitChs.firstSempahore + nonblockingRefreshError := rs.RefreshForSearch(ctx, at.Id, ResolvableAliases, append(opts, extraOpts...)...) + close(refreshWaitChs.secondSemaphore) + assert.ErrorIs(t, nonblockingRefreshError, ErrRefreshInProgress) + }() + + wg.Wait() + + // Unlike in the TestRefreshForSearch test, since we are did a force + // refresh we do expect to see values + cachedAliases, err := r.ListResolvableAliases(ctx, at.Id) + require.NoError(t, err) + assert.ElementsMatch(t, retAl[:3], cachedAliases.ResolvableAliases) + }) +} + func TestRefresh(t *testing.T) { ctx := context.Background() diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index 2fbdfbd300..97dffa5cc2 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -6,6 +6,7 @@ package daemon import ( "context" "encoding/json" + stderrors "errors" "fmt" "net/http" "strconv" @@ -22,6 +23,22 @@ import ( "github.com/hashicorp/go-hclog" ) +type RefreshStatus string + +const ( + // Not refreshing means the result is complete, that is, the cache is not in + // the process of being built or refreshed + NotRefreshing RefreshStatus = "not-refreshing" + // Refreshing means that the cache is in the process of being refreshed, so + // the result may not be complete and the caller should try again later for + // more complete results + Refreshing RefreshStatus = "refreshing" + // RefreshError means that there was an error refreshing the cache. It says + // nothing about the completeness of the result, only that when attempting + // to refresh the cache in-line with the search an error was encountered. + RefreshError RefreshStatus = "refresh-error" +) + // SearchResult is the struct returned to search requests. type SearchResult struct { ResolvableAliases []*aliases.Alias `json:"resolvable_aliases,omitempty"` @@ -29,6 +46,8 @@ type SearchResult struct { Sessions []*sessions.Session `json:"sessions,omitempty"` ImplicitScopes []*scopes.Scope `json:"implicit_scopes,omitempty"` Incomplete bool `json:"incomplete,omitempty"` + RefreshStatus RefreshStatus `json:"refresh_status,omitempty"` + RefreshError string `json:"refresh_error,omitempty"` } const ( @@ -131,13 +150,20 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe opts = append(opts, cache.WithIgnoreSearchStaleness(true)) } + var refreshError error // Refresh the resources for the provided user, if possible. This is best // effort, so if there is any problem refreshing, we just log the error // and move on to handling the search request. switch searchableResource { case cache.ImplicitScopes: + // This is not able to be refreshed, so continue on default: - if err := refreshService.RefreshForSearch(reqCtx, authTokenId, searchableResource, opts...); err != nil { + refreshError = refreshService.RefreshForSearch(reqCtx, authTokenId, searchableResource, opts...) + switch { + case refreshError == nil, + stderrors.Is(refreshError, cache.ErrRefreshInProgress): + // Don't event in these cases + default: // we don't stop the search, we just log that the inline refresh failed event.WriteError(ctx, op, err, event.WithInfoMsg("when refreshing the resources inline for search", "auth_token_id", authTokenId, "resource", searchableResource)) } @@ -166,6 +192,16 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe } apiRes := toApiResult(res) + switch { + case refreshError == nil: + apiRes.RefreshStatus = NotRefreshing + case stderrors.Is(refreshError, cache.ErrRefreshInProgress): + apiRes.RefreshStatus = Refreshing + default: + apiRes.RefreshStatus = RefreshError + apiRes.RefreshError = refreshError.Error() + } + j, err := json.Marshal(apiRes) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("when marshaling search result to JSON", "auth_token_id", authTokenId, "resource", searchableResource, "query", query, "filter", filter))