diff --git a/cache/cache.go b/cache/cache.go index 78c8ea966..2a3091439 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -13,6 +13,11 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +var ( + ErrNotStored = errors.New("item not stored") + ErrInvalidTTL = errors.New("invalid TTL") +) + // Cache is a high level interface to interact with a cache. type Cache interface { // GetMulti fetches multiple keys at once from a cache. In case of error, @@ -28,6 +33,14 @@ type Cache interface { // any underlying async operations fail, the errors will be tracked/logged. SetMultiAsync(data map[string][]byte, ttl time.Duration) + // Set stores a key and value into a cache. + Set(ctx context.Context, key string, value []byte, ttl time.Duration) error + + // Add stores a key and value into a cache only if it does not already exist. If the + // item was not stored because an entry already exists in the cache, ErrNotStored will + // be returned. + Add(ctx context.Context, key string, value []byte, ttl time.Duration) error + // Delete deletes a key from a cache. This is a synchronous operation. If an asynchronous // set operation for key is still pending to be processed, it will wait for it to complete // before performing deletion. diff --git a/cache/client.go b/cache/client.go index 033d2add6..1d3c87c56 100644 --- a/cache/client.go +++ b/cache/client.go @@ -17,6 +17,7 @@ import ( // Common functionality shared between the Memcached and Redis Cache implementations const ( + opAdd = "add" opSet = "set" opGetMulti = "getmulti" opDelete = "delete" @@ -29,6 +30,8 @@ const ( reasonMaxItemSize = "max-item-size" reasonAsyncBufferFull = "async-buffer-full" reasonMalformedKey = "malformed-key" + reasonInvalidTTL = "invalid-ttl" + reasonNotStored = "not-stored" reasonConnectTimeout = "connect-timeout" reasonTimeout = "request-timeout" reasonServerError = "server-error" @@ -85,10 +88,12 @@ func newClientMetrics(reg prometheus.Registerer) *clientMetrics { Name: "operation_failures_total", Help: "Total number of operations against cache that failed.", }, []string{"operation", "reason"}) - for _, op := range []string{opGetMulti, opSet, opDelete, opIncrement, opFlush, opTouch, opCompareAndSwap} { + for _, op := range []string{opGetMulti, opAdd, opSet, opDelete, opIncrement, opFlush, opTouch, opCompareAndSwap} { cm.failures.WithLabelValues(op, reasonConnectTimeout) cm.failures.WithLabelValues(op, reasonTimeout) cm.failures.WithLabelValues(op, reasonMalformedKey) + cm.failures.WithLabelValues(op, reasonInvalidTTL) + cm.failures.WithLabelValues(op, reasonNotStored) cm.failures.WithLabelValues(op, reasonServerError) cm.failures.WithLabelValues(op, reasonNetworkError) cm.failures.WithLabelValues(op, reasonOther) @@ -99,6 +104,7 @@ func newClientMetrics(reg prometheus.Registerer) *clientMetrics { Help: "Total number of operations against cache that have been skipped.", }, []string{"operation", "reason"}) cm.skipped.WithLabelValues(opGetMulti, reasonMaxItemSize) + cm.skipped.WithLabelValues(opAdd, reasonMaxItemSize) cm.skipped.WithLabelValues(opSet, reasonMaxItemSize) cm.skipped.WithLabelValues(opSet, reasonAsyncBufferFull) @@ -112,6 +118,7 @@ func newClientMetrics(reg prometheus.Registerer) *clientMetrics { NativeHistogramMinResetDuration: time.Hour, }, []string{"operation"}) cm.duration.WithLabelValues(opGetMulti) + cm.duration.WithLabelValues(opAdd) cm.duration.WithLabelValues(opSet) cm.duration.WithLabelValues(opDelete) cm.duration.WithLabelValues(opIncrement) @@ -129,6 +136,7 @@ func newClientMetrics(reg prometheus.Registerer) *clientMetrics { []string{"operation"}, ) cm.dataSize.WithLabelValues(opGetMulti) + cm.dataSize.WithLabelValues(opAdd) cm.dataSize.WithLabelValues(opSet) cm.dataSize.WithLabelValues(opCompareAndSwap) @@ -172,22 +180,12 @@ func (c *baseClient) setAsync(key string, value []byte, ttl time.Duration, f fun } err := c.asyncQueue.submit(func() { - start := time.Now() - c.metrics.operations.WithLabelValues(opSet).Inc() - - err := f(key, value, ttl) - if err != nil { - level.Debug(c.logger).Log( - "msg", "failed to store item to cache", - "key", key, - "sizeBytes", len(value), - "err", err, - ) - c.trackError(opSet, err) - } - - c.metrics.dataSize.WithLabelValues(opSet).Observe(float64(len(value))) - c.metrics.duration.WithLabelValues(opSet).Observe(time.Since(start).Seconds()) + // Because this operation is executed in a separate goroutine: We run the operation without + // a context (it is expected to keep running no matter what happens) and we don't return the + // error (it will be tracked via metrics instead of being returned to the caller). + _ = c.storeOperation(context.Background(), key, value, ttl, opSet, func(_ context.Context, key string, value []byte, ttl time.Duration) error { + return f(key, value, ttl) + }) }) if err != nil { @@ -196,6 +194,32 @@ func (c *baseClient) setAsync(key string, value []byte, ttl time.Duration, f fun } } +func (c *baseClient) storeOperation(ctx context.Context, key string, value []byte, ttl time.Duration, operation string, f func(ctx context.Context, key string, value []byte, ttl time.Duration) error) error { + if c.maxItemSize > 0 && uint64(len(value)) > c.maxItemSize { + c.metrics.skipped.WithLabelValues(operation, reasonMaxItemSize).Inc() + return nil + } + + start := time.Now() + c.metrics.operations.WithLabelValues(operation).Inc() + + err := f(ctx, key, value, ttl) + if err != nil { + level.Debug(c.logger).Log( + "msg", "failed to store item to cache", + "operation", operation, + "key", key, + "sizeBytes", len(value), + "err", err, + ) + c.trackError(operation, err) + } + + c.metrics.dataSize.WithLabelValues(operation).Observe(float64(len(value))) + c.metrics.duration.WithLabelValues(operation).Observe(time.Since(start).Seconds()) + return err +} + // wait submits an async task and blocks until it completes. This can be used during // tests to ensure that async "sets" have completed before attempting to read them. func (c *baseClient) wait() error { @@ -255,6 +279,10 @@ func (c *baseClient) trackError(op string, err error) { } else { c.metrics.failures.WithLabelValues(op, reasonNetworkError).Inc() } + case errors.Is(err, ErrNotStored): + c.metrics.failures.WithLabelValues(op, reasonNotStored).Inc() + case errors.Is(err, ErrInvalidTTL): + c.metrics.failures.WithLabelValues(op, reasonInvalidTTL).Inc() case errors.Is(err, memcache.ErrMalformedKey): c.metrics.failures.WithLabelValues(op, reasonMalformedKey).Inc() case errors.Is(err, memcache.ErrServerError): diff --git a/cache/compression.go b/cache/compression.go index 4146d2597..460ac5849 100644 --- a/cache/compression.go +++ b/cache/compression.go @@ -85,6 +85,14 @@ func (s *SnappyCache) SetMultiAsync(data map[string][]byte, ttl time.Duration) { s.next.SetMultiAsync(encoded, ttl) } +func (s *SnappyCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return s.next.Set(ctx, key, snappy.Encode(nil, value), ttl) +} + +func (s *SnappyCache) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return s.next.Add(ctx, key, snappy.Encode(nil, value), ttl) +} + // GetMulti implements Cache. func (s *SnappyCache) GetMulti(ctx context.Context, keys []string, opts ...Option) map[string][]byte { found := s.next.GetMulti(ctx, keys, opts...) diff --git a/cache/lru.go b/cache/lru.go index b75e5a4e2..c0b7a032f 100644 --- a/cache/lru.go +++ b/cache/lru.go @@ -103,6 +103,41 @@ func (l *LRUCache) SetMultiAsync(data map[string][]byte, ttl time.Duration) { } } +func (l *LRUCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + err := l.c.Set(ctx, key, value, ttl) + + l.mtx.Lock() + defer l.mtx.Unlock() + + expires := time.Now().Add(ttl) + l.lru.Add(key, &Item{ + Data: value, + ExpiresAt: expires, + }) + + return err +} + +func (l *LRUCache) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + err := l.c.Add(ctx, key, value, ttl) + + // When a caller uses the Add method, the presence of absence of an entry in the cache + // has significance. In order to maintain the semantics of that, we only add an entry to + // the LRU when it was able to be successfully added to the shared cache. + if err == nil { + l.mtx.Lock() + defer l.mtx.Unlock() + + expires := time.Now().Add(ttl) + l.lru.Add(key, &Item{ + Data: value, + ExpiresAt: expires, + }) + } + + return err +} + func (l *LRUCache) GetMulti(ctx context.Context, keys []string, opts ...Option) (result map[string][]byte) { l.requests.Add(float64(len(keys))) l.mtx.Lock() diff --git a/cache/memcached_client.go b/cache/memcached_client.go index a22f80354..0c4e9e6c5 100644 --- a/cache/memcached_client.go +++ b/cache/memcached_client.go @@ -28,6 +28,7 @@ import ( const ( dnsProviderUpdateInterval = 30 * time.Second + maxTTL = 30 * 24 * time.Hour ) var ( @@ -43,6 +44,7 @@ var ( type memcachedClientBackend interface { GetMulti(keys []string, opts ...memcache.Option) (map[string]*memcache.Item, error) Set(item *memcache.Item) error + Add(item *memcache.Item) error Delete(key string) error Decrement(key string, delta uint64) (uint64, error) Increment(key string, delta uint64) (uint64, error) @@ -322,14 +324,47 @@ func (c *MemcachedClient) SetAsync(key string, value []byte, ttl time.Duration) c.setAsync(key, value, ttl, c.setSingleItem) } +func (c *MemcachedClient) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.storeOperation(ctx, key, value, ttl, opSet, func(ctx context.Context, key string, value []byte, ttl time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return c.setSingleItem(key, value, ttl) + } + }) +} + +func (c *MemcachedClient) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.storeOperation(ctx, key, value, ttl, opAdd, func(ctx context.Context, key string, value []byte, ttl time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + ttlSeconds, ok := toSeconds(ttl) + if !ok { + return fmt.Errorf("%w: for set operation on %s %s", ErrInvalidTTL, key, ttl) + } + + err := c.client.Add(&memcache.Item{ + Key: key, + Value: value, + Expiration: ttlSeconds, + }) + + if errors.Is(err, memcache.ErrNotStored) { + return fmt.Errorf("%w: for add operation on %s", ErrNotStored, key) + } + + return err + } + }) +} + func (c *MemcachedClient) setSingleItem(key string, value []byte, ttl time.Duration) error { - ttlSeconds := int32(ttl.Seconds()) - // If a TTL of exactly 0 is passed, we honor it and pass it to Memcached which will - // interpret it as an infinite TTL. However, if we get a non-zero TTL that is truncated - // to 0 seconds, we discard the update since the caller didn't intend to set an infinite - // TTL. - if ttl != 0 && ttlSeconds <= 0 { - return nil + ttlSeconds, ok := toSeconds(ttl) + if !ok { + return fmt.Errorf("%w: for set operation on %s %s", ErrInvalidTTL, key, ttl) } return c.client.Set(&memcache.Item{ @@ -339,6 +374,20 @@ func (c *MemcachedClient) setSingleItem(key string, value []byte, ttl time.Durat }) } +// TODO: Docs +func toSeconds(d time.Duration) (int32, bool) { + if d > maxTTL { + return 0, false + } + + secs := int32(d.Seconds()) + if d != 0 && secs <= 0 { + return 0, false + } + + return secs, true +} + func toMemcacheOptions(opts ...Option) []memcache.Option { if len(opts) == 0 { return nil diff --git a/cache/memcached_client_test.go b/cache/memcached_client_test.go index 9b71cec1d..98dc3386f 100644 --- a/cache/memcached_client_test.go +++ b/cache/memcached_client_test.go @@ -96,6 +96,60 @@ func TestMemcachedClient_SetAsync(t *testing.T) { }) } +func TestMemcachedClient_Set(t *testing.T) { + t.Run("with non-zero TTL", func(t *testing.T) { + ctx := context.Background() + client, _, err := setupDefaultMemcachedClient() + require.NoError(t, err) + require.NoError(t, client.Set(ctx, "foo", []byte("bar"), time.Minute)) + + res := client.GetMulti(ctx, []string{"foo"}) + require.Equal(t, map[string][]byte{"foo": []byte("bar")}, res) + }) + + t.Run("with truncated TTL", func(t *testing.T) { + ctx := context.Background() + client, _, err := setupDefaultMemcachedClient() + require.NoError(t, err) + err = client.Set(ctx, "foo", []byte("bar"), 100*time.Millisecond) + require.ErrorIs(t, err, ErrInvalidTTL) + }) + + t.Run("with zero TTL", func(t *testing.T) { + ctx := context.Background() + client, _, err := setupDefaultMemcachedClient() + require.NoError(t, err) + require.NoError(t, client.Set(ctx, "foo", []byte("bar"), 0)) + + res := client.GetMulti(ctx, []string{"foo"}) + require.Equal(t, map[string][]byte{"foo": []byte("bar")}, res) + }) +} + +func TestMemcachedClient_Add(t *testing.T) { + t.Run("item does not exist", func(t *testing.T) { + ctx := context.Background() + client, _, err := setupDefaultMemcachedClient() + require.NoError(t, err) + require.NoError(t, client.Add(ctx, "foo", []byte("bar"), time.Minute)) + }) + + t.Run("item already exists", func(t *testing.T) { + ctx := context.Background() + client, mock, err := setupDefaultMemcachedClient() + require.NoError(t, err) + + require.NoError(t, mock.Set(&memcache.Item{ + Key: "foo", + Value: []byte("baz"), + Expiration: 60, + })) + + err = client.Add(ctx, "foo", []byte("bar"), time.Minute) + require.ErrorIs(t, err, ErrNotStored) + }) +} + func TestMemcachedClient_GetMulti(t *testing.T) { t.Run("no allocator", func(t *testing.T) { client, backend, err := setupDefaultMemcachedClient() @@ -334,6 +388,15 @@ func (m *mockMemcachedClientBackend) Set(item *memcache.Item) error { return nil } +func (m *mockMemcachedClientBackend) Add(item *memcache.Item) error { + if _, ok := m.values[item.Key]; ok { + return memcache.ErrNotStored + } + + m.values[item.Key] = item + return nil +} + func (m *mockMemcachedClientBackend) Delete(key string) error { delete(m.values, key) return nil diff --git a/cache/mock.go b/cache/mock.go index b06123272..f2d84fa12 100644 --- a/cache/mock.go +++ b/cache/mock.go @@ -43,6 +43,14 @@ func (m *MockCache) SetMultiAsync(data map[string][]byte, ttl time.Duration) { } } +func (m *MockCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return nil +} + +func (m *MockCache) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return nil +} + func (m *MockCache) GetMulti(_ context.Context, keys []string, _ ...Option) map[string][]byte { m.mu.Lock() defer m.mu.Unlock() @@ -121,6 +129,16 @@ func (m *InstrumentedMockCache) SetMultiAsync(data map[string][]byte, ttl time.D m.cache.SetMultiAsync(data, ttl) } +func (m *InstrumentedMockCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + m.storeCount.Inc() + return m.cache.Set(ctx, key, value, ttl) +} + +func (m *InstrumentedMockCache) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + m.storeCount.Inc() + return m.cache.Add(ctx, key, value, ttl) +} + func (m *InstrumentedMockCache) GetMulti(ctx context.Context, keys []string, opts ...Option) map[string][]byte { m.fetchCount.Inc() return m.cache.GetMulti(ctx, keys, opts...) diff --git a/cache/redis_client.go b/cache/redis_client.go index cd9efb4c1..a0feb4e15 100644 --- a/cache/redis_client.go +++ b/cache/redis_client.go @@ -226,7 +226,7 @@ func NewRedisClient(logger log.Logger, name string, config RedisClientConfig, re return c, nil } -// SetMultiAsync implements RemoteCacheClient. +// SetMultiAsync implements Cache. func (c *RedisClient) SetMultiAsync(data map[string][]byte, ttl time.Duration) { c.setMultiAsync(data, ttl, func(key string, value []byte, ttl time.Duration) error { _, err := c.client.Set(context.Background(), key, value, ttl).Result() @@ -234,7 +234,7 @@ func (c *RedisClient) SetMultiAsync(data map[string][]byte, ttl time.Duration) { }) } -// SetAsync implements RemoteCacheClient. +// SetAsync implements Cache. func (c *RedisClient) SetAsync(key string, value []byte, ttl time.Duration) { c.setAsync(key, value, ttl, func(key string, buf []byte, ttl time.Duration) error { _, err := c.client.Set(context.Background(), key, buf, ttl).Result() @@ -242,7 +242,30 @@ func (c *RedisClient) SetAsync(key string, value []byte, ttl time.Duration) { }) } -// GetMulti implements RemoteCacheClient. +// Set implements Cache. +func (c *RedisClient) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.storeOperation(ctx, key, value, ttl, opSet, func(ctx context.Context, key string, value []byte, ttl time.Duration) error { + _, err := c.client.Set(ctx, key, value, ttl).Result() + return err + }) +} + +// Add implements Cache. +func (c *RedisClient) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.storeOperation(ctx, key, value, ttl, opAdd, func(ctx context.Context, key string, value []byte, ttl time.Duration) error { + stored, err := c.client.SetNX(ctx, key, value, ttl).Result() + if err != nil { + return err + } + if !stored { + return fmt.Errorf("%w: for Set NX operation on %s", ErrNotStored, key) + } + + return nil + }) +} + +// GetMulti implements Cache. func (c *RedisClient) GetMulti(ctx context.Context, keys []string, _ ...Option) map[string][]byte { if len(keys) == 0 { return nil diff --git a/cache/redis_client_test.go b/cache/redis_client_test.go index 344db4211..ff1c38cbe 100644 --- a/cache/redis_client_test.go +++ b/cache/redis_client_test.go @@ -156,7 +156,7 @@ func TestRedisClient(t *testing.T) { } } -func TestRedisClientDelete(t *testing.T) { +func TestRedisClient_Delete(t *testing.T) { s, err := miniredis.Run() require.NoError(t, err) defer s.Close() diff --git a/cache/tracing.go b/cache/tracing.go index 68f29d140..148c708f8 100644 --- a/cache/tracing.go +++ b/cache/tracing.go @@ -31,6 +31,14 @@ func (t *SpanlessTracingCache) SetMultiAsync(data map[string][]byte, ttl time.Du t.next.SetMultiAsync(data, ttl) } +func (t *SpanlessTracingCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return t.next.Set(ctx, key, value, ttl) +} + +func (t *SpanlessTracingCache) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return t.next.Add(ctx, key, value, ttl) +} + func (t *SpanlessTracingCache) GetMulti(ctx context.Context, keys []string, opts ...Option) (result map[string][]byte) { var ( bytes int diff --git a/cache/versioned.go b/cache/versioned.go index 782cd8597..a4eef5170 100644 --- a/cache/versioned.go +++ b/cache/versioned.go @@ -36,6 +36,14 @@ func (c *Versioned) SetMultiAsync(data map[string][]byte, ttl time.Duration) { c.cache.SetMultiAsync(versioned, ttl) } +func (c *Versioned) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.cache.Set(ctx, c.addVersion(key), value, ttl) +} + +func (c *Versioned) Add(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return c.cache.Add(ctx, c.addVersion(key), value, ttl) +} + func (c *Versioned) GetMulti(ctx context.Context, keys []string, opts ...Option) map[string][]byte { versionedKeys := make([]string, len(keys)) for i, k := range keys {