Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions internal/cache/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type DiskConfig struct {
type Disk struct {
logger *slog.Logger
config DiskConfig
ttl *ttlStorage
db *diskMetaDB
size atomic.Int64
runEviction chan struct{}
stop context.CancelFunc
Expand Down Expand Up @@ -67,7 +67,7 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) {
}

// Open TTL storage
ttl, err := newTTLStorage(filepath.Join(config.Root, "metadata.db"))
db, err := newDiskMetaDB(filepath.Join(config.Root, "metadata.db"))
if err != nil {
return nil, errors.Errorf("failed to create TTL storage: %w", err)
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func NewDisk(ctx context.Context, config DiskConfig) (*Disk, error) {
disk := &Disk{
logger: logger,
config: config,
ttl: ttl,
db: db,
runEviction: make(chan struct{}),
stop: stop,
}
Expand All @@ -114,8 +114,8 @@ func (d *Disk) String() string { return "disk:" + d.config.Root }

func (d *Disk) Close() error {
d.stop()
if d.ttl != nil {
return d.ttl.close()
if d.db != nil {
return d.db.close()
}
return nil
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error {

// Check if file is expired
expired := false
expiresAt, _, err := d.ttl.get(key)
expiresAt, err := d.db.getTTL(key)
if err == nil && time.Now().After(expiresAt) {
expired = true
}
Expand All @@ -178,7 +178,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error {
}

// Remove TTL metadata
if err := d.ttl.delete(key); err != nil {
if err := d.db.delete(key); err != nil {
return errors.Errorf("failed to delete TTL metadata: %w", err)
}

Expand All @@ -198,15 +198,20 @@ func (d *Disk) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error)
return nil, errors.Errorf("failed to stat file: %w", err)
}

expiresAt, headers, err := d.ttl.get(key)
expiresAt, err := d.db.getTTL(key)
if err != nil {
return nil, errors.Errorf("failed to get metadata: %w", err)
return nil, errors.Errorf("failed to get TTL: %w", err)
}

if time.Now().After(expiresAt) {
return nil, errors.Join(fs.ErrNotExist, d.Delete(ctx, key))
}

headers, err := d.db.getHeaders(key)
if err != nil {
return nil, errors.Errorf("failed to get headers: %w", err)
}

return headers, nil
}

Expand All @@ -219,21 +224,26 @@ func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIME
return nil, nil, errors.Errorf("failed to open file: %w", err)
}

expiresAt, headers, err := d.ttl.get(key)
expiresAt, err := d.db.getTTL(key)
if err != nil {
return nil, nil, errors.Join(errors.Errorf("failed to get metadata: %w", err), f.Close())
return nil, nil, errors.Join(errors.Errorf("failed to get TTL: %w", err), f.Close())
}

now := time.Now()
if now.After(expiresAt) {
return nil, nil, errors.Join(fs.ErrNotExist, f.Close(), d.Delete(ctx, key))
}

headers, err := d.db.getHeaders(key)
if err != nil {
return nil, nil, errors.Join(errors.Errorf("failed to get headers: %w", err), f.Close())
}

// Reset expiration time to implement LRU
ttl := min(expiresAt.Sub(now), d.config.MaxTTL)
newExpiresAt := now.Add(ttl)

if err := d.ttl.set(key, newExpiresAt, headers); err != nil {
if err := d.db.setTTL(key, newExpiresAt); err != nil {
return nil, nil, errors.Join(errors.Errorf("failed to update expiration time: %w", err), f.Close())
}

Expand Down Expand Up @@ -279,7 +289,7 @@ func (d *Disk) evict() error {
var expiredKeys []Key
now := time.Now()

err := d.ttl.walk(func(key Key, expiresAt time.Time) error {
err := d.db.walk(func(key Key, expiresAt time.Time) error {
path := d.keyToPath(key)
fullPath := filepath.Join(d.config.Root, path)

Expand Down Expand Up @@ -312,7 +322,7 @@ func (d *Disk) evict() error {
return errors.Errorf("failed to walk TTL entries: %w", err)
}

if err := d.ttl.deleteAll(expiredKeys); err != nil {
if err := d.db.deleteAll(expiredKeys); err != nil {
return errors.Errorf("failed to delete TTL metadata: %w", err)
}

Expand Down Expand Up @@ -340,7 +350,7 @@ func (d *Disk) evict() error {
d.size.Add(-f.size)
}

if err := d.ttl.deleteAll(sizeEvictedKeys); err != nil {
if err := d.db.deleteAll(sizeEvictedKeys); err != nil {
return errors.Errorf("failed to delete TTL metadata: %w", err)
}

Expand Down Expand Up @@ -380,7 +390,7 @@ func (w *diskWriter) Close() error {
return errors.Errorf("failed to rename temp file: %w", err)
}

if err := w.disk.ttl.set(w.key, w.expiresAt, w.headers); err != nil {
if err := w.disk.db.set(w.key, w.expiresAt, w.headers); err != nil {
return errors.Join(errors.Errorf("failed to set metadata: %w", err), os.Remove(w.path))
}

Expand Down
164 changes: 164 additions & 0 deletions internal/cache/disk_metadb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package cache

import (
"encoding/json"
"net/textproto"
"time"

"github.com/alecthomas/errors"
"go.etcd.io/bbolt"
)

var (
ttlBucketName = []byte("ttl")
headersBucketName = []byte("headers")
)

// diskMetaDB manages expiration times and headers for cache entries using bbolt.
type diskMetaDB struct {
db *bbolt.DB
}

// newDiskMetaDB creates a new bbolt-backed metadata storage for the disk cache.
func newDiskMetaDB(dbPath string) (*diskMetaDB, error) {
db, err := bbolt.Open(dbPath, 0600, &bbolt.Options{
Timeout: 5 * time.Second,
})
if err != nil {
return nil, errors.Errorf("failed to open bbolt database: %w", err)
}

if err := db.Update(func(tx *bbolt.Tx) error {
if _, err := tx.CreateBucketIfNotExists(ttlBucketName); err != nil {
return errors.WithStack(err)
}
if _, err := tx.CreateBucketIfNotExists(headersBucketName); err != nil {
return errors.WithStack(err)
}
return nil
}); err != nil {
return nil, errors.Join(errors.Errorf("failed to create buckets: %w", err), db.Close())
}

return &diskMetaDB{db: db}, nil
}

func (s *diskMetaDB) setTTL(key Key, expiresAt time.Time) error {
ttlBytes, err := expiresAt.MarshalBinary()
if err != nil {
return errors.Errorf("failed to marshal TTL: %w", err)
}

return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error {
ttlBucket := tx.Bucket(ttlBucketName)
return errors.WithStack(ttlBucket.Put(key[:], ttlBytes))
}))
}

func (s *diskMetaDB) set(key Key, expiresAt time.Time, headers textproto.MIMEHeader) error {
ttlBytes, err := expiresAt.MarshalBinary()
if err != nil {
return errors.Errorf("failed to marshal TTL: %w", err)
}

headersBytes, err := json.Marshal(headers)
if err != nil {
return errors.Errorf("failed to encode headers: %w", err)
}

return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error {
ttlBucket := tx.Bucket(ttlBucketName)
if err := ttlBucket.Put(key[:], ttlBytes); err != nil {
return errors.WithStack(err)
}

headersBucket := tx.Bucket(headersBucketName)
return errors.WithStack(headersBucket.Put(key[:], headersBytes))
}))
}

func (s *diskMetaDB) getTTL(key Key) (time.Time, error) {
var expiresAt time.Time
err := s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(ttlBucketName)
ttlBytes := bucket.Get(key[:])
if ttlBytes == nil {
return errors.New("key not found")
}
return errors.WithStack(expiresAt.UnmarshalBinary(ttlBytes))
})
return expiresAt, errors.WithStack(err)
}

func (s *diskMetaDB) getHeaders(key Key) (textproto.MIMEHeader, error) {
var headers textproto.MIMEHeader
err := s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(headersBucketName)
headersBytes := bucket.Get(key[:])
if headersBytes == nil {
return errors.New("key not found")
}
return errors.WithStack(json.Unmarshal(headersBytes, &headers))
})
return headers, errors.WithStack(err)
}

func (s *diskMetaDB) delete(key Key) error {
return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error {
ttlBucket := tx.Bucket(ttlBucketName)
if err := ttlBucket.Delete(key[:]); err != nil {
return errors.WithStack(err)
}

headersBucket := tx.Bucket(headersBucketName)
return errors.WithStack(headersBucket.Delete(key[:]))
}))
}

func (s *diskMetaDB) deleteAll(keys []Key) error {
if len(keys) == 0 {
return nil
}
return errors.WithStack(s.db.Update(func(tx *bbolt.Tx) error {
ttlBucket := tx.Bucket(ttlBucketName)
headersBucket := tx.Bucket(headersBucketName)

for _, key := range keys {
if err := ttlBucket.Delete(key[:]); err != nil {
return errors.Errorf("failed to delete TTL: %w", err)
}
if err := headersBucket.Delete(key[:]); err != nil {
return errors.Errorf("failed to delete headers: %w", err)
}
}
return nil
}))
}

func (s *diskMetaDB) walk(fn func(key Key, expiresAt time.Time) error) error {
return errors.WithStack(s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(ttlBucketName)
if bucket == nil {
return nil
}
return bucket.ForEach(func(k, v []byte) error {
if len(k) != 32 {
return nil
}
var key Key
copy(key[:], k)
var expiresAt time.Time
if err := expiresAt.UnmarshalBinary(v); err != nil {
return nil //nolint:nilerr
}
return fn(key, expiresAt)
})
}))
}

func (s *diskMetaDB) close() error {
if err := s.db.Close(); err != nil {
return errors.Errorf("failed to close bbolt database: %w", err)
}
return nil
}
Loading