Skip to content

Commit ea54414

Browse files
committed
allow for custom ttl in session store
1 parent 9ed0647 commit ea54414

File tree

10 files changed

+220
-69
lines changed

10 files changed

+220
-69
lines changed

auth/api/iam/api.go

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ type httpRequestContextKey struct{}
7474
// TODO: Might want to make this configurable at some point
7575
const accessTokenValidity = 15 * time.Minute
7676

77+
// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it.
78+
// this to offset clock skew and roundtrip times
79+
const accessTokenCacheOffset = 30 * time.Second
80+
7781
// cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set.
7882
var cacheControlMaxAgeURLs = []string{
7983
"/oauth2/:subjectID/presentation_definition",
@@ -725,22 +729,16 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
725729
}
726730

727731
tokenCache := r.accessTokenCache()
728-
cacheKey, err := accessTokenRequestCacheKey(request)
729-
cacheToken := true
730-
if err != nil {
731-
cacheToken = false
732+
cacheKey := accessTokenRequestCacheKey(request)
733+
734+
// try to retrieve token from cache
735+
tokenResult := new(TokenResponse)
736+
err = tokenCache.Get(cacheKey, tokenResult)
737+
if err == nil {
738+
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
739+
} else if !errors.Is(err, storage.ErrNotFound) {
732740
// only log error, don't fail
733-
log.Logger().WithError(err).Warnf("Failed to create cache key for access token request: %s", err.Error())
734-
} else {
735-
// try to retrieve token from cache
736-
tokenResult := new(TokenResponse)
737-
err = tokenCache.Get(cacheKey, tokenResult)
738-
if err == nil {
739-
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
740-
} else if !errors.Is(err, storage.ErrNotFound) {
741-
// only log error, don't fail
742-
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
743-
}
741+
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
744742
}
745743

746744
var credentials []VerifiableCredential
@@ -753,17 +751,21 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
753751
useDPoP = false
754752
}
755753
clientID := r.subjectToBaseURL(request.SubjectID)
756-
tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
754+
tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
757755
if err != nil {
758756
// this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials
759757
return nil, err
760758
}
761-
if cacheToken {
762-
err = tokenCache.Put(cacheKey, tokenResult)
763-
if err != nil {
764-
// only log error, don't fail
765-
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
766-
}
759+
ttl := accessTokenValidity
760+
if tokenResult.ExpiresIn != nil {
761+
ttl = time.Second * time.Duration(*tokenResult.ExpiresIn)
762+
}
763+
// we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires
764+
ttl -= accessTokenCacheOffset
765+
err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl))
766+
if err != nil {
767+
// only log error, don't fail
768+
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
767769
}
768770
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
769771
}
@@ -928,7 +930,7 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore {
928930
// accessTokenClientStore is used by the client to cache access tokens
929931
func (r Wrapper) accessTokenCache() storage.SessionStore {
930932
// we use a slightly reduced validity to prevent the cache from being used after the token has expired
931-
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-30*time.Second, "accesstokencache")
933+
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache")
932934
}
933935

934936
// accessTokenServerStore is used by the Auth server to store issued access tokens
@@ -983,12 +985,8 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut
983985

984986
// accessTokenRequestCacheKey creates a cache key for the access token request.
985987
// it writes the JSON to a sha256 hash and returns the hex encoded hash.
986-
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) (string, error) {
987-
// create a hash of the request
988+
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string {
988989
hash := sha256.New()
989-
err := json.NewEncoder(hash).Encode(request)
990-
if err != nil {
991-
return "", err
992-
}
993-
return hex.EncodeToString(hash.Sum(nil)), nil
990+
_ = json.NewEncoder(hash).Encode(request)
991+
return hex.EncodeToString(hash.Sum(nil))
994992
}

auth/api/iam/api_test.go

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
880880
})
881881

