From f9794677efe6eb669116f97b8555d33bf1aebdda Mon Sep 17 00:00:00 2001 From: Wout Slakhorst Date: Tue, 3 Oct 2023 11:52:14 +0200 Subject: [PATCH] move session store to storage.SessionDatabase (#2475) --- network/dag/consistency_test.go | 4 +- .../transport/grpc/connection_manager_test.go | 2 +- network/transport/v2/gossip/manager_test.go | 2 +- network/transport/v2/protocol_test.go | 2 +- .../v2/transactionlist_handler_test.go | 4 +- pki/validator_test.go | 2 +- storage/engine.go | 25 +- storage/interface.go | 32 +++ storage/mock.go | 148 +++++++++++ storage/session.go | 169 +++++++++++++ storage/session_test.go | 239 ++++++++++++++++++ storage/test.go | 8 + vcr/issuer/openid.go | 24 +- vcr/issuer/openid_store.go | 149 ++--------- vcr/issuer/openid_store_test.go | 163 +----------- vcr/issuer/openid_test.go | 11 +- vcr/vcr.go | 9 +- 17 files changed, 676 insertions(+), 317 deletions(-) create mode 100644 storage/session.go create mode 100644 storage/session_test.go diff --git a/network/dag/consistency_test.go b/network/dag/consistency_test.go index 09f9614cf9..a13387b9df 100644 --- a/network/dag/consistency_test.go +++ b/network/dag/consistency_test.go @@ -35,9 +35,7 @@ import ( ) func TestXorTreeRepair(t *testing.T) { - t.Cleanup(func() { - goleak.VerifyNone(t) - }) + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) tx, _, _ := CreateTestTransaction(1) t.Run("xor tree repaired after 2 signals", func(t *testing.T) { diff --git a/network/transport/grpc/connection_manager_test.go b/network/transport/grpc/connection_manager_test.go index acc5080394..2db69baa0b 100644 --- a/network/transport/grpc/connection_manager_test.go +++ b/network/transport/grpc/connection_manager_test.go @@ -215,7 +215,7 @@ func Test_grpcConnectionManager_hasActiveConnection(t *testing.T) { func Test_grpcConnectionManager_dialerLoop(t *testing.T) { // make sure connectLoop only returns after all of its goroutines are closed - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) targetAddress := "bootstrap" var capturedAddress string diff --git a/network/transport/v2/gossip/manager_test.go b/network/transport/v2/gossip/manager_test.go index d92d8a56a9..85c9b6e5d4 100644 --- a/network/transport/v2/gossip/manager_test.go +++ b/network/transport/v2/gossip/manager_test.go @@ -97,7 +97,7 @@ func TestManager_PeerDisconnected(t *testing.T) { t.Run("stops ticker", func(t *testing.T) { // Use uber/goleak to assert the goroutine started by PeerConnected is stopped when PeerDisconnected is called - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) gMan := giveMeAgMan(t) gMan.interval = time.Millisecond diff --git a/network/transport/v2/protocol_test.go b/network/transport/v2/protocol_test.go index 1912e6b273..00a28f43c7 100644 --- a/network/transport/v2/protocol_test.go +++ b/network/transport/v2/protocol_test.go @@ -202,7 +202,7 @@ func TestProtocol_Start(t *testing.T) { func TestProtocol_Stop(t *testing.T) { t.Run("waits until goroutines have finished", func(t *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) // Use waitgroup to make sure the goroutine that blocks has started wg := &sync.WaitGroup{} diff --git a/network/transport/v2/transactionlist_handler_test.go b/network/transport/v2/transactionlist_handler_test.go index 16142ce84a..0971900cbc 100644 --- a/network/transport/v2/transactionlist_handler_test.go +++ b/network/transport/v2/transactionlist_handler_test.go @@ -37,9 +37,7 @@ import ( ) func TestTransactionListHandler(t *testing.T) { - t.Cleanup(func() { - goleak.VerifyNone(t) - }) + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) t.Run("fn is called", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/pki/validator_test.go b/pki/validator_test.go index 8115e426d2..c2581cece7 100644 --- a/pki/validator_test.go +++ b/pki/validator_test.go @@ -57,7 +57,7 @@ var crlPathMap = map[string]string{ } func TestValidator_Start(t *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) store, err := core.LoadTrustStore(truststorePKIo) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) diff --git a/storage/engine.go b/storage/engine.go index 5c8fdd57bb..3a00578de3 100644 --- a/storage/engine.go +++ b/storage/engine.go @@ -36,17 +36,19 @@ const storeShutdownTimeout = 5 * time.Second // New creates a new instance of the storage engine. func New() Engine { return &engine{ - storesMux: &sync.Mutex{}, - stores: map[string]stoabs.Store{}, + storesMux: &sync.Mutex{}, + stores: map[string]stoabs.Store{}, + sessionDatabase: NewInMemorySessionDatabase(), } } type engine struct { - datadir string - storesMux *sync.Mutex - stores map[string]stoabs.Store - databases []database - config Config + datadir string + storesMux *sync.Mutex + stores map[string]stoabs.Store + databases []database + sessionDatabase SessionDatabase + config Config } func (e *engine) Config() interface{} { @@ -84,9 +86,13 @@ func (e engine) Shutdown() error { failures = true } } + if failures { return errors.New("one or more stores failed to close") } + + e.sessionDatabase.close() + return nil } @@ -108,6 +114,7 @@ func (e *engine) Configure(config core.ServerConfig) error { return fmt.Errorf("unable to configure BBolt database: %w", err) } e.databases = append(e.databases, bboltDB) + return nil } @@ -118,6 +125,10 @@ func (e *engine) GetProvider(moduleName string) Provider { } } +func (e *engine) GetSessionDatabase() SessionDatabase { + return e.sessionDatabase +} + type provider struct { moduleName string engine *engine diff --git a/storage/interface.go b/storage/interface.go index f9b3f9b662..e23888542f 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -19,6 +19,7 @@ package storage import ( + "errors" "github.com/nuts-foundation/go-stoabs" "github.com/nuts-foundation/nuts-node/core" "time" @@ -34,6 +35,8 @@ type Engine interface { // GetProvider returns the Provider for the given module. GetProvider(moduleName string) Provider + // GetSessionDatabase returns the SessionDatabase + GetSessionDatabase() SessionDatabase } // Provider lets callers get access to stores. @@ -59,3 +62,32 @@ type database interface { getClass() Class close() } + +var ErrNotFound = errors.New("not found") + +// SessionDatabase is a non-persistent database that holds session data on a KV basis. +// Keys could be access tokens, nonce's, authorization codes, etc. +// All entries are stored with a TTL, so they will be removed automatically. +type SessionDatabase interface { + // GetStore returns a SessionStore with the given keys as key prefixes. + // The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification. + // The TTL is the time-to-live for the entries in the store. + GetStore(ttl time.Duration, keys ...string) SessionStore + // close stops any background processes and closes the database. + close() +} + +// SessionStore is a key-value store that holds session data. +// The SessionStore is an abstraction for underlying storage, it automatically adds prefixes for logical partitions. +type SessionStore interface { + // Delete deletes the entry for the given key. + // It does not return an error if the key does not exist. + Delete(key string) error + // Exists returns true if the key exists. + Exists(key string) bool + // Get returns the value for the given key. + // Returns ErrNotFound if the key does not exist. + Get(key string, target interface{}) error + // Put stores the given value for the given key. + Put(key string, value interface{}) error +} diff --git a/storage/mock.go b/storage/mock.go index f67a88e420..de364cd577 100644 --- a/storage/mock.go +++ b/storage/mock.go @@ -10,6 +10,7 @@ package storage import ( reflect "reflect" + time "time" stoabs "github.com/nuts-foundation/go-stoabs" core "github.com/nuts-foundation/nuts-node/core" @@ -67,6 +68,20 @@ func (mr *MockEngineMockRecorder) GetProvider(moduleName any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvider", reflect.TypeOf((*MockEngine)(nil).GetProvider), moduleName) } +// GetSessionDatabase mocks base method. +func (m *MockEngine) GetSessionDatabase() SessionDatabase { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionDatabase") + ret0, _ := ret[0].(SessionDatabase) + return ret0 +} + +// GetSessionDatabase indicates an expected call of GetSessionDatabase. +func (mr *MockEngineMockRecorder) GetSessionDatabase() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionDatabase", reflect.TypeOf((*MockEngine)(nil).GetSessionDatabase)) +} + // Shutdown mocks base method. func (m *MockEngine) Shutdown() error { m.ctrl.T.Helper() @@ -196,3 +211,136 @@ func (mr *MockdatabaseMockRecorder) getClass() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getClass", reflect.TypeOf((*Mockdatabase)(nil).getClass)) } + +// MockSessionDatabase is a mock of SessionDatabase interface. +type MockSessionDatabase struct { + ctrl *gomock.Controller + recorder *MockSessionDatabaseMockRecorder +} + +// MockSessionDatabaseMockRecorder is the mock recorder for MockSessionDatabase. +type MockSessionDatabaseMockRecorder struct { + mock *MockSessionDatabase +} + +// NewMockSessionDatabase creates a new mock instance. +func NewMockSessionDatabase(ctrl *gomock.Controller) *MockSessionDatabase { + mock := &MockSessionDatabase{ctrl: ctrl} + mock.recorder = &MockSessionDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSessionDatabase) EXPECT() *MockSessionDatabaseMockRecorder { + return m.recorder +} + +// GetStore mocks base method. +func (m *MockSessionDatabase) GetStore(ttl time.Duration, keys ...string) SessionStore { + m.ctrl.T.Helper() + varargs := []interface{}{ttl} + for _, a := range keys { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetStore", varargs...) + ret0, _ := ret[0].(SessionStore) + return ret0 +} + +// GetStore indicates an expected call of GetStore. +func (mr *MockSessionDatabaseMockRecorder) GetStore(ttl interface{}, keys ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ttl}, keys...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSessionDatabase)(nil).GetStore), varargs...) +} + +// close mocks base method. +func (m *MockSessionDatabase) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close. +func (mr *MockSessionDatabaseMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*MockSessionDatabase)(nil).close)) +} + +// MockSessionStore is a mock of SessionStore interface. +type MockSessionStore struct { + ctrl *gomock.Controller + recorder *MockSessionStoreMockRecorder +} + +// MockSessionStoreMockRecorder is the mock recorder for MockSessionStore. +type MockSessionStoreMockRecorder struct { + mock *MockSessionStore +} + +// NewMockSessionStore creates a new mock instance. +func NewMockSessionStore(ctrl *gomock.Controller) *MockSessionStore { + mock := &MockSessionStore{ctrl: ctrl} + mock.recorder = &MockSessionStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSessionStore) EXPECT() *MockSessionStoreMockRecorder { + return m.recorder +} + +// Delete mocks base method. +func (m *MockSessionStore) Delete(key string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", key) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockSessionStoreMockRecorder) Delete(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSessionStore)(nil).Delete), key) +} + +// Exists mocks base method. +func (m *MockSessionStore) Exists(key string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exists", key) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Exists indicates an expected call of Exists. +func (mr *MockSessionStoreMockRecorder) Exists(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockSessionStore)(nil).Exists), key) +} + +// Get mocks base method. +func (m *MockSessionStore) Get(key string, target interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", key, target) + ret0, _ := ret[0].(error) + return ret0 +} + +// Get indicates an expected call of Get. +func (mr *MockSessionStoreMockRecorder) Get(key, target interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionStore)(nil).Get), key, target) +} + +// Put mocks base method. +func (m *MockSessionStore) Put(key string, value interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", key, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// Put indicates an expected call of Put. +func (mr *MockSessionStoreMockRecorder) Put(key, value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockSessionStore)(nil).Put), key, value) +} diff --git a/storage/session.go b/storage/session.go new file mode 100644 index 0000000000..9ec851284b --- /dev/null +++ b/storage/session.go @@ -0,0 +1,169 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package storage + +import ( + "encoding/json" + "github.com/nuts-foundation/nuts-node/storage/log" + "strings" + "sync" + "time" +) + +var _ SessionDatabase = (*InMemorySessionDatabase)(nil) +var _ SessionStore = (*InMemorySessionStore)(nil) + +var sessionStorePruneInterval = 10 * time.Minute + +type expiringEntry struct { + // Value stores the actual value as JSON + Value string + Expiry time.Time +} + +// InMemorySessionDatabase is an in memory database that holds session data on a KV basis. +// Keys could be access tokens, nonce's, authorization codes, etc. +// All entries are stored with a TTL, so they will be removed automatically. +type InMemorySessionDatabase struct { + done chan struct{} + mux sync.RWMutex + routines sync.WaitGroup + entries map[string]expiringEntry +} + +// NewInMemorySessionDatabase creates a new in memory session database. +func NewInMemorySessionDatabase() *InMemorySessionDatabase { + result := &InMemorySessionDatabase{ + entries: map[string]expiringEntry{}, + done: make(chan struct{}, 10), + } + result.startPruning(sessionStorePruneInterval) + return result +} + +func (i *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) SessionStore { + return InMemorySessionStore{ + ttl: ttl, + prefixes: keys, + db: i, + } +} + +func (i *InMemorySessionDatabase) close() { + // Signal pruner to stop and wait for it to finish + i.done <- struct{}{} +} + +func (i *InMemorySessionDatabase) startPruning(interval time.Duration) { + ticker := time.NewTicker(interval) + i.routines.Add(1) + go func() { + defer i.routines.Done() + for { + select { + case <-i.done: + ticker.Stop() + return + case <-ticker.C: + valsPruned := i.prune() + if valsPruned > 0 { + log.Logger().Debugf("Pruned %d expired session variables", valsPruned) + } + } + } + }() +} + +func (i *InMemorySessionDatabase) prune() int { + i.mux.Lock() + defer i.mux.Unlock() + + moment := time.Now() + + // Find expired flows and delete them + var count int + for key, entry := range i.entries { + if entry.Expiry.Before(moment) { + count++ + delete(i.entries, key) + } + } + + return count +} + +type InMemorySessionStore struct { + ttl time.Duration + prefixes []string + db *InMemorySessionDatabase +} + +func (i InMemorySessionStore) Delete(key string) error { + i.db.mux.Lock() + defer i.db.mux.Unlock() + + delete(i.db.entries, i.getFullKey(key)) + return nil +} + +func (i InMemorySessionStore) Exists(key string) bool { + i.db.mux.Lock() + defer i.db.mux.Unlock() + + _, ok := i.db.entries[i.getFullKey(key)] + return ok +} + +func (i InMemorySessionStore) Get(key string, target interface{}) error { + i.db.mux.Lock() + defer i.db.mux.Unlock() + + fullKey := i.getFullKey(key) + entry, ok := i.db.entries[fullKey] + if !ok { + return ErrNotFound + } + if entry.Expiry.Before(time.Now()) { + delete(i.db.entries, fullKey) + return ErrNotFound + } + + return json.Unmarshal([]byte(entry.Value), target) +} + +func (i InMemorySessionStore) Put(key string, value interface{}) error { + i.db.mux.Lock() + defer i.db.mux.Unlock() + + bytes, err := json.Marshal(value) + if err != nil { + return err + } + entry := expiringEntry{ + Value: string(bytes), + Expiry: time.Now().Add(i.ttl), + } + + i.db.entries[i.getFullKey(key)] = entry + return nil +} + +func (i InMemorySessionStore) getFullKey(key string) string { + return strings.Join(append(i.prefixes, key), "/") +} diff --git a/storage/session_test.go b/storage/session_test.go new file mode 100644 index 0000000000..473b63c775 --- /dev/null +++ b/storage/session_test.go @@ -0,0 +1,239 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package storage + +import ( + "github.com/nuts-foundation/nuts-node/test" + "go.uber.org/goleak" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewInMemorySessionDatabase(t *testing.T) { + db := createDatabase(t) + + assert.NotNil(t, db) +} + +func TestInMemorySessionDatabase_GetStore(t *testing.T) { + db := createDatabase(t) + + store := db.GetStore(time.Minute, "key1", "key2").(InMemorySessionStore) + + require.NotNil(t, store) + assert.Equal(t, time.Minute, store.ttl) + assert.Equal(t, []string{"key1", "key2"}, store.prefixes) +} + +func TestInMemorySessionStore_Put(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore) + + t.Run("string value is stored", func(t *testing.T) { + err := store.Put("key", "value") + + require.NoError(t, err) + assert.Equal(t, `"value"`, store.db.entries["prefix/key"].Value) + }) + + t.Run("float value is stored", func(t *testing.T) { + err := store.Put("key", 1.23) + + require.NoError(t, err) + assert.Equal(t, "1.23", store.db.entries["prefix/key"].Value) + }) + + t.Run("struct value is stored", func(t *testing.T) { + value := testStruct{ + Field1: "value", + } + + err := store.Put("key", value) + + require.NoError(t, err) + assert.Equal(t, "{\"field1\":\"value\"}", store.db.entries["prefix/key"].Value) + }) + + t.Run("value is not JSON", func(t *testing.T) { + err := store.Put("key", make(chan int)) + + assert.Error(t, err) + }) +} + +func TestInMemorySessionStore_Get(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore) + + t.Run("string value is retrieved correctly", func(t *testing.T) { + _ = store.Put(t.Name(), "value") + var actual string + + err := store.Get(t.Name(), &actual) + + require.NoError(t, err) + assert.Equal(t, "value", actual) + }) + + t.Run("float value is retrieved correctly", func(t *testing.T) { + _ = store.Put(t.Name(), 1.23) + var actual float64 + + err := store.Get(t.Name(), &actual) + + require.NoError(t, err) + assert.Equal(t, 1.23, actual) + }) + + t.Run("struct value is retrieved correctly", func(t *testing.T) { + value := testStruct{ + Field1: "value", + } + _ = store.Put(t.Name(), value) + var actual testStruct + + err := store.Get(t.Name(), &actual) + + require.NoError(t, err) + assert.Equal(t, value, actual) + }) + + t.Run("value is not found", func(t *testing.T) { + var actual string + + err := store.Get(t.Name(), &actual) + + assert.Equal(t, ErrNotFound, err) + }) + + t.Run("value is expired", func(t *testing.T) { + store.db.entries["prefix/key"] = expiringEntry{ + Value: "", + Expiry: time.Now().Add(-time.Minute), + } + var actual string + + err := store.Get("key", &actual) + + assert.Equal(t, ErrNotFound, err) + }) + + t.Run("value is not JSON", func(t *testing.T) { + store.db.entries["prefix/key"] = expiringEntry{ + Value: "not JSON", + Expiry: time.Now().Add(time.Minute), + } + var actual string + + err := store.Get("key", &actual) + + assert.Error(t, err) + }) + + t.Run("value is not a pointer", func(t *testing.T) { + _ = store.Put(t.Name(), "value") + + err := store.Get(t.Name(), "not a pointer") + + assert.Error(t, err) + }) +} + +func TestInMemorySessionStore_Delete(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore) + + t.Run("value is deleted", func(t *testing.T) { + _ = store.Put(t.Name(), "value") + + err := store.Delete(t.Name()) + + require.NoError(t, err) + _, ok := store.db.entries["prefix/key"] + assert.False(t, ok) + }) + + t.Run("value is not found", func(t *testing.T) { + err := store.Delete(t.Name()) + + assert.NoError(t, err) + }) +} + +func TestInMemorySessionDatabase_Close(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + t.Run("assert Close() waits for pruning to finish to avoid leaking goroutines", func(t *testing.T) { + sessionStorePruneInterval = 10 * time.Millisecond + defer func() { + sessionStorePruneInterval = 10 * time.Minute + }() + store := NewInMemorySessionDatabase() + time.Sleep(50 * time.Millisecond) // make sure pruning is running + store.close() + }) +} + +func Test_memoryStore_prune(t *testing.T) { + t.Run("automatic", func(t *testing.T) { + store := createDatabase(t) + // we call startPruning a second time ourselves to speed things up, make sure not to leak the original goroutine + defer func() { + store.done <- struct{}{} + }() + store.startPruning(10 * time.Millisecond) + + err := store.GetStore(time.Millisecond).Put("key", "value") + require.NoError(t, err) + + test.WaitFor(t, func() (bool, error) { + store.mux.Lock() + defer store.mux.Unlock() + _, exists := store.entries["key"] + return !exists, nil + }, time.Second, "time-out waiting for entry to be pruned") + }) + t.Run("prunes expired flows", func(t *testing.T) { + store := createDatabase(t) + defer store.close() + + _ = store.GetStore(0).Put("key1", "value") + _ = store.GetStore(time.Minute).Put("key2", "value") + + count := store.prune() + + assert.Equal(t, 1, count) + + // Second round to assert there's nothing to prune now + count = store.prune() + + assert.Equal(t, 0, count) + }) +} + +type testStruct struct { + Field1 string `json:"field1"` +} + +func createDatabase(t *testing.T) *InMemorySessionDatabase { + return NewTestInMemorySessionDatabase(t) +} diff --git a/storage/test.go b/storage/test.go index eba2f9a337..d1c6c07116 100644 --- a/storage/test.go +++ b/storage/test.go @@ -67,3 +67,11 @@ func (p *StaticKVStoreProvider) GetKVStore(_ string, _ Class) (stoabs.KVStore, e } return p.Store, nil } + +func NewTestInMemorySessionDatabase(t *testing.T) *InMemorySessionDatabase { + db := NewInMemorySessionDatabase() + t.Cleanup(func() { + db.close() + }) + return db +} diff --git a/vcr/issuer/openid.go b/vcr/issuer/openid.go index a8ee1ec298..18b3d68752 100644 --- a/vcr/issuer/openid.go +++ b/vcr/issuer/openid.go @@ -34,6 +34,7 @@ import ( "github.com/nuts-foundation/nuts-node/audit" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/crypto" + "github.com/nuts-foundation/nuts-node/storage" "github.com/nuts-foundation/nuts-node/vcr/issuer/assets" "github.com/nuts-foundation/nuts-node/vcr/log" "github.com/nuts-foundation/nuts-node/vcr/openid4vci" @@ -57,14 +58,6 @@ type Flow struct { // Credentials is the list of Verifiable Credentials that be issued to the wallet through this flow. // It might be pre-determined (in the issuer-initiated flow) or determined during the flow execution (in the wallet-initiated flow). Credentials []vc.VerifiableCredential `json:"credentials"` - Expiry time.Time `json:"exp"` -} - -// Nonce is a nonce that has been issued for an OpenID4VCI flow, to be used by the wallet when requesting credentials. -// A nonce can only be used once (doh), and is only valid for a certain period of time. -type Nonce struct { - Nonce string `json:"nonce"` - Expiry time.Time `json:"exp"` } // Grant is a grant that has been issued for an OAuth2 state. @@ -75,8 +68,6 @@ type Grant struct { Params map[string]interface{} `json:"params"` } -// ErrUnknownIssuer is returned when the given issuer is unknown. -var ErrUnknownIssuer = errors.New("unknown OpenID4VCI issuer") var _ OpenIDHandler = (*openidHandler)(nil) // TokenTTL is the time-to-live for issuance flows, access tokens and nonces. @@ -105,7 +96,7 @@ type OpenIDHandler interface { } // NewOpenIDHandler creates a new OpenIDHandler instance. The identifier is the Credential Issuer Identifier, e.g. https://example.com/issuer/ -func NewOpenIDHandler(issuerDID did.DID, issuerIdentifierURL string, definitionsDIR string, httpClient core.HTTPRequestDoer, keyResolver resolver.KeyResolver, store OpenIDStore) (OpenIDHandler, error) { +func NewOpenIDHandler(issuerDID did.DID, issuerIdentifierURL string, definitionsDIR string, httpClient core.HTTPRequestDoer, keyResolver resolver.KeyResolver, sessionDatabase storage.SessionDatabase) (OpenIDHandler, error) { i := &openidHandler{ issuerIdentifierURL: issuerIdentifierURL, issuerDID: issuerDID, @@ -113,7 +104,7 @@ func NewOpenIDHandler(issuerDID did.DID, issuerIdentifierURL string, definitions httpClient: httpClient, keyResolver: keyResolver, walletClientCreator: openid4vci.NewWalletAPIClient, - store: store, + store: NewOpenIDMemoryStore(sessionDatabase), } // load the credential definitions. This is done to halt startup procedure if needed. @@ -174,12 +165,12 @@ func (i *openidHandler) HandleAccessTokenRequest(ctx context.Context, preAuthori } } accessToken := generateCode() - err = i.store.StoreReference(ctx, flow.ID, accessTokenRefType, accessToken, time.Now().Add(TokenTTL)) + err = i.store.StoreReference(ctx, flow.ID, accessTokenRefType, accessToken) if err != nil { return "", "", err } cNonce := generateCode() - err = i.store.StoreReference(ctx, flow.ID, cNonceRefType, cNonce, time.Now().Add(TokenTTL)) + err = i.store.StoreReference(ctx, flow.ID, cNonceRefType, cNonce) if err != nil { return "", "", err } @@ -294,7 +285,7 @@ func (i *openidHandler) validateProof(ctx context.Context, flow *Flow, request o // augment invalid_proof errors according to ยง7.3.2 of openid4vci spec generateProofError := func(err openid4vci.Error) error { cnonce := generateCode() - if err := i.store.StoreReference(ctx, flow.ID, cNonceRefType, cnonce, time.Now().Add(TokenTTL)); err != nil { + if err := i.store.StoreReference(ctx, flow.ID, cNonceRefType, cnonce); err != nil { return err } expiry := int(TokenTTL.Seconds()) @@ -438,7 +429,6 @@ func (i *openidHandler) createOffer(ctx context.Context, credential vc.Verifiabl ID: uuid.NewString(), IssuerID: credential.Issuer.String(), WalletID: subjectDID.String(), - Expiry: time.Now().Add(TokenTTL), Credentials: []vc.VerifiableCredential{credential}, Grants: []Grant{ { @@ -449,7 +439,7 @@ func (i *openidHandler) createOffer(ctx context.Context, credential vc.Verifiabl } err := i.store.Store(ctx, flow) if err == nil { - err = i.store.StoreReference(ctx, flow.ID, preAuthCodeRefType, preAuthorizedCode, time.Now().Add(TokenTTL)) + err = i.store.StoreReference(ctx, flow.ID, preAuthCodeRefType, preAuthorizedCode) } if err != nil { return nil, fmt.Errorf("unable to store credential offer: %w", err) diff --git a/vcr/issuer/openid_store.go b/vcr/issuer/openid_store.go index 689556cae8..0471301164 100644 --- a/vcr/issuer/openid_store.go +++ b/vcr/issuer/openid_store.go @@ -21,9 +21,7 @@ package issuer import ( "context" "errors" - "github.com/nuts-foundation/nuts-node/vcr/log" - "sync" - "time" + "github.com/nuts-foundation/nuts-node/storage" ) // OpenIDStore defines the storage API for OpenID Credential Issuance flows. @@ -35,164 +33,71 @@ type OpenIDStore interface { // like a database index. The reference must be unique for all flows. // The expiry is the time-to-live for the reference. After this time, the reference is automatically deleted. // If the flow does not exist, or the reference does already exist, it returns an error. - StoreReference(ctx context.Context, flowID string, refType string, reference string, expiry time.Time) error + StoreReference(ctx context.Context, flowID string, refType string, reference string) error // FindByReference finds a Flow by its reference. // If the flow does not exist, it returns nil. FindByReference(ctx context.Context, refType string, reference string) (*Flow, error) // DeleteReference deletes the reference from the store. // It does not return an error if it doesn't exist anymore. DeleteReference(ctx context.Context, refType string, reference string) error - // Close signals the store to close any owned resources. - Close() } var _ OpenIDStore = (*openidMemoryStore)(nil) -var openidStorePruneInterval = 10 * time.Minute - type openidMemoryStore struct { - mux *sync.RWMutex - flows map[string]Flow - refs map[string]map[string]referenceValue - routines *sync.WaitGroup - ctx context.Context - cancel context.CancelFunc + sessionDatabase storage.SessionDatabase } // NewOpenIDMemoryStore creates a new in-memory OpenIDStore. -func NewOpenIDMemoryStore() OpenIDStore { - result := &openidMemoryStore{ - mux: &sync.RWMutex{}, - flows: map[string]Flow{}, - refs: map[string]map[string]referenceValue{}, - routines: &sync.WaitGroup{}, +func NewOpenIDMemoryStore(sessionDatabase storage.SessionDatabase) OpenIDStore { + return &openidMemoryStore{ + sessionDatabase: sessionDatabase, } - result.ctx, result.cancel = context.WithCancel(context.Background()) - result.startPruning(openidStorePruneInterval) - return result -} - -type referenceValue struct { - FlowID string `json:"flow_id"` - Expiry time.Time `json:"exp"` } func (o *openidMemoryStore) Store(_ context.Context, flow Flow) error { if len(flow.ID) == 0 { return errors.New("invalid flow ID") } - o.mux.Lock() - defer o.mux.Unlock() - if o.flows[flow.ID].ID != "" { + store := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", "flow") + if store.Exists(flow.ID) { return errors.New("OAuth2 flow with this ID already exists") } - o.flows[flow.ID] = flow - return nil + return store.Put(flow.ID, flow) } -func (o *openidMemoryStore) StoreReference(_ context.Context, flowID string, refType string, reference string, expiry time.Time) error { +func (o *openidMemoryStore) StoreReference(_ context.Context, flowID string, refType string, reference string) error { if len(reference) == 0 { return errors.New("invalid reference") } - o.mux.Lock() - defer o.mux.Unlock() - if o.flows[flowID].ID == "" { + refStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", refType) + flowStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", "flow") + if !flowStore.Exists(flowID) { return errors.New("OAuth2 flow with this ID does not exist") } - if o.refs[refType] == nil { - o.refs[refType] = map[string]referenceValue{} - } - if _, ok := o.refs[refType][reference]; ok { + if refStore.Exists(reference) { return errors.New("reference already exists") } - o.refs[refType][reference] = referenceValue{FlowID: flowID, Expiry: expiry} - return nil + return refStore.Put(reference, flowID) } func (o *openidMemoryStore) FindByReference(_ context.Context, refType string, reference string) (*Flow, error) { - o.mux.RLock() - defer o.mux.RUnlock() - - refMap := o.refs[refType] - if refMap == nil { - return nil, nil - } - value, ok := refMap[reference] - if !ok { - return nil, nil - } - if value.Expiry.Before(time.Now()) { + refStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", refType) + flowStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", "flow") + if !refStore.Exists(reference) { return nil, nil } - - flow := o.flows[value.FlowID] - if flow.Expiry.Before(time.Now()) { - return nil, nil + var flowID string + err := refStore.Get(reference, &flowID) + if err != nil { + return nil, err } - return &flow, nil + var flow Flow + err = flowStore.Get(flowID, &flow) + return &flow, err } func (o *openidMemoryStore) DeleteReference(_ context.Context, refType string, reference string) error { - o.mux.Lock() - defer o.mux.Unlock() - - if o.refs[refType] == nil { - return nil - } - delete(o.refs[refType], reference) - return nil -} - -func (o *openidMemoryStore) Close() { - // Signal pruner to stop and wait for it to finish - o.cancel() - o.routines.Wait() -} - -func (o *openidMemoryStore) startPruning(interval time.Duration) { - ticker := time.NewTicker(interval) - o.routines.Add(1) - go func(ctx context.Context) { - defer o.routines.Done() - for { - select { - case <-ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - flowsPruned, refsPruned := o.prune() - if flowsPruned > 0 || refsPruned > 0 { - log.Logger().Debugf("Pruned %d expired OpenID4VCI flows and %d expired refs", flowsPruned, refsPruned) - } - } - } - }(o.ctx) -} - -func (o *openidMemoryStore) prune() (int, int) { - o.mux.Lock() - defer o.mux.Unlock() - - moment := time.Now() - - // Find expired flows and delete them - var flowCount int - for id, flow := range o.flows { - if flow.Expiry.Before(moment) { - flowCount++ - delete(o.flows, id) - } - } - // Find expired refs and delete them - var refCount int - for _, refMap := range o.refs { - for reference, value := range refMap { - if value.Expiry.Before(moment) { - refCount++ - delete(refMap, reference) - } - } - } - - return flowCount, refCount + refStore := o.sessionDatabase.GetStore(TokenTTL, "openid4vci", refType) + return refStore.Delete(reference) } diff --git a/vcr/issuer/openid_store_test.go b/vcr/issuer/openid_store_test.go index a779e949ce..9fcbf80109 100644 --- a/vcr/issuer/openid_store_test.go +++ b/vcr/issuer/openid_store_test.go @@ -20,11 +20,9 @@ package issuer import ( "context" - "github.com/nuts-foundation/nuts-node/test" + "github.com/nuts-foundation/nuts-node/storage" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "testing" - "time" ) const refType = "ref-type" @@ -34,12 +32,11 @@ func Test_memoryStore_DeleteReference(t *testing.T) { t.Run("ok", func(t *testing.T) { store := createStore(t) expected := Flow{ - ID: "flow-id", - Expiry: futureExpiry(), + ID: "flow-id", } err := store.Store(context.Background(), expected) assert.NoError(t, err) - err = store.StoreReference(context.Background(), expected.ID, refType, ref, futureExpiry()) + err = store.StoreReference(context.Background(), expected.ID, refType, ref) assert.NoError(t, err) err = store.DeleteReference(context.Background(), refType, ref) @@ -63,49 +60,31 @@ func Test_memoryStore_FindByReference(t *testing.T) { t.Run("reference already exists", func(t *testing.T) { store := createStore(t) expected := Flow{ - ID: "flow-id", - Expiry: futureExpiry(), + ID: "flow-id", } err := store.Store(context.Background(), expected) assert.NoError(t, err) - err = store.StoreReference(context.Background(), expected.ID, refType, ref, futureExpiry()) + err = store.StoreReference(context.Background(), expected.ID, refType, ref) assert.NoError(t, err) - err = store.StoreReference(context.Background(), expected.ID, refType, ref, futureExpiry()) + err = store.StoreReference(context.Background(), expected.ID, refType, ref) assert.EqualError(t, err, "reference already exists") }) t.Run("invalid reference", func(t *testing.T) { store := createStore(t) - err := store.StoreReference(context.Background(), "unknown", refType, "", futureExpiry()) + err := store.StoreReference(context.Background(), "unknown", refType, "") assert.EqualError(t, err, "invalid reference") }) t.Run("unknown flow", func(t *testing.T) { store := createStore(t) - err := store.StoreReference(context.Background(), "unknown", refType, ref, futureExpiry()) + err := store.StoreReference(context.Background(), "unknown", refType, ref) assert.EqualError(t, err, "OAuth2 flow with this ID does not exist") }) - t.Run("reference has expired", func(t *testing.T) { - store := createStore(t) - expected := Flow{ - ID: "flow-id", - Expiry: futureExpiry(), - } - - err := store.Store(context.Background(), expected) - assert.NoError(t, err) - // We need a reference to resolve it - err = store.StoreReference(context.Background(), expected.ID, refType, ref, pastExpiry()) - assert.NoError(t, err) - - actual, err := store.FindByReference(context.Background(), refType, ref) - assert.NoError(t, err) - assert.Nil(t, actual) - }) } func Test_memoryStore_Store(t *testing.T) { @@ -113,14 +92,13 @@ func Test_memoryStore_Store(t *testing.T) { t.Run("write, then read", func(t *testing.T) { store := createStore(t) expected := Flow{ - ID: "flow-id", - Expiry: futureExpiry(), + ID: "flow-id", } err := store.Store(ctx, expected) assert.NoError(t, err) // We need a reference to resolve it - err = store.StoreReference(ctx, expected.ID, refType, ref, futureExpiry()) + err = store.StoreReference(ctx, expected.ID, refType, ref) assert.NoError(t, err) actual, err := store.FindByReference(ctx, refType, ref) @@ -130,8 +108,7 @@ func Test_memoryStore_Store(t *testing.T) { t.Run("already exists", func(t *testing.T) { store := createStore(t) expected := Flow{ - ID: "flow-id", - Expiry: futureExpiry(), + ID: "flow-id", } err := store.Store(ctx, expected) @@ -140,124 +117,10 @@ func Test_memoryStore_Store(t *testing.T) { assert.EqualError(t, err, "OAuth2 flow with this ID already exists") }) - t.Run("flow has expired", func(t *testing.T) { - store := createStore(t) - expected := Flow{ - ID: "flow-id", - Expiry: pastExpiry(), - } - - err := store.Store(ctx, expected) - assert.NoError(t, err) - // We need a reference to resolve it - err = store.StoreReference(ctx, expected.ID, refType, ref, futureExpiry()) - assert.NoError(t, err) - - actual, err := store.FindByReference(ctx, refType, ref) - assert.NoError(t, err) - assert.Nil(t, actual) - }) -} - -func Test_memoryStore_Close(t *testing.T) { - t.Run("assert Close() waits for pruning to finish to avoid leaking goroutines", func(t *testing.T) { - openidStorePruneInterval = 10 * time.Millisecond - store := createStore(t) - time.Sleep(50 * time.Millisecond) // make sure pruning is running - store.Close() - }) -} - -func Test_memoryStore_prune(t *testing.T) { - ctx := context.Background() - t.Run("automatic", func(t *testing.T) { - store := createStore(t) - // we call startPruning a second time ourselves, make sure not to leak the original goroutine - cancelFunc := store.cancel - defer cancelFunc() - store.startPruning(10 * time.Millisecond) - - // Feed it something to prune - expiredFlow := Flow{ - ID: "expired", - } - err := store.Store(ctx, expiredFlow) - require.NoError(t, err) - - test.WaitFor(t, func() (bool, error) { - store.mux.Lock() - defer store.mux.Unlock() - _, exists := store.flows[expiredFlow.ID] - return !exists, nil - }, time.Second, "time-out waiting for flow to be pruned") - }) - t.Run("prunes expired flows", func(t *testing.T) { - store := createStore(t) - - expiredFlow := Flow{ - ID: "expired", - } - unexpiredFlow := Flow{ - ID: "unexpired", - Expiry: futureExpiry(), - } - _ = store.Store(ctx, expiredFlow) - _ = store.Store(ctx, unexpiredFlow) - - flows, refs := store.prune() - - assert.Equal(t, 1, flows) - assert.Equal(t, 0, refs) - - // Second round to assert there's nothing to prune now - flows, refs = store.prune() - - assert.Equal(t, 0, flows) - assert.Equal(t, 0, refs) - }) - t.Run("prunes expired refs", func(t *testing.T) { - store := createStore(t) - - flow := Flow{ - ID: "f", - Expiry: futureExpiry(), - } - err := store.Store(ctx, flow) - require.NoError(t, err) - err = store.StoreReference(ctx, flow.ID, refType, "expired", pastExpiry()) - require.NoError(t, err) - err = store.StoreReference(ctx, flow.ID, refType, "unexpired", futureExpiry()) - require.NoError(t, err) - - flows, refs := store.prune() - - assert.Equal(t, 0, flows) - assert.Equal(t, 1, refs) - - // Second round to assert there's nothing to prune now - flows, refs = store.prune() - - assert.NoError(t, err) - assert.Equal(t, 0, flows) - assert.Equal(t, 0, refs) - }) } func createStore(t *testing.T) *openidMemoryStore { - store := NewOpenIDMemoryStore().(*openidMemoryStore) - t.Cleanup(store.Close) + storageDatabase := storage.NewTestInMemorySessionDatabase(t) + store := NewOpenIDMemoryStore(storageDatabase).(*openidMemoryStore) return store } - -func moment() time.Time { - return time.Now().In(time.UTC) -} - -func pastExpiry() time.Time { - return moment().Add(-time.Hour) -} - -func futureExpiry() time.Time { - // truncating makes assertion easier - return moment().Add(time.Hour).Truncate(time.Second) -} diff --git a/vcr/issuer/openid_test.go b/vcr/issuer/openid_test.go index c19199c7e5..281eeaaa0f 100644 --- a/vcr/issuer/openid_test.go +++ b/vcr/issuer/openid_test.go @@ -28,6 +28,7 @@ import ( "github.com/nuts-foundation/nuts-node/audit" "github.com/nuts-foundation/nuts-node/core" "github.com/nuts-foundation/nuts-node/crypto" + "github.com/nuts-foundation/nuts-node/storage" "github.com/nuts-foundation/nuts-node/vcr/openid4vci" "github.com/nuts-foundation/nuts-node/vdr/resolver" "github.com/stretchr/testify/assert" @@ -65,21 +66,21 @@ var issuedVC = vc.VerifiableCredential{ func TestNew(t *testing.T) { t.Run("custom definitions", func(t *testing.T) { - iss, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/valid", nil, nil, NewOpenIDMemoryStore()) + iss, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/valid", nil, nil, storage.NewTestInMemorySessionDatabase(t)) require.NoError(t, err) assert.Len(t, iss.(*openidHandler).credentialsSupported, 3) }) t.Run("error - invalid json", func(t *testing.T) { - _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/invalid", nil, nil, NewOpenIDMemoryStore()) + _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/invalid", nil, nil, storage.NewTestInMemorySessionDatabase(t)) require.Error(t, err) assert.EqualError(t, err, "failed to parse credential definition from test/invalid/invalid.json: unexpected end of JSON input") }) t.Run("error - invalid directory", func(t *testing.T) { - _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/non_existing", nil, nil, NewOpenIDMemoryStore()) + _, err := NewOpenIDHandler(issuerDID, issuerIdentifier, "./test/non_existing", nil, nil, storage.NewTestInMemorySessionDatabase(t)) require.Error(t, err) assert.EqualError(t, err, "failed to load credential definitions: lstat ./test/non_existing: no such file or directory") @@ -396,7 +397,7 @@ func Test_memoryIssuer_HandleAccessTokenRequest(t *testing.T) { assert.NotEmpty(t, accessToken) }) t.Run("pre-authorized code issued by other issuer", func(t *testing.T) { - store := NewOpenIDMemoryStore() + store := storage.NewTestInMemorySessionDatabase(t) service, err := NewOpenIDHandler(issuerDID, issuerIdentifier, definitionsDIR, &http.Client{}, nil, store) require.NoError(t, err) _, err = service.(*openidHandler).createOffer(ctx, issuedVC, "code") @@ -435,7 +436,7 @@ func assertProtocolError(t *testing.T, err error, statusCode int, message string } func requireNewTestHandler(t *testing.T, keyResolver resolver.KeyResolver) *openidHandler { - service, err := NewOpenIDHandler(issuerDID, issuerIdentifier, definitionsDIR, &http.Client{}, keyResolver, NewOpenIDMemoryStore()) + service, err := NewOpenIDHandler(issuerDID, issuerIdentifier, definitionsDIR, &http.Client{}, keyResolver, storage.NewTestInMemorySessionDatabase(t)) require.NoError(t, err) return service.(*openidHandler) } diff --git a/vcr/vcr.go b/vcr/vcr.go index 81fe82d4b5..ea44364c24 100644 --- a/vcr/vcr.go +++ b/vcr/vcr.go @@ -99,7 +99,7 @@ type vcr struct { jsonldManager jsonld.JSONLD eventManager events.Event storageClient storage.Engine - openidIsssuerStore issuer.OpenIDStore + openidSessionStore storage.SessionDatabase localWalletResolver openid4vci.IdentifierResolver issuerHttpClient core.HTTPRequestDoer walletHttpClient core.HTTPRequestDoer @@ -112,7 +112,7 @@ func (c *vcr) GetOpenIDIssuer(ctx context.Context, id did.DID) (issuer.OpenIDHan if err != nil { return nil, err } - return issuer.NewOpenIDHandler(id, identifier, c.config.OpenID4VCI.DefinitionsDIR, c.issuerHttpClient, c.keyResolver, c.openidIsssuerStore) + return issuer.NewOpenIDHandler(id, identifier, c.config.OpenID4VCI.DefinitionsDIR, c.issuerHttpClient, c.keyResolver, c.openidSessionStore) } func (c *vcr) GetOpenIDHolder(ctx context.Context, id did.DID) (holder.OpenIDHandler, error) { @@ -269,7 +269,7 @@ func (c *vcr) Configure(config core.ServerConfig) error { Timeout: c.config.OpenID4VCI.Timeout, Transport: walletTransport, }) - c.openidIsssuerStore = issuer.NewOpenIDMemoryStore() + c.openidSessionStore = c.storageClient.GetSessionDatabase() } c.issuer = issuer.NewIssuer(c.issuerStore, c, networkPublisher, openidHandlerFn, didResolver, c.keyStore, c.jsonldManager, c.trustConfig) c.verifier = verifier.NewVerifier(c.verifierStore, didResolver, c.keyResolver, c.jsonldManager, c.trustConfig) @@ -329,9 +329,6 @@ func (c *vcr) Start() error { } func (c *vcr) Shutdown() error { - if c.openidIsssuerStore != nil { - c.openidIsssuerStore.Close() - } err := c.issuerStore.Close() if err != nil { log.Logger().