diff --git a/storage/interface.go b/storage/interface.go index 83ab900ad0..2a3d379d8c 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -20,10 +20,11 @@ package storage import ( "errors" + "time" + "github.com/nuts-foundation/go-stoabs" "github.com/nuts-foundation/nuts-node/core" "gorm.io/gorm" - "time" ) const lockAcquireTimeout = time.Second @@ -94,4 +95,6 @@ type SessionStore interface { Get(key string, target interface{}) error // Put stores the given value for the given key. Put(key string, value interface{}) error + // GetAndDelete combines Get and Delete as a convenience for burning nonce entries. + GetAndDelete(key string, target interface{}) error } diff --git a/storage/mock.go b/storage/mock.go index 241edf7804..5beb253a22 100644 --- a/storage/mock.go +++ b/storage/mock.go @@ -347,6 +347,20 @@ func (mr *MockSessionStoreMockRecorder) Get(key, target any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionStore)(nil).Get), key, target) } +// GetAndDelete mocks base method. +func (m *MockSessionStore) GetAndDelete(key string, target any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAndDelete", key, target) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetAndDelete indicates an expected call of GetAndDelete. +func (mr *MockSessionStoreMockRecorder) GetAndDelete(key, target any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAndDelete", reflect.TypeOf((*MockSessionStore)(nil).GetAndDelete), key, target) +} + // Put mocks base method. func (m *MockSessionStore) Put(key string, value any) error { m.ctrl.T.Helper() diff --git a/storage/session_inmemory.go b/storage/session_inmemory.go index 9ec851284b..c5dfa96ccb 100644 --- a/storage/session_inmemory.go +++ b/storage/session_inmemory.go @@ -20,10 +20,11 @@ package storage import ( "encoding/json" - "github.com/nuts-foundation/nuts-node/storage/log" "strings" "sync" "time" + + "github.com/nuts-foundation/nuts-node/storage/log" ) var _ SessionDatabase = (*InMemorySessionDatabase)(nil) @@ -133,7 +134,10 @@ func (i InMemorySessionStore) Exists(key string) bool { func (i InMemorySessionStore) Get(key string, target interface{}) error { i.db.mux.Lock() defer i.db.mux.Unlock() + return i.get(key, target) +} +func (i InMemorySessionStore) get(key string, target interface{}) error { fullKey := i.getFullKey(key) entry, ok := i.db.entries[fullKey] if !ok { @@ -163,6 +167,15 @@ func (i InMemorySessionStore) Put(key string, value interface{}) error { i.db.entries[i.getFullKey(key)] = entry return nil } +func (i InMemorySessionStore) GetAndDelete(key string, target interface{}) error { + i.db.mux.Lock() + defer i.db.mux.Unlock() + if err := i.get(key, target); err != nil { + return err + } + delete(i.db.entries, i.getFullKey(key)) + return nil +} func (i InMemorySessionStore) getFullKey(key string) string { return strings.Join(append(i.prefixes, key), "/") diff --git a/storage/session_inmemory_test.go b/storage/session_inmemory_test.go index 473b63c775..bfc29aa77a 100644 --- a/storage/session_inmemory_test.go +++ b/storage/session_inmemory_test.go @@ -19,11 +19,12 @@ package storage import ( - "github.com/nuts-foundation/nuts-node/test" - "go.uber.org/goleak" "testing" "time" + "github.com/nuts-foundation/nuts-node/test" + "go.uber.org/goleak" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -179,6 +180,26 @@ func TestInMemorySessionStore_Delete(t *testing.T) { }) } +func TestInMemorySessionStore_GetAndDelete(t *testing.T) { + db := createDatabase(t) + store := db.GetStore(time.Minute, "prefix").(InMemorySessionStore) + + t.Run("ok", func(t *testing.T) { + _ = store.Put(t.Name(), "value") + var actual string + + err := store.GetAndDelete(t.Name(), &actual) + + require.NoError(t, err) + assert.Equal(t, "value", actual) + // is deleted + assert.ErrorIs(t, store.Get(t.Name(), new(string)), ErrNotFound) + }) + t.Run("error", func(t *testing.T) { + assert.ErrorIs(t, store.GetAndDelete(t.Name(), new(string)), ErrNotFound) + }) +} + func TestInMemorySessionDatabase_Close(t *testing.T) { defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) diff --git a/storage/session_redis.go b/storage/session_redis.go index a7417a20d6..83ca50ffa6 100644 --- a/storage/session_redis.go +++ b/storage/session_redis.go @@ -22,10 +22,11 @@ import ( "context" "encoding/json" "errors" - "github.com/nuts-foundation/nuts-node/storage/log" - "github.com/redis/go-redis/v9" "strings" "time" + + "github.com/nuts-foundation/nuts-node/storage/log" + "github.com/redis/go-redis/v9" ) func NewRedisSessionDatabase(client *redis.Client, prefix string) SessionDatabase { @@ -96,7 +97,18 @@ func (s redisSessionStore) Put(key string, value interface{}) error { return err } return s.client.Set(context.Background(), s.toRedisKey(key), marshal, s.ttl).Err() +} +func (s redisSessionStore) GetAndDelete(key string, target interface{}) error { + // GetDel requires redis-server version >= 6.2.0. + result, err := s.client.GetDel(context.Background(), s.toRedisKey(key)).Result() + if err != nil { + if errors.Is(redis.Nil, err) { + return ErrNotFound + } + return err + } + return json.Unmarshal([]byte(result), target) } func (s redisSessionStore) toRedisKey(key string) string { diff --git a/storage/session_redis_test.go b/storage/session_redis_test.go index 53d85c56f9..b957c15b55 100644 --- a/storage/session_redis_test.go +++ b/storage/session_redis_test.go @@ -21,11 +21,12 @@ package storage import ( "encoding/json" "errors" + "testing" + "time" + "github.com/go-redis/redismock/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" - "time" ) const testKey = "keyname" @@ -93,7 +94,6 @@ func TestRedisSessionStore_Delete(t *testing.T) { // We make sure the value exists in another store, // to test partitioning otherStore := sessions.GetStore(time.Second, "unit_other") - assert.NoError(t, otherStore.Put(testKey, testValue)) t.Run("non-existing key", func(t *testing.T) { @@ -107,6 +107,47 @@ func TestRedisSessionStore_Delete(t *testing.T) { }) } +func TestRedisSessionStore_GetAndDelete(t *testing.T) { + storageEngine, miniRedis := NewTestStorageEngineRedis(t) + require.NoError(t, storageEngine.Start()) + sessions := storageEngine.GetSessionDatabase() + defer sessions.close() + + t.Run("ok", func(t *testing.T) { + var actual testType + + store := sessions.GetStore(time.Minute, "storename") + assert.NoError(t, store.Put(testKey, testValue)) + // We make sure the value exists in another store, + // to test partitioning + otherStore := sessions.GetStore(time.Second, "unit_other") + assert.NoError(t, otherStore.Put(testKey, testValue)) + + err := store.GetAndDelete(testKey, &actual) + assert.NoError(t, err) + // deleted + assert.False(t, store.Exists(testKey)) + + // Make sure it did not delete an entry with the same key from another store + assert.True(t, otherStore.Exists(testKey)) + }) + t.Run("non-existing key", func(t *testing.T) { + store := sessions.GetStore(time.Minute, "storename") + err := store.GetAndDelete(testKey, new(testType)) + + assert.ErrorIs(t, err, ErrNotFound) + }) + t.Run("expired entry", func(t *testing.T) { + store := sessions.GetStore(time.Minute, "otherstore") + assert.NoError(t, store.Put(testKey, testValue)) + miniRedis.FastForward(2 * time.Minute) // cause the entry to expire + + err := store.GetAndDelete(testKey, new(testType)) + + assert.ErrorIs(t, err, ErrNotFound) + }) +} + func TestRedisSessionStore_Exists(t *testing.T) { store, miniRedis := NewTestStorageEngineRedis(t) require.NoError(t, store.Start())