diff --git a/cli/global.go b/cli/global.go index 0a2f618cd..56b8e8a7c 100644 --- a/cli/global.go +++ b/cli/global.go @@ -77,6 +77,13 @@ func (a *AwsVault) Keyring() (keyring.Keyring, error) { if err != nil { return nil, err } + if a.KeyringBackend == string(keyring.KeychainBackend) { + lockKey := a.KeyringConfig.KeychainName + if lockKey == "" { + lockKey = "aws-vault" + } + a.keyringImpl = vault.NewKeychainLockedKeyring(a.keyringImpl, lockKey) + } } return a.keyringImpl, nil diff --git a/go.mod b/go.mod index fd6ca3f94..d8d91eacb 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/byteness/keyring v1.7.2 github.com/charmbracelet/huh v0.8.0 github.com/charmbracelet/lipgloss v1.1.0 + github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.7.0 github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-tty v0.0.7 diff --git a/go.sum b/go.sum index 30ea33351..e5e4f9d07 100644 --- a/go.sum +++ b/go.sum @@ -114,6 +114,8 @@ github.com/extism/go-sdk v1.7.1 h1:lWJos6uY+tRFdlIHR+SJjwFDApY7OypS/2nMhiVQ9Sw= github.com/extism/go-sdk v1.7.1/go.mod h1:IT+Xdg5AZM9hVtpFUA+uZCJMge/hbvshl8bwzLtFyKA= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 1a382d6b3..3b4aca65e 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -2,7 +2,9 @@ package vault import ( "context" + "fmt" "log" + "os" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -21,26 +23,141 @@ type CachedSessionProvider struct { SessionProvider StsSessionProvider Keyring *SessionKeyring ExpiryWindow time.Duration + sessionLock SessionCacheLock + sessionLockWait time.Duration + sessionLockLog time.Duration + sessionNow func() time.Time + sessionSleep func(context.Context, time.Duration) error + sessionLogf func(string, ...any) +} + +const ( + defaultSessionLockWaitDelay = 100 * time.Millisecond + defaultSessionLockLogEvery = 15 * time.Second + defaultSessionLockWarnAfter = 5 * time.Second +) + +func defaultSessionSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (p *CachedSessionProvider) ensureSessionDependencies() { + if p.sessionLock == nil { + p.sessionLock = NewDefaultSessionCacheLock(p.SessionKey.StringForMatching()) + } + if p.sessionLockWait == 0 { + p.sessionLockWait = defaultSessionLockWaitDelay + } + if p.sessionLockLog == 0 { + p.sessionLockLog = defaultSessionLockLogEvery + } + if p.sessionNow == nil { + p.sessionNow = time.Now + } + if p.sessionSleep == nil { + p.sessionSleep = defaultSessionSleep + } + if p.sessionLogf == nil { + p.sessionLogf = log.Printf + } } func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { - creds, err := p.Keyring.Get(p.SessionKey) + p.ensureSessionDependencies() + + creds, cached, err := p.getCachedSession() + if err == nil && cached { + return creds, nil + } - if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow { - // lookup missed, we need to create a new one. - creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) + return p.getSessionWithLock(ctx) +} + +func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, cached bool, err error) { + creds, err = p.Keyring.Get(p.SessionKey) + if err != nil { + return nil, false, err + } + if time.Until(*creds.Expiration) < p.ExpiryWindow { + return nil, false, nil + } + log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) + return creds, true, nil +} + +func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { + waiter := newLockWaiter( + p.sessionLock, + "Waiting for session lock at %s\n", + "Waiting for session lock at %s", + p.sessionLockWait, + p.sessionLockLog, + defaultSessionLockWarnAfter, + p.sessionNow, + p.sessionSleep, + p.sessionLogf, + func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + ) + + for { + creds, cached, err := p.getCachedSession() + if err == nil && cached { + return creds, nil + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + + locked, err := p.sessionLock.TryLock() if err != nil { return nil, err } - err = p.Keyring.Set(p.SessionKey, creds) - if err != nil { + if locked { + creds, cached, err = p.getCachedSession() + if err == nil && cached { + unlockErr := p.sessionLock.Unlock() + if unlockErr != nil { + return nil, unlockErr + } + return creds, nil + } + + creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) + if err != nil { + unlockErr := p.sessionLock.Unlock() + if unlockErr != nil { + return nil, unlockErr + } + return nil, err + } + if err = p.Keyring.Set(p.SessionKey, creds); err != nil { + unlockErr := p.sessionLock.Unlock() + if unlockErr != nil { + return nil, unlockErr + } + return nil, err + } + + if err = p.sessionLock.Unlock(); err != nil { + return nil, err + } + + return creds, nil + } + if err = waiter.sleepAfterMiss(ctx); err != nil { return nil, err } - } else { - log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) } - - return creds, nil } // Retrieve returns cached credentials from the keyring, or if no credentials are cached diff --git a/vault/cachedsessionprovider_lock_test.go b/vault/cachedsessionprovider_lock_test.go new file mode 100644 index 000000000..f34f5c214 --- /dev/null +++ b/vault/cachedsessionprovider_lock_test.go @@ -0,0 +1,251 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" +) + +type testSessionProvider struct { + creds *types.Credentials + calls int + onRetrieve func() +} + +func (p *testSessionProvider) RetrieveStsCredentials(context.Context) (*types.Credentials, error) { + p.calls++ + if p.onRetrieve != nil { + p.onRetrieve() + } + return p.creds, nil +} + +func (p *testSessionProvider) Retrieve(context.Context) (aws.Credentials, error) { + return aws.Credentials{}, nil +} + +type lockCheckingKeyring struct { + keyring.Keyring + setCalls int + setLock *testLock +} + +func (k *lockCheckingKeyring) Set(item keyring.Item) error { + k.setCalls++ + if k.setLock != nil && !k.setLock.locked { + return fmt.Errorf("lock not held during cache set") + } + return k.Keyring.Set(item) +} + +func newTestSessionKey() SessionMetadata { + return SessionMetadata{ + Type: "sso.GetRoleCredentials", + ProfileName: "test-profile", + MfaSerial: "https://sso.example", + } +} + +func newTestCreds(expires time.Time) *types.Credentials { + return &types.Credentials{ + AccessKeyId: aws.String("AKIATEST"), + SecretAccessKey: aws.String("secret"), + SessionToken: aws.String("token"), + Expiration: aws.Time(expires), + } +} + +func TestCachedSession_CacheHit_NoLock(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + if err := sk.Set(key, creds); err != nil { + t.Fatalf("set cache: %v", err) + } + + lock := &testLock{} + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called on cache hit") }, + } + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + sessionLock: lock, + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockMiss_ThenCacheHit_NoRefresh(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{false}} + + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills while waiting") }, + } + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + sessionLock: lock, + sessionLockWait: 5 * time.Second, + } + p.sessionSleep = func(ctx context.Context, d time.Duration) error { + return sk.Set(key, creds) + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 1 { + t.Fatalf("expected 1 lock attempt, got %d", lock.tryCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockAcquired_RecheckCache(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + lock.onTry = func(l *testLock) { + if l.locked { + _ = sk.Set(key, creds) + } + } + + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills after lock") }, + } + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + sessionLock: lock, + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockHeldThroughCacheSet(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + lock := &testLock{tryResults: []bool{true}} + wrappedKeyring := &lockCheckingKeyring{ + Keyring: keyring.NewArrayKeyring(nil), + setLock: lock, + } + sk := &SessionKeyring{Keyring: wrappedKeyring} + provider := &testSessionProvider{creds: creds} + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + sessionLock: lock, + } + + _, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wrappedKeyring.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", wrappedKeyring.setCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } + if provider.calls != 1 { + t.Fatalf("expected 1 provider call, got %d", provider.calls) + } +} + +func TestCachedSession_LockWaitLogs(t *testing.T) { + lock := &testLock{tryResults: []bool{false, false, false, false}} + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + key := newTestSessionKey() + provider := &testSessionProvider{} + + ctx, cancel := context.WithCancel(context.Background()) + clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} + var logTimes []time.Time + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + sessionLock: lock, + sessionLockWait: 5 * time.Second, + sessionLockLog: 15 * time.Second, + sessionNow: clock.Now, + } + p.sessionSleep = clock.Sleep + p.sessionLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) + } + + _, err := p.RetrieveStsCredentials(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + if len(logTimes) != 2 { + t.Fatalf("expected 2 log entries, got %d", len(logTimes)) + } + if !logTimes[0].Equal(time.Unix(0, 0)) { + t.Fatalf("unexpected first log time: %s", logTimes[0]) + } + if !logTimes[1].Equal(time.Unix(15, 0)) { + t.Fatalf("unexpected second log time: %s", logTimes[1]) + } +} diff --git a/vault/keychain_lock.go b/vault/keychain_lock.go new file mode 100644 index 000000000..55bbd013e --- /dev/null +++ b/vault/keychain_lock.go @@ -0,0 +1,21 @@ +package vault + +const keychainLockFilenamePrefix = "aws-vault.keychain" + +// KeychainLock coordinates keychain access across processes. +type KeychainLock = ProcessLock + +// NewDefaultKeychainLock creates a lock in the system temp directory. +// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. +func NewDefaultKeychainLock(lockKey string) KeychainLock { + return NewKeychainLock(defaultLockPath(keychainLockFilename(lockKey))) +} + +// NewKeychainLock creates a lock at the provided path. +func NewKeychainLock(path string) KeychainLock { + return NewFileLock(path) +} + +func keychainLockFilename(lockKey string) string { + return hashedLockFilename(keychainLockFilenamePrefix, lockKey) +} diff --git a/vault/lock_test.go b/vault/lock_test.go new file mode 100644 index 000000000..18a2607b7 --- /dev/null +++ b/vault/lock_test.go @@ -0,0 +1,66 @@ +package vault + +import ( + "context" + "time" +) + +type testLock struct { + tryResults []bool + tryCalls int + unlockCalls int + locked bool + path string + onTry func(*testLock) +} + +func (l *testLock) TryLock() (bool, error) { + l.tryCalls++ + locked := false + if l.tryCalls <= len(l.tryResults) { + locked = l.tryResults[l.tryCalls-1] + } + if locked { + l.locked = true + } + if l.onTry != nil { + l.onTry(l) + } + return locked, nil +} + +func (l *testLock) Unlock() error { + l.unlockCalls++ + l.locked = false + return nil +} + +func (l *testLock) Path() string { + if l.path != "" { + return l.path + } + return "/tmp/aws-vault.lock" +} + +type testClock struct { + now time.Time + sleepCalls int + cancelAfter int + cancel context.CancelFunc +} + +func (c *testClock) Now() time.Time { + return c.now +} + +func (c *testClock) Sleep(ctx context.Context, d time.Duration) error { + c.sleepCalls++ + c.now = c.now.Add(d) + if c.cancel != nil && c.cancelAfter > 0 && c.sleepCalls >= c.cancelAfter { + c.cancel() + } + if ctx.Err() != nil { + return ctx.Err() + } + return nil +} diff --git a/vault/lock_waiter.go b/vault/lock_waiter.go new file mode 100644 index 000000000..5cdac42ae --- /dev/null +++ b/vault/lock_waiter.go @@ -0,0 +1,85 @@ +package vault + +import ( + "context" + "time" +) + +type lockLogger func(string, ...any) + +type lockWaiter struct { + lock ProcessLock + waitDelay time.Duration + logEvery time.Duration + warnAfter time.Duration + now func() time.Time + sleep func(context.Context, time.Duration) error + logf lockLogger + warnf lockLogger + warnMsg string + logMsg string + + lastLog time.Time + waitStart time.Time + warned bool +} + +func newLockWaiter( + lock ProcessLock, + warnMsg string, + logMsg string, + waitDelay time.Duration, + logEvery time.Duration, + warnAfter time.Duration, + now func() time.Time, + sleep func(context.Context, time.Duration) error, + logf lockLogger, + warnf lockLogger, +) *lockWaiter { + if now == nil { + now = time.Now + } + if sleep == nil { + sleep = func(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + } + return &lockWaiter{ + lock: lock, + waitDelay: waitDelay, + logEvery: logEvery, + warnAfter: warnAfter, + now: now, + sleep: sleep, + logf: logf, + warnf: warnf, + warnMsg: warnMsg, + logMsg: logMsg, + } +} + +func (w *lockWaiter) sleepAfterMiss(ctx context.Context) error { + now := w.now() + if w.waitStart.IsZero() { + w.waitStart = now + } + if !w.warned && now.Sub(w.waitStart) >= w.warnAfter { + if w.warnf != nil { + w.warnf(w.warnMsg, w.lock.Path()) + } + w.warned = true + } + if w.logf != nil && (w.lastLog.IsZero() || now.Sub(w.lastLog) >= w.logEvery) { + w.logf(w.logMsg, w.lock.Path()) + w.lastLog = now + } + + return w.sleep(ctx, w.waitDelay) +} diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go new file mode 100644 index 000000000..7e53a6a27 --- /dev/null +++ b/vault/locked_keyring.go @@ -0,0 +1,163 @@ +package vault + +import ( + "context" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/byteness/keyring" +) + +type lockedKeyring struct { + inner keyring.Keyring + lock KeychainLock + mu sync.Mutex + + lockKey string + lockWait time.Duration + lockLog time.Duration + warnAfter time.Duration + lockNow func() time.Time + lockSleep func(context.Context, time.Duration) error + lockLogf func(string, ...any) +} + +const ( + defaultKeychainLockWaitDelay = 100 * time.Millisecond + defaultKeychainLockLogEvery = 15 * time.Second + defaultKeychainLockWarnAfter = 5 * time.Second +) + +// NewKeychainLockedKeyring wraps the provided keyring with a cross-process lock +// to serialize keychain operations. +func NewKeychainLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { + return &lockedKeyring{ + inner: kr, + lock: NewDefaultKeychainLock(lockKey), + lockKey: lockKey, + } +} + +func (k *lockedKeyring) ensureLockDependencies() { + if k.lock == nil { + lockKey := k.lockKey + if lockKey == "" { + lockKey = "aws-vault" + } + k.lock = NewDefaultKeychainLock(lockKey) + } + if k.lockWait == 0 { + k.lockWait = defaultKeychainLockWaitDelay + } + if k.lockLog == 0 { + k.lockLog = defaultKeychainLockLogEvery + } + if k.warnAfter == 0 { + k.warnAfter = defaultKeychainLockWarnAfter + } + if k.lockNow == nil { + k.lockNow = time.Now + } + if k.lockSleep == nil { + k.lockSleep = func(_ context.Context, d time.Duration) error { + time.Sleep(d) + return nil + } + } + if k.lockLogf == nil { + k.lockLogf = log.Printf + } +} + +func (k *lockedKeyring) withLock(fn func() error) error { + k.ensureLockDependencies() + + k.mu.Lock() + defer k.mu.Unlock() + + waiter := newLockWaiter( + k.lock, + "Waiting for keychain lock at %s\n", + "Waiting for keychain lock at %s", + k.lockWait, + k.lockLog, + k.warnAfter, + k.lockNow, + k.lockSleep, + k.lockLogf, + func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + ) + + ctx := context.Background() + for { + locked, err := k.lock.TryLock() + if err != nil { + return err + } + if locked { + err = fn() + unlockErr := k.lock.Unlock() + if unlockErr != nil { + return unlockErr + } + return err + } + + if err = waiter.sleepAfterMiss(ctx); err != nil { + return err + } + } +} + +func (k *lockedKeyring) Get(key string) (keyring.Item, error) { + var item keyring.Item + if err := k.withLock(func() error { + var err error + item, err = k.inner.Get(key) + return err + }); err != nil { + return keyring.Item{}, err + } + return item, nil +} + +func (k *lockedKeyring) GetMetadata(key string) (keyring.Metadata, error) { + var meta keyring.Metadata + if err := k.withLock(func() error { + var err error + meta, err = k.inner.GetMetadata(key) + return err + }); err != nil { + return keyring.Metadata{}, err + } + return meta, nil +} + +func (k *lockedKeyring) Set(item keyring.Item) error { + return k.withLock(func() error { + return k.inner.Set(item) + }) +} + +func (k *lockedKeyring) Remove(key string) error { + return k.withLock(func() error { + return k.inner.Remove(key) + }) +} + +func (k *lockedKeyring) Keys() ([]string, error) { + var keys []string + if err := k.withLock(func() error { + var err error + keys, err = k.inner.Keys() + return err + }); err != nil { + return nil, err + } + return keys, nil +} diff --git a/vault/process_lock.go b/vault/process_lock.go new file mode 100644 index 000000000..0582fe3c1 --- /dev/null +++ b/vault/process_lock.go @@ -0,0 +1,47 @@ +package vault + +import ( + "crypto/sha256" + "fmt" + "os" + "path/filepath" + + "github.com/gofrs/flock" +) + +// ProcessLock coordinates work across processes. +type ProcessLock interface { + TryLock() (bool, error) + Unlock() error + Path() string +} + +type fileProcessLock struct { + lock *flock.Flock +} + +// NewFileLock creates a lock at the provided path. +func NewFileLock(path string) ProcessLock { + return &fileProcessLock{lock: flock.New(path)} +} + +func (l *fileProcessLock) TryLock() (bool, error) { + return l.lock.TryLock() +} + +func (l *fileProcessLock) Unlock() error { + return l.lock.Unlock() +} + +func (l *fileProcessLock) Path() string { + return l.lock.Path() +} + +func defaultLockPath(filename string) string { + return filepath.Join(os.TempDir(), filename) +} + +func hashedLockFilename(prefix, key string) string { + sum := sha256.Sum256([]byte(key)) + return fmt.Sprintf("%s.%x.lock", prefix, sum) +} diff --git a/vault/session_lock.go b/vault/session_lock.go new file mode 100644 index 000000000..5d77e2407 --- /dev/null +++ b/vault/session_lock.go @@ -0,0 +1,21 @@ +package vault + +const sessionLockFilenamePrefix = "aws-vault.session" + +// SessionCacheLock coordinates session cache refreshes across processes. +type SessionCacheLock = ProcessLock + +// NewDefaultSessionCacheLock creates a lock in the system temp directory. +// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. +func NewDefaultSessionCacheLock(lockKey string) SessionCacheLock { + return NewSessionCacheLock(defaultLockPath(sessionLockFilename(lockKey))) +} + +// NewSessionCacheLock creates a lock at the provided path. +func NewSessionCacheLock(path string) SessionCacheLock { + return NewFileLock(path) +} + +func sessionLockFilename(lockKey string) string { + return hashedLockFilename(sessionLockFilenamePrefix, lockKey) +} diff --git a/vault/sso_lock.go b/vault/sso_lock.go new file mode 100644 index 000000000..d90a879b7 --- /dev/null +++ b/vault/sso_lock.go @@ -0,0 +1,17 @@ +package vault + +const defaultSSOLockFilename = "aws-vault.sso.lock" + +// SSOTokenLock coordinates the SSO device flow across processes. +type SSOTokenLock = ProcessLock + +// NewDefaultSSOTokenLock creates a lock in the system temp directory. +// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. +func NewDefaultSSOTokenLock() SSOTokenLock { + return NewSSOTokenLock(defaultLockPath(defaultSSOLockFilename)) +} + +// NewSSOTokenLock creates a lock at the provided path. +func NewSSOTokenLock(path string) SSOTokenLock { + return NewFileLock(path) +} diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 9854f4ee3..c672dd4fc 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -5,8 +5,11 @@ import ( "errors" "fmt" "log" + "math/rand" "net/http" "os" + "strconv" + "strings" "time" "github.com/byteness/keyring" @@ -35,12 +38,67 @@ type SSORoleCredentialsProvider struct { AccountID string RoleName string UseStdout bool + ssoTokenLock SSOTokenLock + ssoLockWait time.Duration + ssoLockLog time.Duration + ssoNow func() time.Time + ssoSleep func(context.Context, time.Duration) error + ssoLogf func(string, ...any) + newOIDCTokenFn func(context.Context) (*ssooidc.CreateTokenOutput, error) } func millisecondsTimeValue(v int64) time.Time { return time.Unix(0, v*int64(time.Millisecond)) } +const ( + defaultSSOLockWaitDelay = 100 * time.Millisecond + defaultSSOLockLogEvery = 15 * time.Second + defaultSSOLockWarnAfter = 5 * time.Second + // 0 means retry indefinitely (caller is expected to use context cancellation). + ssoMaxAttempts = 0 + ssoRetryBase = 200 * time.Millisecond + ssoRetryMax = 5 * time.Second + ssoRetryAfterJitterMin = 1.1 + ssoRetryAfterJitterMax = 1.3 +) + +func defaultSSOSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (p *SSORoleCredentialsProvider) ensureSSODependencies() { + if p.ssoTokenLock == nil && !p.UseStdout { + p.ssoTokenLock = NewDefaultSSOTokenLock() + } + if p.ssoLockWait == 0 { + p.ssoLockWait = defaultSSOLockWaitDelay + } + if p.ssoLockLog == 0 { + p.ssoLockLog = defaultSSOLockLogEvery + } + if p.ssoNow == nil { + p.ssoNow = time.Now + } + if p.ssoSleep == nil { + p.ssoSleep = defaultSSOSleep + } + if p.ssoLogf == nil { + p.ssoLogf = log.Printf + } + if p.newOIDCTokenFn == nil { + p.newOIDCTokenFn = p.newOIDCToken + } +} + // Retrieve generates a new set of temporary credentials using SSO GetRoleCredentials. func (p *SSORoleCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { creds, err := p.getRoleCredentials(ctx) @@ -58,39 +116,69 @@ func (p *SSORoleCredentialsProvider) Retrieve(ctx context.Context) (aws.Credenti } func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*ssotypes.RoleCredentials, error) { + p.ensureSSODependencies() + token, cached, err := p.getOIDCToken(ctx) if err != nil { return nil, err } - resp, err := p.SSOClient.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ - AccessToken: token.AccessToken, - AccountId: aws.String(p.AccountID), - RoleName: aws.String(p.RoleName), - }) - if err != nil { + maxAttempts, baseDelay, maxDelay := ssoRetrySettings() + attempt := 0 + for { + attempt++ + resp, err := p.SSOClient.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ + AccessToken: token.AccessToken, + AccountId: aws.String(p.AccountID), + RoleName: aws.String(p.RoleName), + }) + if err == nil { + log.Printf("Got credentials %s for SSO role %s (account: %s), expires in %s", FormatKeyForDisplay(*resp.RoleCredentials.AccessKeyId), p.RoleName, p.AccountID, time.Until(millisecondsTimeValue(resp.RoleCredentials.Expiration)).String()) + return resp.RoleCredentials, nil + } + if cached && p.OIDCTokenCache != nil { var rspError *awshttp.ResponseError - if !errors.As(err, &rspError) { - return nil, err + if errors.As(err, &rspError) && rspError.HTTPStatusCode() == http.StatusUnauthorized { + // Cached token rejected: drop it and retry with a fresh access token. + // This should only happen once because the cache is cleared before retrying. + if err = p.OIDCTokenCache.Remove(p.StartURL); err != nil { + return nil, err + } + token, cached, err = p.getOIDCToken(ctx) + if err != nil { + return nil, err + } + attempt = 0 + continue } + } - // If the error is a 401, remove the cached oidc token and try - // again. This is a recursive call but it should only happen once - // due to the cache being cleared before retrying. - if rspError.HTTPStatusCode() == http.StatusUnauthorized { - err = p.OIDCTokenCache.Remove(p.StartURL) - if err != nil { + if isSSORateLimitError(err) { + if maxAttempts == 0 || attempt < maxAttempts { + attemptInfo := fmt.Sprintf("%d/%d", attempt, maxAttempts) + if maxAttempts == 0 { + attemptInfo = fmt.Sprintf("%d/inf", attempt) + } + if retryAfter, ok := retryAfterFromError(err); ok { + delay := jitterRetryAfter(retryAfter) + log.Printf("SSO rate limited for role %s (account: %s); retry-after %s (jittered %s), attempt %s", p.RoleName, p.AccountID, retryAfter, delay, attemptInfo) + if err = p.ssoSleep(ctx, delay); err != nil { + return nil, err + } + continue + } + delay := jitteredBackoff(baseDelay, maxDelay, attempt) + log.Printf("SSO rate limited for role %s (account: %s); backing off %s (synthetic), attempt %s", p.RoleName, p.AccountID, delay, attemptInfo) + if err = p.ssoSleep(ctx, delay); err != nil { return nil, err } - return p.getRoleCredentials(ctx) + continue } } + return nil, err } - log.Printf("Got credentials %s for SSO role %s (account: %s), expires in %s", FormatKeyForDisplay(*resp.RoleCredentials.AccessKeyId), p.RoleName, p.AccountID, time.Until(millisecondsTimeValue(resp.RoleCredentials.Expiration)).String()) - - return resp.RoleCredentials, nil } func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { @@ -113,27 +201,120 @@ func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx cont } func (p *SSORoleCredentialsProvider) getOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { - if p.OIDCTokenCache != nil { - token, err = p.OIDCTokenCache.Get(p.StartURL) - if err != nil && err != keyring.ErrKeyNotFound { - return nil, false, err - } - if token != nil { - return token, true, nil - } + p.ensureSSODependencies() + + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + return token, cached, err } - token, err = p.newOIDCToken(ctx) + + if p.UseStdout { + return p.createAndCacheOIDCToken(ctx) + } + + return p.getOIDCTokenWithLock(ctx) +} + +func (p *SSORoleCredentialsProvider) getCachedOIDCToken() (token *ssooidc.CreateTokenOutput, cached bool, err error) { + if p.OIDCTokenCache == nil { + return nil, false, nil + } + + token, err = p.OIDCTokenCache.Get(p.StartURL) + if err != nil && err != keyring.ErrKeyNotFound { + return nil, false, err + } + if token != nil { + return token, true, nil + } + + return nil, false, nil +} + +func (p *SSORoleCredentialsProvider) createAndCacheOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + token, err = p.newOIDCTokenFn(ctx) if err != nil { return nil, false, err } if p.OIDCTokenCache != nil { - err = p.OIDCTokenCache.Set(p.StartURL, token) + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { + return nil, false, err + } + } + + return token, false, nil +} + +func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + waiter := newLockWaiter( + p.ssoTokenLock, + "Waiting for SSO lock at %s\n", + "Waiting for SSO lock at %s", + p.ssoLockWait, + p.ssoLockLog, + defaultSSOLockWarnAfter, + p.ssoNow, + p.ssoSleep, + p.ssoLogf, + func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + ) + + for { + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + return token, cached, err + } + if ctx.Err() != nil { + return nil, false, ctx.Err() + } + + locked, err := p.ssoTokenLock.TryLock() if err != nil { return nil, false, err } + if locked { + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + unlockErr := p.ssoTokenLock.Unlock() + if unlockErr != nil { + return nil, false, unlockErr + } + return token, cached, err + } + + token, err = p.newOIDCTokenFn(ctx) + if err != nil { + unlockErr := p.ssoTokenLock.Unlock() + if unlockErr != nil { + return nil, false, unlockErr + } + return nil, false, err + } + + if p.OIDCTokenCache != nil { + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { + unlockErr := p.ssoTokenLock.Unlock() + if unlockErr != nil { + return nil, false, unlockErr + } + return nil, false, err + } + } + + if err = p.ssoTokenLock.Unlock(); err != nil { + return nil, false, err + } + + return token, false, nil + } + + if err = waiter.sleepAfterMiss(ctx); err != nil { + return nil, false, err + } } - return token, false, err } func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { @@ -201,3 +382,87 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc return t, nil } } + +func ssoRetrySettings() (int, time.Duration, time.Duration) { + return ssoMaxAttempts, ssoRetryBase, ssoRetryMax +} + +func retryAfterFromError(err error) (time.Duration, bool) { + var rspError *awshttp.ResponseError + if errors.As(err, &rspError) { + if rspError.Response != nil { + if d, ok := parseRetryAfter(rspError.Response.Header.Get("Retry-After")); ok { + return d, true + } + } + } + return 0, false +} + +func parseRetryAfter(value string) (time.Duration, bool) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return 0, false + } + if secs, err := strconv.Atoi(trimmed); err == nil { + if secs < 0 { + return 0, false + } + return time.Duration(secs) * time.Second, true + } + if t, err := http.ParseTime(trimmed); err == nil { + d := time.Until(t) + if d < 0 { + d = 0 + } + return d, true + } + return 0, false +} + +func isSSORateLimitError(err error) bool { + var tooMany *ssotypes.TooManyRequestsException + if errors.As(err, &tooMany) { + return true + } + var rspError *awshttp.ResponseError + if errors.As(err, &rspError) && rspError.HTTPStatusCode() == http.StatusTooManyRequests { + return true + } + return false +} + +func jitterRetryAfter(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + return jitterDelay(base) +} + +func jitteredBackoff(base, max time.Duration, attempt int) time.Duration { + if attempt < 1 { + attempt = 1 + } + capDelay := base << uint(attempt-1) + if capDelay > max { + capDelay = max + } + if capDelay < base { + capDelay = base + } + return jitterDelay(capDelay) +} + +func jitterDelay(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + min := ssoRetryAfterJitterMin + max := ssoRetryAfterJitterMax + if max < min { + max = min + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + factor := min + r.Float64()*(max-min) + return time.Duration(float64(base) * factor) +} diff --git a/vault/ssorolecredentialsprovider_lock_test.go b/vault/ssorolecredentialsprovider_lock_test.go new file mode 100644 index 000000000..347e0844b --- /dev/null +++ b/vault/ssorolecredentialsprovider_lock_test.go @@ -0,0 +1,261 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssooidc" + "github.com/byteness/keyring" +) + +type testTokenCache struct { + token *ssooidc.CreateTokenOutput + setCalls int + setLock *testLock +} + +func (c *testTokenCache) Get(string) (*ssooidc.CreateTokenOutput, error) { + if c.token == nil { + return nil, keyring.ErrKeyNotFound + } + return c.token, nil +} + +func (c *testTokenCache) Set(_ string, token *ssooidc.CreateTokenOutput) error { + c.setCalls++ + if c.setLock != nil && !c.setLock.locked { + return fmt.Errorf("lock not held during cache set") + } + c.token = token + return nil +} + +func (c *testTokenCache) Remove(string) error { + c.token = nil + return nil +} + +func TestGetOIDCToken_CacheHit_NoLock(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{token: cachedToken} + lock := &testLock{} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called on cache hit") + return nil, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } +} + +func TestGetOIDCToken_LockMiss_ThenCacheHit_NoLock(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{false}} + clock := &testClock{now: time.Unix(0, 0)} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + ssoLockWait: 5 * time.Second, + ssoNow: clock.Now, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when cache fills while waiting") + return nil, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(ctx context.Context, d time.Duration) error { + clock.now = clock.now.Add(d) + cache.token = cachedToken + return nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 1 { + t.Fatalf("expected 1 lock attempt, got %d", lock.tryCalls) + } + if lock.unlockCalls != 0 { + t.Fatalf("expected no unlocks, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_LockAcquired_RecheckCache(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{true}} + lock.onTry = func(l *testLock) { + if l.locked { + cache.token = cachedToken + } + } + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when cache is filled after lock") + return nil, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_LockHeldThroughCacheSet(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{setLock: lock} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if cache.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", cache.setCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_UseStdout_SkipsLock(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: true, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } +} + +func TestGetOIDCToken_LockWaitLogs(t *testing.T) { + lock := &testLock{tryResults: []bool{false, false, false, false}} + cache := &testTokenCache{} + ctx, cancel := context.WithCancel(context.Background()) + clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} + var logTimes []time.Time + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + ssoLockWait: 5 * time.Second, + ssoLockLog: 15 * time.Second, + ssoNow: clock.Now, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when lock never acquired") + return nil, nil + } + p.ssoSleep = clock.Sleep + p.ssoLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) + } + + _, _, err := p.getOIDCToken(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + if len(logTimes) != 2 { + t.Fatalf("expected 2 log entries, got %d", len(logTimes)) + } + if !logTimes[0].Equal(time.Unix(0, 0)) { + t.Fatalf("unexpected first log time: %s", logTimes[0]) + } + if !logTimes[1].Equal(time.Unix(15, 0)) { + t.Fatalf("unexpected second log time: %s", logTimes[1]) + } +} diff --git a/vault/ssorolecredentialsprovider_retry_test.go b/vault/ssorolecredentialsprovider_retry_test.go new file mode 100644 index 000000000..71bc9a205 --- /dev/null +++ b/vault/ssorolecredentialsprovider_retry_test.go @@ -0,0 +1,78 @@ +package vault + +import ( + "errors" + "net/http" + "testing" + "time" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + ssotypes "github.com/aws/aws-sdk-go-v2/service/sso/types" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func TestRetryAfterFromErrorSeconds(t *testing.T) { + header := http.Header{} + header.Set("Retry-After", "120") + resp := &http.Response{StatusCode: http.StatusTooManyRequests, Header: header} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + + delay, ok := retryAfterFromError(err) + if !ok { + t.Fatal("expected retry-after delay to be detected") + } + if delay != 120*time.Second { + t.Fatalf("expected 120s retry-after, got %s", delay) + } +} + +func TestRetryAfterFromErrorMissingHeader(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusTooManyRequests, Header: http.Header{}} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + + delay, ok := retryAfterFromError(err) + if ok { + t.Fatalf("expected retry-after to be absent, got %s", delay) + } +} + +func TestIsSSORateLimitError(t *testing.T) { + if !isSSORateLimitError(&ssotypes.TooManyRequestsException{}) { + t.Fatal("expected TooManyRequestsException to be rate limit error") + } + + resp := &http.Response{StatusCode: http.StatusTooManyRequests} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + if !isSSORateLimitError(err) { + t.Fatal("expected HTTP 429 response error to be rate limit error") + } + + if isSSORateLimitError(errors.New("boom")) { + t.Fatal("expected non-rate-limit error to be false") + } +} + +func TestJitterDelayRange(t *testing.T) { + base := 10 * time.Second + min := time.Duration(float64(base) * ssoRetryAfterJitterMin) + max := time.Duration(float64(base) * ssoRetryAfterJitterMax) + + for i := 0; i < 10; i++ { + delay := jitterDelay(base) + if delay < min || delay > max { + t.Fatalf("expected delay in range %s-%s, got %s", min, max, delay) + } + } +}