882882
t.Run("cache expired", func(t *testing.T) {
883-
cacheKey, _ := accessTokenRequestCacheKey(request)
883+
cacheKey := accessTokenRequestCacheKey(request)
884884
_ = ctx.client.accessTokenCache().Delete(cacheKey)
885885
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil)
886886

@@ -905,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
905905

906906
require.NoError(t, err)
907907
})
908+
t.Run("ok with expired cache by ttl", func(t *testing.T) {
909+
ctx := newTestClient(t)
910+
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
911+
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(5)}, nil)
912+
913+
_, err := ctx.client.RequestServiceAccessToken(nil, request)
914+
915+
require.NoError(t, err)
916+
assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request)))
917+
})
908918
t.Run("error - no matching credentials", func(t *testing.T) {
909919
ctx := newTestClient(t)
910920
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials)
@@ -915,6 +925,23 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
915925
assert.Equal(t, err, pe.ErrNoCredentials)
916926
assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err))
917927
})
928+
t.Run("broken cache", func(t *testing.T) {
929+
ctx := newTestClient(t)
930+
mockStorage := storage.NewMockEngine(ctx.ctrl)
931+
mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase{err: assert.AnError}).AnyTimes()
932+
ctx.client.storageEngine = mockStorage
933+
934+
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
935+
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil)
936+
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil)
937+
938+
token1, err := ctx.client.RequestServiceAccessToken(nil, request)
939+
require.NoError(t, err)
940+
token2, err := ctx.client.RequestServiceAccessToken(nil, request)
941+
require.NoError(t, err)
942+
943+
assert.NotEqual(t, token1, token2)
944+
})
918945
}
919946

920947
func TestWrapper_RequestUserAccessToken(t *testing.T) {
@@ -1340,6 +1367,15 @@ func TestWrapper_subjectOwns(t *testing.T) {
13401367
})
13411368
}
13421369

1370+
func TestWrapper_accessTokenRequestCacheKey(t *testing.T) {
1371+
expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86"
1372+
key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}})
1373+
other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}})
1374+
1375+
assert.Equal(t, expected, key)
1376+
assert.NotEqual(t, key, other)
1377+
}
1378+
13431379
func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential {
13441380
privateKey, _ := spi.GenerateKeyPair()
13451381
credType := ssi.MustParseURI("ExampleType")
@@ -1509,3 +1545,42 @@ func newCustomTestClient(t testing.TB, publicURL *url.URL, authEndpointEnabled b
15091545
client: client,
15101546
}
15111547
}
1548+
1549+
var _ storage.SessionDatabase = (*errorSessionDatabase)(nil)
1550+
var _ storage.SessionStore = (*errorSessionStore)(nil)
1551+
1552+
type errorSessionDatabase struct {
1553+
err error
1554+
}
1555+
1556+
type errorSessionStore struct {
1557+
err error
1558+
}
1559+
1560+
func (e errorSessionDatabase) GetStore(ttl time.Duration, keys ...string) storage.SessionStore {
1561+
return errorSessionStore{err: e.err}
1562+
}
1563+
1564+
func (e errorSessionDatabase) Close() {
1565+
// nop
1566+
}
1567+
1568+
func (e errorSessionStore) Delete(key string) error {
1569+
return e.err
1570+
}
1571+
1572+
func (e errorSessionStore) Exists(key string) bool {
1573+
return false
1574+
}
1575+
1576+
func (e errorSessionStore) Get(key string, target interface{}) error {
1577+
return e.err
1578+
}
1579+
1580+
func (e errorSessionStore) Put(key string, value interface{}, options ...storage.SessionOption) error {
1581+
return e.err
1582+
}
1583+
1584+
func (e errorSessionStore) GetAndDelete(key string, target interface{}) error {
1585+
return e.err
1586+
}

storage/engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func (e *engine) Shutdown() error {
135135
}
136136

