Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cli/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ func (a *AwsVault) Keyring() (keyring.Keyring, error) {
if err != nil {
return nil, err
}
if a.KeyringBackend == string(keyring.KeychainBackend) {
lockKey := a.KeyringConfig.KeychainName
if lockKey == "" {
lockKey = "aws-vault"
}
a.keyringImpl = vault.NewKeychainLockedKeyring(a.keyringImpl, lockKey)
}
}

return a.keyringImpl, nil
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/byteness/keyring v1.7.2
github.com/charmbracelet/huh v0.8.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/gofrs/flock v0.8.1
github.com/google/go-cmp v0.7.0
github.com/mattn/go-isatty v0.0.20
github.com/mattn/go-tty v0.0.7
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ github.com/extism/go-sdk v1.7.1 h1:lWJos6uY+tRFdlIHR+SJjwFDApY7OypS/2nMhiVQ9Sw=
github.com/extism/go-sdk v1.7.1/go.mod h1:IT+Xdg5AZM9hVtpFUA+uZCJMge/hbvshl8bwzLtFyKA=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
Expand Down
137 changes: 127 additions & 10 deletions vault/cachedsessionprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package vault

import (
"context"
"fmt"
"log"
"os"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -21,26 +23,141 @@ type CachedSessionProvider struct {
SessionProvider StsSessionProvider
Keyring *SessionKeyring
ExpiryWindow time.Duration
sessionLock SessionCacheLock
sessionLockWait time.Duration
sessionLockLog time.Duration
sessionNow func() time.Time
sessionSleep func(context.Context, time.Duration) error
sessionLogf func(string, ...any)
}

const (
defaultSessionLockWaitDelay = 100 * time.Millisecond
defaultSessionLockLogEvery = 15 * time.Second
defaultSessionLockWarnAfter = 5 * time.Second
)

func defaultSessionSleep(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()

select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}

func (p *CachedSessionProvider) ensureSessionDependencies() {
if p.sessionLock == nil {
p.sessionLock = NewDefaultSessionCacheLock(p.SessionKey.StringForMatching())
}
if p.sessionLockWait == 0 {
p.sessionLockWait = defaultSessionLockWaitDelay
}
if p.sessionLockLog == 0 {
p.sessionLockLog = defaultSessionLockLogEvery
}
if p.sessionNow == nil {
p.sessionNow = time.Now
}
if p.sessionSleep == nil {
p.sessionSleep = defaultSessionSleep
}
if p.sessionLogf == nil {
p.sessionLogf = log.Printf
}
}

func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
creds, err := p.Keyring.Get(p.SessionKey)
p.ensureSessionDependencies()

creds, cached, err := p.getCachedSession()
if err == nil && cached {
return creds, nil
}

if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow {
// lookup missed, we need to create a new one.
creds, err = p.SessionProvider.RetrieveStsCredentials(ctx)
return p.getSessionWithLock(ctx)
}

func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, cached bool, err error) {
creds, err = p.Keyring.Get(p.SessionKey)
if err != nil {
return nil, false, err
}
if time.Until(*creds.Expiration) < p.ExpiryWindow {
return nil, false, nil
}
log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String())
return creds, true, nil
}

func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) {
waiter := newLockWaiter(
p.sessionLock,
"Waiting for session lock at %s\n",
"Waiting for session lock at %s",
p.sessionLockWait,
p.sessionLockLog,
defaultSessionLockWarnAfter,
p.sessionNow,
p.sessionSleep,
p.sessionLogf,
func(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
},
)

for {
creds, cached, err := p.getCachedSession()
if err == nil && cached {
return creds, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}

locked, err := p.sessionLock.TryLock()
if err != nil {
return nil, err
}
err = p.Keyring.Set(p.SessionKey, creds)
if err != nil {
if locked {
creds, cached, err = p.getCachedSession()
if err == nil && cached {
unlockErr := p.sessionLock.Unlock()
if unlockErr != nil {
return nil, unlockErr
}
return creds, nil
}

creds, err = p.SessionProvider.RetrieveStsCredentials(ctx)
if err != nil {
unlockErr := p.sessionLock.Unlock()
if unlockErr != nil {
return nil, unlockErr
}
return nil, err
}
if err = p.Keyring.Set(p.SessionKey, creds); err != nil {
unlockErr := p.sessionLock.Unlock()
if unlockErr != nil {
return nil, unlockErr
}
return nil, err
}

if err = p.sessionLock.Unlock(); err != nil {
return nil, err
}

return creds, nil
}
if err = waiter.sleepAfterMiss(ctx); err != nil {
return nil, err
}
} else {
log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String())
}

return creds, nil
}

// Retrieve returns cached credentials from the keyring, or if no credentials are cached
Expand Down
Loading