Skip to content

Commit

Permalink
Ensure only one refresh is running for a user/resource at a time (#5054)
Browse files Browse the repository at this point in the history
This ensures we don't have competing update processes running. When doing the
inline RefreshForSearch, a specific error will be returned so we can hint in
the API that we're in the middle of a refresh.
  • Loading branch information
jefferai committed Sep 11, 2024
1 parent fe5ed5b commit da24897
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 25 deletions.
8 changes: 6 additions & 2 deletions internal/clientcache/cmd/search/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ func TestSearch(t *testing.T) {
assert.Nil(t, apiErr)
assert.NotNil(t, resp)
assert.NotNil(t, r)
assert.EqualValues(t, r, &daemon.SearchResult{})
assert.EqualValues(t, r, &daemon.SearchResult{
RefreshStatus: daemon.NotRefreshing,
})
})

t.Run("empty response from query", func(t *testing.T) {
Expand All @@ -150,7 +152,9 @@ func TestSearch(t *testing.T) {
assert.Nil(t, apiErr)
assert.NotNil(t, resp)
assert.NotNil(t, r)
assert.EqualValues(t, r, &daemon.SearchResult{})
assert.EqualValues(t, r, &daemon.SearchResult{
RefreshStatus: daemon.NotRefreshing,
})
})

t.Run("unsupported boundary instance", func(t *testing.T) {
Expand Down
16 changes: 16 additions & 0 deletions internal/clientcache/internal/cache/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
"github.com/hashicorp/go-dbw"
)

type testRefreshWaitChs struct {
firstSempahore chan struct{}
secondSemaphore chan struct{}
}

type options struct {
withUpdateLastAccessedTime bool
withDbType dbw.DbType
Expand All @@ -19,6 +24,7 @@ type options struct {
withSessionRetrievalFunc SessionRetrievalFunc
withIgnoreSearchStaleness bool
withMaxResultSetSize int
withTestRefreshWaitChs *testRefreshWaitChs
}

// Option - how options are passed as args
Expand Down Expand Up @@ -113,3 +119,13 @@ func WithMaxResultSetSize(with int) Option {
return nil
}
}

// WithTestRefreshWaitChs provides an option for specifying channels to wait on
// before proceeding. This allows testing the logic that ensures only one is
// running at a time.
func WithTestRefreshWaitChs(with *testRefreshWaitChs) Option {
return func(o *options) error {
o.withTestRefreshWaitChs = with
return nil
}
}
12 changes: 12 additions & 0 deletions internal/clientcache/internal/cache/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,16 @@ func Test_GetOpts(t *testing.T) {
_, err = getOpts(WithMaxResultSetSize(-2))
require.Error(t, err)
})
t.Run("withTestRefreshWaitChs", func(t *testing.T) {
waitCh := &testRefreshWaitChs{
firstSempahore: make(chan struct{}),
secondSemaphore: make(chan struct{}),
}
opts, err := getOpts(WithTestRefreshWaitChs(waitCh))
require.NoError(t, err)
testOpts := getDefaultOptions()
assert.Empty(t, testOpts.withTestRefreshWaitChs)
testOpts.withTestRefreshWaitChs = waitCh
assert.Equal(t, opts, testOpts)
})
}
118 changes: 96 additions & 22 deletions internal/clientcache/internal/cache/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
stderrors "errors"
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/hashicorp/boundary/api"
Expand All @@ -17,9 +19,19 @@ import (
"github.com/hashicorp/go-hclog"
)

// This is used as an internal error to indicate that a refresh is already in
// progress, as a signal that the caller may want to handle it a different way.
var ErrRefreshInProgress error = stderrors.New("cache refresh in progress")