137137
// Close session database
138-
e.sessionDatabase.close()
138+
e.sessionDatabase.Close()
139139
// Close SQL db
140140
if e.sqlDB != nil {
141141
underlyingDB, err := e.sqlDB.DB()

storage/interface.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ type SessionDatabase interface {
7878
// The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification.
7979
// The TTL is the time-to-live for the entries in the store.
8080
GetStore(ttl time.Duration, keys ...string) SessionStore
81-
// close stops any background processes and closes the database.
82-
close()
81+
// Close stops any background processes and closes the database.
82+
Close()
8383
}
8484

8585
// SessionStore is a key-value store that holds session data.
@@ -94,10 +94,25 @@ type SessionStore interface {
9494
// Returns ErrNotFound if the key does not exist.
9595
Get(key string, target interface{}) error
9696
// Put stores the given value for the given key.
97-
Put(key string, value interface{}) error
97+
// options can be used to fine-tune the storage of the item.
98+
Put(key string, value interface{}, options ...SessionOption) error
9899
// GetAndDelete combines Get and Delete as a convenience for burning nonce entries.
99100
GetAndDelete(key string, target interface{}) error
100101
}
101102

102103
// TransactionKey is the key used to store the SQL transaction in the context.
103104
type TransactionKey struct{}
105+
106+
// SessionOption is an option that can be given when storing items.
107+
type SessionOption func(target *sessionOptions)
108+
109+
type sessionOptions struct {
110+
ttl time.Duration
111+
}
112+
113+
// WithTTL sets the time-to-live for the stored item.
114+
func WithTTL(ttl time.Duration) SessionOption {
115+
return func(target *sessionOptions) {
116+
target.ttl = ttl
117+
}
118+
}

storage/mock.go

Lines changed: 21 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

storage/session_inmemory.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (i *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se
6666
}
6767
}
6868

69-
func (i *InMemorySessionDatabase) close() {
69+
func (i *InMemorySessionDatabase) Close() {
7070
// Signal pruner to stop and wait for it to finish
7171
i.done <- struct{}{}
7272
}
@@ -127,8 +127,14 @@ func (i InMemorySessionStore) Exists(key string) bool {
127127
i.db.mux.Lock()
128128
defer i.db.mux.Unlock()
129129

130-
_, ok := i.db.entries[i.getFullKey(key)]
131-
return ok
130+
entry, ok := i.db.entries[i.getFullKey(key)]
131+
if !ok {
132+
return false
133+
}
134+
if entry.Expiry.Before(time.Now()) {
135+
return false
136+
}
137+
return true
132138
}
133139

134140
func (i InMemorySessionStore) Get(key string, target interface{}) error {
@@ -151,7 +157,12 @@ func (i InMemorySessionStore) get(key string, target interface{}) error {
151157
return json.Unmarshal([]byte(entry.Value), target)
152158
}
153159

154-
func (i InMemorySessionStore) Put(key string, value interface{}) error {
160+
func (i InMemorySessionStore) Put(key string, value interface{}, options ...SessionOption) error {
161+
defaultOptions := i.defaultOptions()
162+
for _, option := range options {
163+
option(&defaultOptions)
164+
}
165+
155166
i.db.mux.Lock()
156167
defer i.db.mux.Unlock()
157168

@@ -161,7 +172,7 @@ func (i InMemorySessionStore) Put(key string, value interface{}) error {
161172
}
162173
entry := expiringEntry{
163174
Value: string(bytes),
164-
Expiry: time.Now().Add(i.ttl),
175+
Expiry: time.Now().Add(defaultOptions.ttl),
165176
}
166177

167178
i.db.entries[i.getFullKey(key)] = entry
@@ -180,3 +191,9 @@ func (i InMemorySessionStore) GetAndDelete(key string, target interface{}) error
180191
func (i InMemorySessionStore) getFullKey(key string) string {
181192
return strings.Join(append(i.prefixes, key), "/")
182193
}
194+
195+
func (i InMemorySessionStore) defaultOptions() sessionOptions {
196+
return sessionOptions{
197+
ttl: i.ttl,
198+
}
199+
}

0 commit comments

Comments
 (0)