diff --git a/internal/cache/disk.go b/internal/cache/disk.go index 6c21882..cef70f2 100644 --- a/internal/cache/disk.go +++ b/internal/cache/disk.go @@ -32,7 +32,7 @@ type DiskConfig struct { type Disk struct { logger *slog.Logger config DiskConfig - ttl *ttlStorage + db *diskMetaDB size atomic.Int64 runEviction chan struct{} stop context.CancelFunc @@ -67,7 +67,7 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { } // Open TTL storage - ttl, err := newTTLStorage(filepath.Join(config.Root, "metadata.db")) + db, err := newDiskMetaDB(filepath.Join(config.Root, "metadata.db")) if err != nil { return nil, errors.Errorf("failed to create TTL storage: %w", err) } @@ -99,7 +99,7 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) { disk := &Disk{ logger: logger, config: config, - ttl: ttl, + db: db, runEviction: make(chan struct{}), stop: stop, } @@ -114,8 +114,8 @@ func (d *Disk) String() string { return "disk:" + d.config.Root } func (d *Disk) Close() error { d.stop() - if d.ttl != nil { - return d.ttl.close() + if d.db != nil { + return d.db.close() } return nil } @@ -163,7 +163,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error { // Check if file is expired expired := false - expiresAt, _, err := d.ttl.get(key) + expiresAt, err := d.db.getTTL(key) if err == nil && time.Now().After(expiresAt) { expired = true } @@ -178,7 +178,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error { } // Remove TTL metadata - if err := d.ttl.delete(key); err != nil { + if err := d.db.delete(key); err != nil { return errors.Errorf("failed to delete TTL metadata: %w", err) } @@ -198,15 +198,20 @@ func (d *Disk) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) return nil, errors.Errorf("failed to stat file: %w", err) } - expiresAt, headers, err := d.ttl.get(key) + expiresAt, err := d.db.getTTL(key) if err != nil { - return nil, errors.Errorf("failed to get metadata: %w", err) + return nil, errors.Errorf("failed to get TTL: %w", err) } if time.Now().After(expiresAt) { return nil, errors.Join(fs.ErrNotExist, d.Delete(ctx, key)) } + headers, err := d.db.getHeaders(key) + if err != nil { + return nil, errors.Errorf("failed to get headers: %w", err) + } + return headers, nil } @@ -219,9 +224,9 @@ func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIME return nil, nil, errors.Errorf("failed to open file: %w", err) } - expiresAt, headers, err := d.ttl.get(key) + expiresAt, err := d.db.getTTL(key) if err != nil { - return nil, nil, errors.Join(errors.Errorf("failed to get metadata: %w", err), f.Close()) + return nil, nil, errors.Join(errors.Errorf("failed to get TTL: %w", err), f.Close()) } now := time.Now() @@ -229,11 +234,16 @@ func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIME return nil, nil, errors.Join(fs.ErrNotExist, f.Close(), d.Delete(ctx, key)) } + headers, err := d.db.getHeaders(key) + if err != nil { + return nil, nil, errors.Join(errors.Errorf("failed to get headers: %w", err), f.Close()) + } + // Reset expiration time to implement LRU ttl := min(expiresAt.Sub(now), d.config.MaxTTL) newExpiresAt := now.Add(ttl) - if err := d.ttl.set(key, newExpiresAt, headers); err != nil { + if err := d.db.setTTL(key, newExpiresAt); err != nil { return nil, nil, errors.Join(errors.Errorf("failed to update expiration time: %w", err), f.Close()) } @@ -279,7 +289,7 @@ func (d *Disk) evict() error { var expiredKeys []Key now := time.Now() - err := d.ttl.walk(func(key Key, expiresAt time.Time) error { + err := d.db.walk(func(key Key, expiresAt time.Time) error { path := d.keyToPath(key) fullPath := filepath.Join(d.config.Root, path) @@ -312,7 +322,7 @@ func (d *Disk) evict() error { return errors.Errorf("failed to walk TTL entries: %w", err) } - if err := d.ttl.deleteAll(expiredKeys); err != nil { + if err := d.db.deleteAll(expiredKeys); err != nil { return errors.Errorf("failed to delete TTL metadata: %w", err) } @@ -340,7 +350,7 @@ func (d *Disk) evict() error { d.size.Add(-f.size) } - if err := d.ttl.deleteAll(sizeEvictedKeys); err != nil { + if err := d.db.deleteAll(sizeEvictedKeys); err != nil { return errors.Errorf("failed to delete TTL metadata: %w", err) } @@ -380,7 +390,7 @@ func (w *diskWriter) Close() error { return errors.Errorf("failed to rename temp file: %w", err) } - if err := w.disk.ttl.set(w.key, w.expiresAt, w.headers); err != nil { + if err := w.disk.db.set(w.key, w.expiresAt, w.headers); err != nil { return errors.Join(errors.Errorf("failed to set metadata: %w", err), os.Remove(w.path)) } diff --git a/internal/cache/disk_metadb.go b/internal/cache/disk_metadb.go new file mode 100644 index 0000000..673658c --- /dev/null +++ b/internal/cache/disk_metadb.go @@ -0,0 +1,164 @@ +package cache + +import ( + "encoding/json" + "net/textproto" + "time" + + "github.com/alecthomas/errors" + "go.etcd.io/bbolt" +) + +var ( + ttlBucketName = []byte("ttl") + headersBucketName = []byte("headers") +) + +// diskMetaDB manages expiration times and headers for cache entries using bbolt. +type diskMetaDB struct { + db *bbolt.DB +} + +// newDiskMetaDB creates a new bbolt-backed metadata storage for the disk cache. +func newDiskMetaDB(dbPath string) (*diskMetaDB, error) { + db, err := bbolt.Open(dbPath, 0600, &bbolt.Options{ + Timeout: 5 * time.Second, + }) + if err != nil { + return nil, errors.Errorf("failed to open bbolt database: %w", err) + } + + if err := db.Update(func(tx *bbolt.Tx) error { + if _, err := tx.CreateBucketIfNotExists(ttlBucketName); err != nil { + return errors.WithStack(err) + } + if _, err := tx.CreateBucketIfNotExists(headersBucketName); err != nil { + return errors.WithStack(err) + } + return nil + }); err != nil { + return nil, errors.Join(errors.Errorf("failed to create buckets: %w", err), db.Close()) + } + + return &diskMetaDB{db: db}, nil +} + +func (s *diskMetaDB) setTTL(key Key, expiresAt time.Time) error { + ttlBytes, err := expiresAt.MarshalBinary() + if err != nil { + return errors.Errorf("failed to marshal TTL: %w", err) + } + + return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { + ttlBucket := tx.Bucket(ttlBucketName) + return errors.WithStack(ttlBucket.Put(key[:], ttlBytes)) + })) +} + +func (s *diskMetaDB) set(key Key, expiresAt time.Time, headers textproto.MIMEHeader) error { + ttlBytes, err := expiresAt.MarshalBinary() + if err != nil { + return errors.Errorf("failed to marshal TTL: %w", err) + } + + headersBytes, err := json.Marshal(headers) + if err != nil { + return errors.Errorf("failed to encode headers: %w", err) + } + + return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { + ttlBucket := tx.Bucket(ttlBucketName) + if err := ttlBucket.Put(key[:], ttlBytes); err != nil { + return errors.WithStack(err) + } + + headersBucket := tx.Bucket(headersBucketName) + return errors.WithStack(headersBucket.Put(key[:], headersBytes)) + })) +} + +func (s *diskMetaDB) getTTL(key Key) (time.Time, error) { + var expiresAt time.Time + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + ttlBytes := bucket.Get(key[:]) + if ttlBytes == nil { + return errors.New("key not found") + } + return errors.WithStack(expiresAt.UnmarshalBinary(ttlBytes)) + }) + return expiresAt, errors.WithStack(err) +} + +func (s *diskMetaDB) getHeaders(key Key) (textproto.MIMEHeader, error) { + var headers textproto.MIMEHeader + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(headersBucketName) + headersBytes := bucket.Get(key[:]) + if headersBytes == nil { + return errors.New("key not found") + } + return errors.WithStack(json.Unmarshal(headersBytes, &headers)) + }) + return headers, errors.WithStack(err) +} + +func (s *diskMetaDB) delete(key Key) error { + return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { + ttlBucket := tx.Bucket(ttlBucketName) + if err := ttlBucket.Delete(key[:]); err != nil { + return errors.WithStack(err) + } + + headersBucket := tx.Bucket(headersBucketName) + return errors.WithStack(headersBucket.Delete(key[:])) + })) +} + +func (s *diskMetaDB) deleteAll(keys []Key) error { + if len(keys) == 0 { + return nil + } + return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { + ttlBucket := tx.Bucket(ttlBucketName) + headersBucket := tx.Bucket(headersBucketName) + + for _, key := range keys { + if err := ttlBucket.Delete(key[:]); err != nil { + return errors.Errorf("failed to delete TTL: %w", err) + } + if err := headersBucket.Delete(key[:]); err != nil { + return errors.Errorf("failed to delete headers: %w", err) + } + } + return nil + })) +} + +func (s *diskMetaDB) walk(fn func(key Key, expiresAt time.Time) error) error { + return errors.WithStack(s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(ttlBucketName) + if bucket == nil { + return nil + } + return bucket.ForEach(func(k, v []byte) error { + if len(k) != 32 { + return nil + } + var key Key + copy(key[:], k) + var expiresAt time.Time + if err := expiresAt.UnmarshalBinary(v); err != nil { + return nil //nolint:nilerr + } + return fn(key, expiresAt) + }) + })) +} + +func (s *diskMetaDB) close() error { + if err := s.db.Close(); err != nil { + return errors.Errorf("failed to close bbolt database: %w", err) + } + return nil +} diff --git a/internal/cache/ttl_storage.go b/internal/cache/ttl_storage.go deleted file mode 100644 index 016c6db..0000000 --- a/internal/cache/ttl_storage.go +++ /dev/null @@ -1,123 +0,0 @@ -package cache - -import ( - "encoding/json" - "net/textproto" - "time" - - "github.com/alecthomas/errors" - "go.etcd.io/bbolt" -) - -var metadataBucketName = []byte("metadata") - -// metadata stores expiration time and headers for a cache entry. -type metadata struct { - ExpiresAt time.Time `json:"expires_at"` - Headers textproto.MIMEHeader `json:"headers"` -} - -// ttlStorage manages expiration times and headers for cache entries using bbolt. -type ttlStorage struct { - db *bbolt.DB -} - -// newTTLStorage creates a new bbolt-backed TTL storage. -func newTTLStorage(dbPath string) (*ttlStorage, error) { - db, err := bbolt.Open(dbPath, 0600, &bbolt.Options{ - Timeout: 5 * time.Second, - }) - if err != nil { - return nil, errors.Errorf("failed to open bbolt database: %w", err) - } - - // Create the bucket if it doesn't exist - if err := db.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(metadataBucketName) - return errors.WithStack(err) - }); err != nil { - return nil, errors.Join(errors.Errorf("failed to create metadata bucket: %w", err), db.Close()) - } - - return &ttlStorage{db: db}, nil -} - -func (s *ttlStorage) set(key Key, expiresAt time.Time, headers textproto.MIMEHeader) error { - md := metadata{ - ExpiresAt: expiresAt, - Headers: headers, - } - - mdBytes, err := json.Marshal(md) - if err != nil { - return errors.Errorf("failed to encode metadata: %w", err) - } - - return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(metadataBucketName) - return bucket.Put(key[:], mdBytes) - })) -} - -func (s *ttlStorage) get(key Key) (time.Time, textproto.MIMEHeader, error) { - var md metadata - err := s.db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(metadataBucketName) - mdBytes := bucket.Get(key[:]) - if mdBytes == nil { - return errors.New("key not found") - } - return errors.WithStack(json.Unmarshal(mdBytes, &md)) - }) - return md.ExpiresAt, md.Headers, errors.WithStack(err) -} - -func (s *ttlStorage) delete(key Key) error { - return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(metadataBucketName) - return bucket.Delete(key[:]) - })) -} - -func (s *ttlStorage) deleteAll(keys []Key) error { - if len(keys) == 0 { - return nil - } - return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(metadataBucketName) - for _, key := range keys { - if err := bucket.Delete(key[:]); err != nil { - return errors.Errorf("failed to delete metadata: %w", err) - } - } - return nil - })) -} - -func (s *ttlStorage) walk(fn func(key Key, expiresAt time.Time) error) error { - return errors.WithStack(s.db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(metadataBucketName) - if bucket == nil { - return nil - } - return bucket.ForEach(func(k, v []byte) error { - if len(k) != 32 { - return nil - } - var key Key - copy(key[:], k) - var md metadata - if err := json.Unmarshal(v, &md); err != nil { - return nil //nolint:nilerr - } - return fn(key, md.ExpiresAt) - }) - })) -} - -func (s *ttlStorage) close() error { - if err := s.db.Close(); err != nil { - return errors.Errorf("failed to close bbolt database: %w", err) - } - return nil -}