type RefreshService struct {
repo *Repository

// syncSempaphores is used to prevent multiple refreshes from happening at
// the same time for a given user ID and resource type. The key is a string
// (the user ID + resource type) and the value is an atomic bool we can CAS
// with.
syncSemaphores sync.Map

logger hclog.Logger
// the amount of time that should have passed since the last refresh for a
// call to RefreshForSearch will cause a refresh request to be sent to a
Expand Down Expand Up @@ -135,21 +147,25 @@ func (r *RefreshService) cacheSupportedUsers(ctx context.Context, in []*user) ([
return ret, nil
}

// RefreshForSearch refreshes a specific resource type owned by the user associated
// with the provided auth token id, if it hasn't been refreshed in a certain
// amount of time. If the resources has been updated recently, or if it is known
// that the user's boundary instance doesn't support partial refreshing, this
// method returns without any requests to the boundary controller.
// While the criteria of whether the user will refresh or not is almost the same
// as the criteria used in Refresh, RefreshForSearch will not refresh any
// data if there is not a refresh token stored for the resource. It
// might make sense to change this in the future, but the reasoning used is
// that we should not be making an initial load of all resources while blocking
// a search query, in case we have not yet even attempted to load the resources
// for this user yet.
// Note: Currently, if the context timesout we stop refreshing completely and
// return to the caller. A possible enhancement in the future would be to return
// when the context is Done, but allow the refresh to proceed in the background.
// RefreshForSearch refreshes a specific resource type owned by the user
// associated with the provided auth token id, if it hasn't been refreshed in a
// certain amount of time. If the resources has been updated recently, or if it
// is known that the user's boundary instance doesn't support partial
// refreshing, this method returns without any requests to the boundary
// controller. While the criteria of whether the user will refresh or not is
// almost the same as the criteria used in Refresh, RefreshForSearch will not
// refresh any data if there is not a refresh token stored for the resource. It
// might make sense to change this in the future, but the reasoning used is that
// we should not be making an initial load of all resources while blocking a
// search query, in case we have not yet even attempted to load the resources
// for this user yet. Note: Currently, if the context timesout we stop
// refreshing completely and return to the caller. A possible enhancement in
// the future would be to return when the context is Done, but allow the refresh
// to proceed in the background.
//
// If a refresh is already running for that user and resource type we will
// return ErrRefreshInProgress. If so we will not start another one and we can
// make it clear in the response that cache filling is ongoing.
func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid string, resourceType SearchableResource, opt ...Option) error {
const op = "cache.(RefreshService).RefreshForSearch"
if r.maxSearchRefreshTimeout > 0 {
Expand Down Expand Up @@ -191,13 +207,20 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
}

cacheKey := fmt.Sprintf("%s-%s", u.Id, resourceType)

switch resourceType {
case ResolvableAliases:
rtv, err := r.repo.lookupRefreshToken(ctx, u, resolvableAliasResourceType)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness {
semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
return ErrRefreshInProgress
}
defer semaphore.(*atomic.Bool).Store(false)
tokens, err := r.cleanAndPickAuthTokens(ctx, u)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
Expand All @@ -206,6 +229,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
if rtv != nil {
args = append(args, "alias staleness", time.Since(rtv.UpdateTime))
}

if opts.withTestRefreshWaitChs != nil {
close(opts.withTestRefreshWaitChs.firstSempahore)
<-opts.withTestRefreshWaitChs.secondSemaphore
}

r.logger.Debug("refreshing aliases before performing search", args...)
if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil {
return errors.Wrap(ctx, err, op)
Expand All @@ -217,6 +246,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
}
if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness {
semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
return ErrRefreshInProgress
}
defer semaphore.(*atomic.Bool).Store(false)

tokens, err := r.cleanAndPickAuthTokens(ctx, u)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
Expand All @@ -225,6 +260,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
if rtv != nil {
args = append(args, "target staleness", time.Since(rtv.UpdateTime))
}

if opts.withTestRefreshWaitChs != nil {
close(opts.withTestRefreshWaitChs.firstSempahore)
<-opts.withTestRefreshWaitChs.secondSemaphore
}

r.logger.Debug("refreshing targets before performing search", args...)
if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil {
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
Expand All @@ -236,6 +277,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
return errors.Wrap(ctx, err, op)
}
if opts.withIgnoreSearchStaleness || rtv != nil && time.Since(rtv.UpdateTime) > r.maxSearchStaleness {
semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if !semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
return ErrRefreshInProgress
}
defer semaphore.(*atomic.Bool).Store(false)

tokens, err := r.cleanAndPickAuthTokens(ctx, u)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
Expand All @@ -244,6 +291,12 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
if rtv != nil {
args = append(args, "session staleness", time.Since(rtv.UpdateTime))
}

if opts.withTestRefreshWaitChs != nil {
close(opts.withTestRefreshWaitChs.firstSempahore)
<-opts.withTestRefreshWaitChs.secondSemaphore
}

r.logger.Debug("refreshing sessions before performing search", args...)
if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil {
return errors.Wrap(ctx, err, op, errors.WithoutEvent())
Expand All @@ -262,6 +315,11 @@ func (r *RefreshService) RefreshForSearch(ctx context.Context, authTokenid strin
// cache with the values retrieved there. Refresh accepts the options
// WithTargetRetrievalFunc and WithSessionRetrievalFunc which overwrites the
// default functions used to retrieve those resources from boundary.
//
// This shares the map of sync semaphores with RefreshForSearch, so if a refresh
// is already happening via either method, it will be skipped by the other. In
// the case of this function it won't return ErrRefreshInProgress, but it will
// log that the refresh is already in progress.
func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error {
const op = "cache.(RefreshService).Refresh"
if err := r.repo.cleanExpiredOrOrphanedAuthTokens(ctx); err != nil {
Expand Down Expand Up @@ -290,16 +348,32 @@ func (r *RefreshService) Refresh(ctx context.Context, opt ...Option) error {
continue
}

if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
cacheKey := fmt.Sprintf("%s-%s", u.Id, ResolvableAliases)
semaphore, _ := r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
if err := r.repo.refreshResolvableAliases(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
semaphore.(*atomic.Bool).Store(false)
}
if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))

cacheKey = fmt.Sprintf("%s-%s", u.Id, Targets)
semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
if err := r.repo.refreshTargets(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
semaphore.(*atomic.Bool).Store(false)
}

cacheKey = fmt.Sprintf("%s-%s", u.Id, Sessions)
semaphore, _ = r.syncSemaphores.LoadOrStore(cacheKey, new(atomic.Bool))
if semaphore.(*atomic.Bool).CompareAndSwap(false, true) {
if err := r.repo.refreshSessions(ctx, u, tokens, opt...); err != nil {
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("for user id %s", u.Id))))
}
semaphore.(*atomic.Bool).Store(false)
}
}
return retErr
}
Expand Down
Loading

0 comments on commit da24897

Please sign in to comment.