diff --git a/chanutils/batch_writer.go b/chanutils/batch_writer.go new file mode 100644 index 00000000..70654a22 --- /dev/null +++ b/chanutils/batch_writer.go @@ -0,0 +1,154 @@ +package chanutils + +import ( + "sync" + "time" + + "github.com/btcsuite/btclog" +) + +// BatchWriterConfig holds the configuration options for BatchWriter. +type BatchWriterConfig[T any] struct { + // QueueBufferSize sets the buffer size of the output channel of the + // concurrent queue used by the BatchWriter. + QueueBufferSize int + + // MaxBatch is the maximum number of filters to be persisted to the DB + // in one go. + MaxBatch int + + // DBWritesTickerDuration is the time after receiving a filter that the + // writer will wait for more filters before writing the current batch + // to the DB. + DBWritesTickerDuration time.Duration + + // Logger is the logger that the BatchWriter should use for any logs. + Logger btclog.Logger + + // PutItems will be used by the BatchWriter to persist filters in + // batches. + PutItems func(...T) error +} + +// BatchWriter manages writing Filters to the DB and tries to batch the writes +// as much as possible. +type BatchWriter[T any] struct { + started sync.Once + stopped sync.Once + + cfg *BatchWriterConfig[T] + + queue *ConcurrentQueue[T] + + quit chan struct{} + wg sync.WaitGroup +} + +// NewBatchWriter constructs a new BatchWriter using the given +// BatchWriterConfig. +func NewBatchWriter[T any](cfg *BatchWriterConfig[T]) *BatchWriter[T] { + return &BatchWriter[T]{ + cfg: cfg, + queue: NewConcurrentQueue[T](cfg.QueueBufferSize), + quit: make(chan struct{}), + } +} + +// Start starts the BatchWriter. +func (b *BatchWriter[T]) Start() { + b.started.Do(func() { + b.queue.Start() + + b.wg.Add(1) + go b.manageNewItems() + }) +} + +// Stop stops the BatchWriter. +func (b *BatchWriter[T]) Stop() { + b.stopped.Do(func() { + close(b.quit) + b.wg.Wait() + + b.queue.Stop() + }) +} + +// AddItem adds a given item to the BatchWriter queue. +func (b *BatchWriter[T]) AddItem(item T) { + b.queue.ChanIn() <- item +} + +// manageNewItems manages collecting filters and persisting them to the DB. +// There are two conditions for writing a batch of filters to the DB: the first +// is if a certain threshold (MaxBatch) of filters has been collected and the +// other is if at least one filter has been collected and a timeout has been +// reached. +// +// NOTE: this must be run in a goroutine. +func (b *BatchWriter[T]) manageNewItems() { + defer b.wg.Done() + + batch := make([]T, 0, b.cfg.MaxBatch) + + // writeBatch writes the current contents of the batch slice to the + // filters DB. + writeBatch := func() { + if len(batch) == 0 { + return + } + + err := b.cfg.PutItems(batch...) + if err != nil { + b.cfg.Logger.Warnf("Couldn't write filters to "+ + "filterDB: %v", err) + } + + // Empty the batch slice. + batch = make([]T, 0, b.cfg.MaxBatch) + } + + ticker := time.NewTicker(b.cfg.DBWritesTickerDuration) + defer ticker.Stop() + + // Stop the ticker since we don't want it to tick unless there is at + // least one item in the queue. + ticker.Stop() + + for { + select { + case filter, ok := <-b.queue.ChanOut(): + if !ok { + return + } + + batch = append(batch, filter) + + switch len(batch) { + // If the batch slice is full, we stop the timer and + // write the batch contents to disk. + case b.cfg.MaxBatch: + ticker.Stop() + writeBatch() + + // If an item is added to the batch, we reset the timer. + // This ensures that if the batch threshold is not met + // then items are still persisted in a timely manner. + default: + ticker.Reset(b.cfg.DBWritesTickerDuration) + } + + case <-ticker.C: + // If the ticker ticks, then we stop it and write the + // current batch contents to the db. If any more items + // are added, the ticker will be reset. + ticker.Stop() + writeBatch() + + case <-b.quit: + writeBatch() + + return + } + } +} diff --git a/chanutils/batch_writer_test.go b/chanutils/batch_writer_test.go new file mode 100644 index 00000000..a2bb12d3 --- /dev/null +++ b/chanutils/batch_writer_test.go @@ -0,0 +1,210 @@ +package chanutils + +import ( + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const waitTime = time.Second * 5 + +// TestBatchWriter tests that the BatchWriter behaves as expected. +func TestBatchWriter(t *testing.T) { + t.Parallel() + rand.Seed(time.Now().UnixNano()) + + // waitForItems is a helper function that will wait for a given set of + // items to appear in the db. + waitForItems := func(db *mockItemsDB, items ...*item) { + err := waitFor(func() bool { + return db.hasItems(items...) + }, waitTime) + require.NoError(t, err) + } + + t.Run("filters persisted after ticker", func(t *testing.T) { + t.Parallel() + + // Create a mock filters DB. + db := newMockItemsDB() + + // Construct a new BatchWriter backed by the mock db. + b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ + QueueBufferSize: 10, + MaxBatch: 20, + DBWritesTickerDuration: time.Millisecond * 500, + PutItems: db.PutItems, + }) + b.Start() + t.Cleanup(b.Stop) + + fs := genFilterSet(5) + for _, f := range fs { + b.AddItem(f) + } + waitForItems(db, fs...) + }) + + t.Run("write once threshold is reached", func(t *testing.T) { + t.Parallel() + + // Create a mock filters DB. + db := newMockItemsDB() + + // Construct a new BatchWriter backed by the mock db. + // Make the DB writes ticker duration extra long so that we + // can explicitly test that the batch gets persisted if the + // MaxBatch threshold is reached. + b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ + QueueBufferSize: 10, + MaxBatch: 20, + DBWritesTickerDuration: time.Hour, + PutItems: db.PutItems, + }) + b.Start() + t.Cleanup(b.Stop) + + // Generate 30 filters and add each one to the batch writer. + fs := genFilterSet(30) + for _, f := range fs { + b.AddItem(f) + } + + // Since the MaxBatch threshold has been reached, we expect the + // first 20 filters to be persisted. + waitForItems(db, fs[:20]...) + + // Since the last 10 filters don't reach the threshold and since + // the ticker has definitely not ticked yet, we don't expect the + // last 10 filters to be in the db yet. + require.False(t, db.hasItems(fs[21:]...)) + }) + + t.Run("stress test", func(t *testing.T) { + t.Parallel() + + // Create a mock filters DB. + db := newMockItemsDB() + + // Construct a new BatchWriter backed by the mock db. + // Make the DB writes ticker duration extra long so that we + // can explicitly test that the batch gets persisted if the + // MaxBatch threshold is reached. + b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ + QueueBufferSize: 5, + MaxBatch: 5, + DBWritesTickerDuration: time.Millisecond * 2, + PutItems: db.PutItems, + }) + b.Start() + t.Cleanup(b.Stop) + + // Generate lots of filters and add each to the batch writer. + // Sleep for a bit between each filter to ensure that we + // sometimes hit the timeout write and sometimes the threshold + // write. + fs := genFilterSet(1000) + for _, f := range fs { + b.AddItem(f) + + n := rand.Intn(3) + time.Sleep(time.Duration(n) * time.Millisecond) + } + + // Since the MaxBatch threshold has been reached, we expect the + // first 20 filters to be persisted. + waitForItems(db, fs...) + }) +} + +type item struct { + i int +} + +// mockItemsDB is a mock DB that holds a set of items. +type mockItemsDB struct { + items map[int]bool + mu sync.Mutex +} + +// newMockItemsDB constructs a new mockItemsDB. +func newMockItemsDB() *mockItemsDB { + return &mockItemsDB{ + items: make(map[int]bool), + } +} + +// hasItems returns true if the db contains all the given items. +func (m *mockItemsDB) hasItems(items ...*item) bool { + m.mu.Lock() + defer m.mu.Unlock() + + for _, i := range items { + _, ok := m.items[i.i] + if !ok { + return false + } + } + + return true +} + +// PutItems adds a set of items to the db. +func (m *mockItemsDB) PutItems(items ...*item) error { + m.mu.Lock() + defer m.mu.Unlock() + + for _, i := range items { + m.items[i.i] = true + } + + return nil +} + +// genItemSet generates a set of numFilters items. +func genFilterSet(numFilters int) []*item { + res := make([]*item, numFilters) + for i := 0; i < numFilters; i++ { + res[i] = &item{i: i} + } + + return res +} + +// pollInterval is a constant specifying a 200 ms interval. +const pollInterval = 200 * time.Millisecond + +// waitFor is a helper test function that will wait for a timeout period of +// time until the passed predicate returns true. This function is helpful as +// timing doesn't always line up well when running integration tests with +// several running lnd nodes. This function gives callers a way to assert that +// some property is upheld within a particular time frame. +func waitFor(pred func() bool, timeout time.Duration) error { + exitTimer := time.After(timeout) + result := make(chan bool, 1) + + for { + <-time.After(pollInterval) + + go func() { + result <- pred() + }() + + // Each time we call the pred(), we expect a result to be + // returned otherwise it will timeout. + select { + case <-exitTimer: + return fmt.Errorf("predicate not satisfied after " + + "time out") + + case succeed := <-result: + if succeed { + return nil + } + } + } +}