diff --git a/chanutils/batch_writer.go b/chanutils/batch_writer.go new file mode 100644 index 00000000..b7828232 --- /dev/null +++ b/chanutils/batch_writer.go @@ -0,0 +1,149 @@ +package chanutils + +import ( + "sync" + "time" +) + +// BatchWriterConfig holds the configuration options for BatchWriter. +type BatchWriterConfig[T any] struct { + // QueueBufferSize sets the buffer size of the output channel of the + // concurrent queue used by the BatchWriter. + QueueBufferSize int + + // MaxBatch is the maximum number of filters to be persisted to the DB + // in one go. + MaxBatch int + + // DBWritesTickerDuration is the time after receiving a filter that the + // writer will wait for more filters before writing the current batch + // to the DB. + DBWritesTickerDuration time.Duration + + // PutItems will be used by the BatchWriter to persist filters in + // batches. + PutItems func(...T) error +} + +// BatchWriter manages writing Filters to the DB and tries to batch the writes +// as much as possible. +type BatchWriter[T any] struct { + started sync.Once + stopped sync.Once + + cfg *BatchWriterConfig[T] + + queue *ConcurrentQueue[T] + + quit chan struct{} + wg sync.WaitGroup +} + +// NewBatchWriter constructs a new BatchWriter using the given +// BatchWriterConfig. +func NewBatchWriter[T any](cfg *BatchWriterConfig[T]) *BatchWriter[T] { + return &BatchWriter[T]{ + cfg: cfg, + queue: NewConcurrentQueue[T](cfg.QueueBufferSize), + quit: make(chan struct{}), + } +} + +// Start starts the BatchWriter. +func (b *BatchWriter[T]) Start() { + b.started.Do(func() { + b.queue.Start() + + b.wg.Add(1) + go b.manageNewItems() + }) +} + +// Stop stops the BatchWriter. +func (b *BatchWriter[T]) Stop() { + b.stopped.Do(func() { + close(b.quit) + b.wg.Wait() + + b.queue.Stop() + }) +} + +// AddItem adds a given item to the BatchWriter queue. +func (b *BatchWriter[T]) AddItem(item T) { + b.queue.ChanIn() <- item +} + +// manageNewItems manages collecting filters and persisting them to the DB. +// There are two conditions for writing a batch of filters to the DB: the first +// is if a certain threshold (MaxBatch) of filters has been collected and the +// other is if at least one filter has been collected and a timeout has been +// reached. +// +// NOTE: this must be run in a goroutine. +func (b *BatchWriter[T]) manageNewItems() { + defer b.wg.Done() + + batch := make([]T, 0, b.cfg.MaxBatch) + + // writeBatch writes the current contents of the batch slice to the + // filters DB. + writeBatch := func() { + if len(batch) == 0 { + return + } + + err := b.cfg.PutItems(batch...) + if err != nil { + log.Errorf("Could not write filters to filterDB: %v", + err) + } + + // Empty the batch slice. + batch = make([]T, 0, b.cfg.MaxBatch) + } + + ticker := time.NewTicker(b.cfg.DBWritesTickerDuration) + defer ticker.Stop() + + // Stop the ticker since we don't want it to tick unless there is at + // least one item in the queue. + ticker.Stop() + + for { + select { + case filter, ok := <-b.queue.ChanOut(): + if !ok { + return + } + + batch = append(batch, filter) + + switch len(batch) { + // If the batch slice is full, we stop the ticker and + // write the batch contents to disk. + case b.cfg.MaxBatch: + ticker.Stop() + writeBatch() + + // If an item is added to the batch, we reset the timer. + // This ensures that if the batch threshold is not met + // then items are still persisted in a timely manner. + default: + ticker.Reset(b.cfg.DBWritesTickerDuration) + } + + case <-ticker.C: + // If the ticker ticks, then we stop it and write the + // current batch contents to the db. If any more items + // are added, the ticker will be reset. + ticker.Stop() + writeBatch() + + case <-b.quit: + writeBatch() + + return + } + } +} diff --git a/chanutils/batch_writer_test.go b/chanutils/batch_writer_test.go new file mode 100644 index 00000000..a2bb12d3 --- /dev/null +++ b/chanutils/batch_writer_test.go @@ -0,0 +1,210 @@ +package chanutils + +import ( + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const waitTime = time.Second * 5 + +// TestBatchWriter tests that the BatchWriter behaves as expected. +func TestBatchWriter(t *testing.T) { + t.Parallel() + rand.Seed(time.Now().UnixNano()) + + // waitForItems is a helper function that will wait for a given set of + // items to appear in the db. + waitForItems := func(db *mockItemsDB, items ...*item) { + err := waitFor(func() bool { + return db.hasItems(items...) + }, waitTime) + require.NoError(t, err) + } + + t.Run("filters persisted after ticker", func(t *testing.T) { + t.Parallel() + + // Create a mock filters DB. + db := newMockItemsDB() + + // Construct a new BatchWriter backed by the mock db. + b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ + QueueBufferSize: 10, + MaxBatch: 20, + DBWritesTickerDuration: time.Millisecond * 500, + PutItems: db.PutItems, + }) + b.Start() + t.Cleanup(b.Stop) + + fs := genFilterSet(5) + for _, f := range fs { + b.AddItem(f) + } + waitForItems(db, fs...) + }) + + t.Run("write once threshold is reached", func(t *testing.T) { + t.Parallel() + + // Create a mock filters DB. + db := newMockItemsDB() + + // Construct a new BatchWriter backed by the mock db. + // Make the DB writes ticker duration extra long so that we + // can explicitly test that the batch gets persisted if the + // MaxBatch threshold is reached. + b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ + QueueBufferSize: 10, + MaxBatch: 20, + DBWritesTickerDuration: time.Hour, + PutItems: db.PutItems, + }) + b.Start() + t.Cleanup(b.Stop) + + // Generate 30 filters and add each one to the batch writer. + fs := genFilterSet(30) + for _, f := range fs { + b.AddItem(f) + } + + // Since the MaxBatch threshold has been reached, we expect the + // first 20 filters to be persisted. + waitForItems(db, fs[:20]...) + + // Since the last 10 filters don't reach the threshold and since + // the ticker has definitely not ticked yet, we don't expect the + // last 10 filters to be in the db yet. + require.False(t, db.hasItems(fs[21:]...)) + }) + + t.Run("stress test", func(t *testing.T) { + t.Parallel() + + // Create a mock filters DB. + db := newMockItemsDB() + + // Construct a new BatchWriter backed by the mock db. + // Make the DB writes ticker duration extra long so that we + // can explicitly test that the batch gets persisted if the + // MaxBatch threshold is reached. + b := NewBatchWriter[*item](&BatchWriterConfig[*item]{ + QueueBufferSize: 5, + MaxBatch: 5, + DBWritesTickerDuration: time.Millisecond * 2, + PutItems: db.PutItems, + }) + b.Start() + t.Cleanup(b.Stop) + + // Generate lots of filters and add each to the batch writer. + // Sleep for a bit between each filter to ensure that we + // sometimes hit the timeout write and sometimes the threshold + // write. + fs := genFilterSet(1000) + for _, f := range fs { + b.AddItem(f) + + n := rand.Intn(3) + time.Sleep(time.Duration(n) * time.Millisecond) + } + + // Since the MaxBatch threshold has been reached, we expect the + // first 20 filters to be persisted. + waitForItems(db, fs...) + }) +} + +type item struct { + i int +} + +// mockItemsDB is a mock DB that holds a set of items. +type mockItemsDB struct { + items map[int]bool + mu sync.Mutex +} + +// newMockItemsDB constructs a new mockItemsDB. +func newMockItemsDB() *mockItemsDB { + return &mockItemsDB{ + items: make(map[int]bool), + } +} + +// hasItems returns true if the db contains all the given items. +func (m *mockItemsDB) hasItems(items ...*item) bool { + m.mu.Lock() + defer m.mu.Unlock() + + for _, i := range items { + _, ok := m.items[i.i] + if !ok { + return false + } + } + + return true +} + +// PutItems adds a set of items to the db. +func (m *mockItemsDB) PutItems(items ...*item) error { + m.mu.Lock() + defer m.mu.Unlock() + + for _, i := range items { + m.items[i.i] = true + } + + return nil +} + +// genItemSet generates a set of numFilters items. +func genFilterSet(numFilters int) []*item { + res := make([]*item, numFilters) + for i := 0; i < numFilters; i++ { + res[i] = &item{i: i} + } + + return res +} + +// pollInterval is a constant specifying a 200 ms interval. +const pollInterval = 200 * time.Millisecond + +// waitFor is a helper test function that will wait for a timeout period of +// time until the passed predicate returns true. This function is helpful as +// timing doesn't always line up well when running integration tests with +// several running lnd nodes. This function gives callers a way to assert that +// some property is upheld within a particular time frame. +func waitFor(pred func() bool, timeout time.Duration) error { + exitTimer := time.After(timeout) + result := make(chan bool, 1) + + for { + <-time.After(pollInterval) + + go func() { + result <- pred() + }() + + // Each time we call the pred(), we expect a result to be + // returned otherwise it will timeout. + select { + case <-exitTimer: + return fmt.Errorf("predicate not satisfied after " + + "time out") + + case succeed := <-result: + if succeed { + return nil + } + } + } +} diff --git a/chanutils/log.go b/chanutils/log.go new file mode 100644 index 00000000..2a8629ee --- /dev/null +++ b/chanutils/log.go @@ -0,0 +1,26 @@ +package chanutils + +import "github.com/btcsuite/btclog" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + DisableLog() +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until either UseLogger or SetLogWriter are called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/chanutils/queue.go b/chanutils/queue.go new file mode 100644 index 00000000..3dc49dd9 --- /dev/null +++ b/chanutils/queue.go @@ -0,0 +1,143 @@ +package chanutils + +import ( + "sync" + + "github.com/lightninglabs/neutrino/cache/lru" +) + +const ( + // DefaultQueueSize is the default size to use for concurrent queues. + DefaultQueueSize = 10 +) + +// ConcurrentQueue is a typed concurrent-safe FIFO queue with unbounded +// capacity. Clients interact with the queue by pushing items into the in +// channel and popping items from the out channel. There is a goroutine that +// manages moving items from the in channel to the out channel in the correct +// order that must be started by calling Start(). +type ConcurrentQueue[T any] struct { + started sync.Once + stopped sync.Once + + chanIn chan T + chanOut chan T + overflow *lru.List[T] + + wg sync.WaitGroup + quit chan struct{} +} + +// NewConcurrentQueue constructs a ConcurrentQueue. The bufferSize parameter is +// the capacity of the output channel. When the size of the queue is below this +// threshold, pushes do not incur the overhead of the less efficient overflow +// structure. +func NewConcurrentQueue[T any](bufferSize int) *ConcurrentQueue[T] { + return &ConcurrentQueue[T]{ + chanIn: make(chan T), + chanOut: make(chan T, bufferSize), + overflow: lru.NewList[T](), + quit: make(chan struct{}), + } +} + +// ChanIn returns a channel that can be used to push new items into the queue. +func (cq *ConcurrentQueue[T]) ChanIn() chan<- T { + return cq.chanIn +} + +// ChanOut returns a channel that can be used to pop items from the queue. +func (cq *ConcurrentQueue[T]) ChanOut() <-chan T { + return cq.chanOut +} + +// Start begins a goroutine that manages moving items from the in channel to the +// out channel. The queue tries to move items directly to the out channel +// minimize overhead, but if the out channel is full it pushes items to an +// overflow queue. This must be called before using the queue. +func (cq *ConcurrentQueue[T]) Start() { + cq.started.Do(cq.start) +} + +func (cq *ConcurrentQueue[T]) start() { + cq.wg.Add(1) + go func() { + defer cq.wg.Done() + + readLoop: + for { + nextElement := cq.overflow.Front() + if nextElement == nil { + // Overflow queue is empty so incoming items can + // be pushed directly to the output channel. If + // output channel is full though, push to + // overflow. + select { + case item, ok := <-cq.chanIn: + if !ok { + log.Warnf("ConcurrentQueue " + + "has exited due to " + + "the input channel " + + "being closed") + + break readLoop + } + select { + case cq.chanOut <- item: + // Optimistically push directly + // to chanOut. + default: + cq.overflow.PushBack(item) + } + case <-cq.quit: + return + } + } else { + // Overflow queue is not empty, so any new items + // get pushed to the back to preserve order. + select { + case item, ok := <-cq.chanIn: + if !ok { + log.Warnf("ConcurrentQueue " + + "has exited due to " + + "the input channel " + + "being closed") + + break readLoop + } + cq.overflow.PushBack(item) + case cq.chanOut <- nextElement.Value: + cq.overflow.Remove(nextElement) + case <-cq.quit: + return + } + } + } + + // Incoming channel has been closed. Empty overflow queue into + // the outgoing channel. + nextElement := cq.overflow.Front() + for nextElement != nil { + select { + case cq.chanOut <- nextElement.Value: + cq.overflow.Remove(nextElement) + case <-cq.quit: + return + } + nextElement = cq.overflow.Front() + } + + // Close outgoing channel. + close(cq.chanOut) + }() +} + +// Stop ends the goroutine that moves items from the in channel to the out +// channel. This does not clear the queue state, so the queue can be restarted +// without dropping items. +func (cq *ConcurrentQueue[T]) Stop() { + cq.stopped.Do(func() { + close(cq.quit) + cq.wg.Wait() + }) +} diff --git a/filterdb/db.go b/filterdb/db.go index a4d4bb13..5539ba27 100644 --- a/filterdb/db.go +++ b/filterdb/db.go @@ -18,9 +18,13 @@ var ( // regBucket is the bucket that stores the regular filters. regBucket = []byte("regular") + + // ErrFilterNotFound is returned when a filter for a target block hash + // is unable to be located. + ErrFilterNotFound = fmt.Errorf("unable to find filter") ) -// FilterType is a enum-like type that represents the various filter types +// FilterType is an enum-like type that represents the various filter types // currently defined. type FilterType uint8 @@ -30,21 +34,27 @@ const ( RegularFilter FilterType = iota ) -var ( - // ErrFilterNotFound is returned when a filter for a target block hash is - // unable to be located. - ErrFilterNotFound = fmt.Errorf("unable to find filter") -) +// FilterData holds all the info about a filter required to store it. +type FilterData struct { + // Filter is the actual filter to be stored. + Filter *gcs.Filter + + // BlockHash is the block header hash of the block associated with the + // Filter. + BlockHash *chainhash.Hash + + // Type is the filter type. + Type FilterType +} // FilterDatabase is an interface which represents an object that is capable of -// storing and retrieving filters according to their corresponding block hash and -// also their filter type. +// storing and retrieving filters according to their corresponding block hash +// and also their filter type. // // TODO(roasbeef): similar interface for headerfs? type FilterDatabase interface { - // PutFilter stores a filter with the given hash and type to persistent - // storage. - PutFilter(*chainhash.Hash, *gcs.Filter, FilterType) error + // PutFilters stores a set of filters to persistent storage. + PutFilters(...*FilterData) error // FetchFilter attempts to fetch a filter with the given hash and type // from persistent storage. In the case that a filter matching the @@ -52,7 +62,8 @@ type FilterDatabase interface { // returned. FetchFilter(*chainhash.Hash, FilterType) (*gcs.Filter, error) - // PurgeFilters purge all filters with a given type from persistent storage. + // PurgeFilters purge all filters with a given type from persistent + // storage. PurgeFilters(FilterType) error } @@ -117,10 +128,13 @@ func (f *FilterStore) PurgeFilters(fType FilterType) error { switch fType { case RegularFilter: - if err := filters.DeleteNestedBucket(regBucket); err != nil { + err := filters.DeleteNestedBucket(regBucket) + if err != nil { return err } - if _, err := filters.CreateBucket(regBucket); err != nil { + + _, err = filters.CreateBucket(regBucket) + if err != nil { return err } default: @@ -149,35 +163,46 @@ func putFilter(bucket walletdb.ReadWriteBucket, hash *chainhash.Hash, return bucket.Put(hash[:], bytes) } -// PutFilter stores a filter with the given hash and type to persistent -// storage. +// PutFilters stores a set of filters to persistent storage. // // NOTE: This method is a part of the FilterDatabase interface. -func (f *FilterStore) PutFilter(hash *chainhash.Hash, - filter *gcs.Filter, fType FilterType) error { - - return walletdb.Update(f.db, func(tx walletdb.ReadWriteTx) error { +func (f *FilterStore) PutFilters(filterList ...*FilterData) error { + var updateErr error + err := walletdb.Batch(f.db, func(tx walletdb.ReadWriteTx) error { filters := tx.ReadWriteBucket(filterBucket) + regularFilterBkt := filters.NestedReadWriteBucket(regBucket) + + for _, filterData := range filterList { + var targetBucket walletdb.ReadWriteBucket + switch filterData.Type { + case RegularFilter: + targetBucket = regularFilterBkt + default: + updateErr = fmt.Errorf("unknown filter "+ + "type: %v", filterData.Type) + + return nil + } - var targetBucket walletdb.ReadWriteBucket - switch fType { - case RegularFilter: - targetBucket = filters.NestedReadWriteBucket(regBucket) - default: - return fmt.Errorf("unknown filter type: %v", fType) - } - - if filter == nil { - return targetBucket.Put(hash[:], nil) - } + err := putFilter( + targetBucket, filterData.BlockHash, + filterData.Filter, + ) + if err != nil { + return err + } - bytes, err := filter.NBytes() - if err != nil { - return err + log.Tracef("Wrote filter for block %s, type %d", + &filterData.BlockHash, filterData.Type) } - return targetBucket.Put(hash[:], bytes) + return nil }) + if err != nil { + return err + } + + return updateErr } // FetchFilter attempts to fetch a filter with the given hash and type from diff --git a/filterdb/db_test.go b/filterdb/db_test.go index 6afd97cf..6d444d55 100644 --- a/filterdb/db_test.go +++ b/filterdb/db_test.go @@ -1,10 +1,7 @@ package filterdb import ( - "io/ioutil" "math/rand" - "os" - "reflect" "testing" "time" @@ -14,116 +11,93 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcwallet/walletdb" _ "github.com/btcsuite/btcwallet/walletdb/bdb" + "github.com/stretchr/testify/require" ) -func createTestDatabase() (func(), FilterDatabase, error) { - tempDir, err := ioutil.TempDir("", "neutrino") - if err != nil { - return nil, nil, err - } +func createTestDatabase(t *testing.T) FilterDatabase { + tempDir := t.TempDir() db, err := walletdb.Create( "bdb", tempDir+"/test.db", true, time.Second*10, ) - if err != nil { - return nil, nil, err - } - - cleanUp := func() { - os.RemoveAll(tempDir) - db.Close() - } + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) filterDB, err := New(db, chaincfg.SimNetParams) - if err != nil { - return nil, nil, err - } + require.NoError(t, err) - return cleanUp, filterDB, nil + return filterDB } +// TestGenesisFilterCreation tests the fetching of the genesis block filter. func TestGenesisFilterCreation(t *testing.T) { - cleanUp, database, err := createTestDatabase() - defer cleanUp() - if err != nil { - t.Fatalf("unable to create test db: %v", err) - } - - genesisHash := chaincfg.SimNetParams.GenesisHash + var ( + database = createTestDatabase(t) + genesisHash = chaincfg.SimNetParams.GenesisHash + ) // With the database initialized, we should be able to fetch the // regular filter for the genesis block. - regGenesisFilter, err := database.FetchFilter(genesisHash, RegularFilter) - if err != nil { - t.Fatalf("unable to fetch regular genesis filter: %v", err) - } + regGenesisFilter, err := database.FetchFilter( + genesisHash, RegularFilter, + ) + require.NoError(t, err) // The regular filter should be non-nil as the gensis block's output // and the coinbase txid should be indexed. - if regGenesisFilter == nil { - t.Fatalf("regular genesis filter is nil") - } + require.NotNil(t, regGenesisFilter) } -func genRandFilter(numElements uint32) (*gcs.Filter, error) { +func genRandFilter(t *testing.T, numElements uint32) *gcs.Filter { elements := make([][]byte, numElements) for i := uint32(0); i < numElements; i++ { var elem [20]byte - if _, err := rand.Read(elem[:]); err != nil { - return nil, err - } + _, err := rand.Read(elem[:]) + require.NoError(t, err) elements[i] = elem[:] } var key [16]byte - if _, err := rand.Read(key[:]); err != nil { - return nil, err - } + _, err := rand.Read(key[:]) + require.NoError(t, err) filter, err := gcs.BuildGCSFilter( builder.DefaultP, builder.DefaultM, key, elements, ) - if err != nil { - return nil, err - } + require.NoError(t, err) - return filter, nil + return filter } +// TestFilterStorage test writing to and reading from the filter DB. func TestFilterStorage(t *testing.T) { - // TODO(roasbeef): use testing.Quick - cleanUp, database, err := createTestDatabase() - defer cleanUp() - if err != nil { - t.Fatalf("unable to create test db: %v", err) - } + database := createTestDatabase(t) // We'll generate a random block hash to create our test filters // against. var randHash chainhash.Hash - if _, err := rand.Read(randHash[:]); err != nil { - t.Fatalf("unable to generate random hash: %v", err) - } + _, err := rand.Read(randHash[:]) + require.NoError(t, err) - // First, we'll create and store a random fitler for the regular filter + // First, we'll create and store a random filter for the regular filter // type for the block hash generate above. - regFilter, err := genRandFilter(100) - if err != nil { - t.Fatalf("unable to create random filter: %v", err) - } - err = database.PutFilter(&randHash, regFilter, RegularFilter) - if err != nil { - t.Fatalf("unable to store regular filter: %v", err) + regFilter := genRandFilter(t, 100) + + filter := &FilterData{ + Filter: regFilter, + BlockHash: &randHash, + Type: RegularFilter, } + err = database.PutFilters(filter) + require.NoError(t, err) + // With the filter stored, we should be able to retrieve the filter // without any issue, and it should match the stored filter exactly. regFilterDB, err := database.FetchFilter(&randHash, RegularFilter) - if err != nil { - t.Fatalf("unable to retrieve reg filter: %v", err) - } - if !reflect.DeepEqual(regFilter, regFilterDB) { - t.Fatalf("regular filter doesn't match!") - } + require.NoError(t, err) + require.Equal(t, regFilter, regFilterDB) } diff --git a/filterdb/log.go b/filterdb/log.go new file mode 100644 index 00000000..bd726dbf --- /dev/null +++ b/filterdb/log.go @@ -0,0 +1,26 @@ +package filterdb + +import "github.com/btcsuite/btclog" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + DisableLog() +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until either UseLogger or SetLogWriter are called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/log.go b/log.go index 748a50e4..5e57fbbc 100644 --- a/log.go +++ b/log.go @@ -8,6 +8,8 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btclog" "github.com/lightninglabs/neutrino/blockntfns" + "github.com/lightninglabs/neutrino/chanutils" + "github.com/lightninglabs/neutrino/filterdb" "github.com/lightninglabs/neutrino/pushtx" "github.com/lightninglabs/neutrino/query" ) @@ -41,4 +43,6 @@ func UseLogger(logger btclog.Logger) { pushtx.UseLogger(logger) connmgr.UseLogger(logger) query.UseLogger(logger) + filterdb.UseLogger(logger) + chanutils.UseLogger(logger) } diff --git a/neutrino.go b/neutrino.go index 9e1e1e1a..7ee45edd 100644 --- a/neutrino.go +++ b/neutrino.go @@ -25,6 +25,7 @@ import ( "github.com/lightninglabs/neutrino/banman" "github.com/lightninglabs/neutrino/blockntfns" "github.com/lightninglabs/neutrino/cache/lru" + "github.com/lightninglabs/neutrino/chanutils" "github.com/lightninglabs/neutrino/filterdb" "github.com/lightninglabs/neutrino/headerfs" "github.com/lightninglabs/neutrino/pushtx" @@ -661,6 +662,7 @@ type ChainService struct { // nolint:maligned broadcaster *pushtx.Broadcaster banStore banman.Store workManager query.WorkManager + filterBatchWriter *chanutils.BatchWriter[*filterdb.FilterData] // peerSubscribers is a slice of active peer subscriptions, that we // will notify each time a new peer is connected. @@ -748,6 +750,21 @@ func NewChainService(cfg Config) (*ChainService, error) { return nil, err } + if s.persistToDisk { + cfg := &chanutils.BatchWriterConfig[*filterdb.FilterData]{ + QueueBufferSize: chanutils.DefaultQueueSize, + MaxBatch: 1000, + DBWritesTickerDuration: time.Millisecond * 500, + PutItems: s.FilterDB.PutFilters, + } + + batchWriter := chanutils.NewBatchWriter[*filterdb.FilterData]( + cfg, + ) + + s.filterBatchWriter = batchWriter + } + filterCacheSize := DefaultFilterCacheSize if cfg.FilterCacheSize != 0 { filterCacheSize = cfg.FilterCacheSize @@ -1606,6 +1623,10 @@ func (s *ChainService) Start() error { err) } + if s.persistToDisk { + s.filterBatchWriter.Start() + } + go s.connManager.Start() // Start the peer handler which in turn starts the address and block @@ -1645,6 +1666,10 @@ func (s *ChainService) Stop() error { returnErr = err } + if s.persistToDisk { + s.filterBatchWriter.Stop() + } + // Signal the remaining goroutines to quit. close(s.quit) s.wg.Wait() diff --git a/query.go b/query.go index 3bbfcb1b..fb5deb13 100644 --- a/query.go +++ b/query.go @@ -534,15 +534,11 @@ func (q *cfiltersQuery) handleResponse(req, resp wire.Message, } if q.cs.persistToDisk { - err = q.cs.FilterDB.PutFilter( - &response.BlockHash, filter, dbFilterType, - ) - if err != nil { - log.Warnf("Couldn't write filter to filterDB: %v", err) - } - - log.Tracef("Wrote filter for block %s, type %d", - &response.BlockHash, dbFilterType) + q.cs.filterBatchWriter.AddItem(&filterdb.FilterData{ + Filter: filter, + BlockHash: &response.BlockHash, + Type: dbFilterType, + }) } // We delete the entry for this filter from the headerIndex to indicate