diff --git a/blockmanager.go b/blockmanager.go index c05dc33c..9c9e6f4c 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -28,10 +28,6 @@ import ( ) const ( - // maxTimeOffset is the maximum duration a block time is allowed to be - // ahead of the current time. This is currently 2 hours. - maxTimeOffset = 2 * time.Hour - // numMaxMemHeaders is the max number of headers to store in memory for // a particular peer. By bounding this value, we're able to closely // control our effective memory usage during initial sync and re-org @@ -2374,15 +2370,19 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { return } - // For checking to make sure blocks aren't too far in the future as of - // the time we receive the headers message. - maxTimestamp := b.cfg.TimeSource.AdjustedTime(). - Add(maxTimeOffset) - // We'll attempt to write the entire batch of validated headers - // atomically in order to improve peformance. + // atomically in order to improve performance. headerWriteBatch := make([]headerfs.BlockHeader, 0, len(msg.Headers)) + // Explicitly check that each header in msg.Headers builds off of the + // previous one. This is a quick sanity check to avoid doing the more + // expensive checks below if we know the headers are invalid. + if !areHeadersConnected(msg.Headers) { + log.Warnf("Headers received from peer don't connect") + hmsg.peer.Disconnect() + return + } + // Process all of the received headers ensuring each one connects to // the previous and that checkpoints match. receivedCheckpoint := false @@ -2411,8 +2411,12 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { prevNode := prevNodeEl prevHash := prevNode.Header.BlockHash() if prevHash.IsEqual(&blockHeader.PrevBlock) { - err := b.checkHeaderSanity(blockHeader, maxTimestamp, - false) + prevNodeHeight := prevNode.Height + prevNodeHeader := prevNode.Header + err := b.checkHeaderSanity( + blockHeader, false, prevNodeHeight, + &prevNodeHeader, + ) if err != nil { log.Warnf("Header doesn't pass sanity check: "+ "%s -- disconnecting peer", err) @@ -2425,10 +2429,12 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { // This header checks out, so we'll add it to our write // batch. - headerWriteBatch = append(headerWriteBatch, headerfs.BlockHeader{ - BlockHeader: blockHeader, - Height: uint32(node.Height), - }) + headerWriteBatch = append( + headerWriteBatch, headerfs.BlockHeader{ + BlockHeader: blockHeader, + Height: uint32(node.Height), + }, + ) hmsg.peer.UpdateLastBlockHeight(node.Height) @@ -2520,8 +2526,27 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { }) totalWork := big.NewInt(0) for j, reorgHeader := range msg.Headers[i:] { - err = b.checkHeaderSanity(reorgHeader, - maxTimestamp, true) + // We have to get the parent's height and + // header to be able to contextually validate + // this header. + prevNodeHeight := backHeight + uint32(j) + + var prevNodeHeader *wire.BlockHeader + if i+j == 0 { + // Use backHead if we are using the + // first header in the Headers slice. + prevNodeHeader = backHead + } else { + // We can find the parent in the + // Headers slice by getting the header + // at index i+j-1. + prevNodeHeader = msg.Headers[i+j-1] + } + + err = b.checkHeaderSanity( + reorgHeader, true, + int32(prevNodeHeight), prevNodeHeader, + ) if err != nil { log.Warnf("Header doesn't pass sanity"+ " check: %s -- disconnecting "+ @@ -2706,180 +2731,70 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { b.newHeadersSignal.Broadcast() } -// checkHeaderSanity checks the PoW, and timestamp of a block header. -func (b *blockManager) checkHeaderSanity(blockHeader *wire.BlockHeader, - maxTimestamp time.Time, reorgAttempt bool) error { +// areHeadersConnected returns true if the passed block headers are connected to +// each other correctly. +func areHeadersConnected(headers []*wire.BlockHeader) bool { + var ( + lastHeader chainhash.Hash + emptyHash chainhash.Hash + ) + for _, blockHeader := range headers { + blockHash := blockHeader.BlockHash() - diff, err := b.calcNextRequiredDifficulty( - blockHeader.Timestamp, reorgAttempt) - if err != nil { - return err - } - stubBlock := btcutil.NewBlock(&wire.MsgBlock{ - Header: *blockHeader, - }) - err = blockchain.CheckProofOfWork(stubBlock, - blockchain.CompactToBig(diff)) - if err != nil { - return err - } - // Ensure the block time is not too far in the future. - if blockHeader.Timestamp.After(maxTimestamp) { - return fmt.Errorf("block timestamp of %v is too far in the "+ - "future", blockHeader.Timestamp) + // If we haven't yet set lastHeader, set it now. + if lastHeader == emptyHash { + lastHeader = blockHash + + continue + } + + // Ensure that blockHeader.PrevBlock matches lastHeader. + if blockHeader.PrevBlock != lastHeader { + return false + } + + lastHeader = blockHash } - return nil + + return true } -// calcNextRequiredDifficulty calculates the required difficulty for the block -// after the passed previous block node based on the difficulty retarget rules. -func (b *blockManager) calcNextRequiredDifficulty(newBlockTime time.Time, - reorgAttempt bool) (uint32, error) { +// checkHeaderSanity performs contextual and context-less checks on the passed +// wire.BlockHeader. This function calls blockchain.CheckBlockHeaderContext for +// the contextual check and blockchain.CheckBlockHeaderSanity for context-less +// checks. +func (b *blockManager) checkHeaderSanity(blockHeader *wire.BlockHeader, + reorgAttempt bool, prevNodeHeight int32, + prevNodeHeader *wire.BlockHeader) error { + // Create the lightHeaderCtx for the blockHeader's parent. hList := b.headerList if reorgAttempt { hList = b.reorgList } - lastNode := hList.Back() - - // Genesis block. - if lastNode == nil { - return b.cfg.ChainParams.PowLimitBits, nil - } - - // Return the previous block's difficulty requirements if this block - // is not at a difficulty retarget interval. - if (lastNode.Height+1)%b.blocksPerRetarget != 0 { - // For networks that support it, allow special reduction of the - // required difficulty once too much time has elapsed without - // mining a block. - if b.cfg.ChainParams.ReduceMinDifficulty { - // Return minimum difficulty when more than the desired - // amount of time has elapsed without mining a block. - reductionTime := int64( - b.cfg.ChainParams.MinDiffReductionTime / - time.Second) - allowMinTime := lastNode.Header.Timestamp.Unix() + - reductionTime - if newBlockTime.Unix() > allowMinTime { - return b.cfg.ChainParams.PowLimitBits, nil - } - - // The block was mined within the desired timeframe, so - // return the difficulty for the last block which did - // not have the special minimum difficulty rule applied. - prevBits, err := b.findPrevTestNetDifficulty(hList) - if err != nil { - return 0, err - } - return prevBits, nil - } + parentHeaderCtx := newLightHeaderCtx( + prevNodeHeight, prevNodeHeader, b.cfg.BlockHeaders, hList, + ) - // For the main network (or any unrecognized networks), simply - // return the previous block's difficulty requirements. - return lastNode.Header.Bits, nil - } + // Create a lightChainCtx as well. + chainCtx := newLightChainCtx( + &b.cfg.ChainParams, b.blocksPerRetarget, b.minRetargetTimespan, + b.maxRetargetTimespan, + ) - // Get the block node at the previous retarget (targetTimespan days - // worth of blocks). - firstNode, err := b.cfg.BlockHeaders.FetchHeaderByHeight( - uint32(lastNode.Height + 1 - b.blocksPerRetarget), + var emptyFlags blockchain.BehaviorFlags + err := blockchain.CheckBlockHeaderContext( + blockHeader, parentHeaderCtx, emptyFlags, chainCtx, true, ) if err != nil { - return 0, err - } - - // Limit the amount of adjustment that can occur to the previous - // difficulty. - actualTimespan := lastNode.Header.Timestamp.Unix() - - firstNode.Timestamp.Unix() - adjustedTimespan := actualTimespan - if actualTimespan < b.minRetargetTimespan { - adjustedTimespan = b.minRetargetTimespan - } else if actualTimespan > b.maxRetargetTimespan { - adjustedTimespan = b.maxRetargetTimespan - } - - // Calculate new target difficulty as: - // currentDifficulty * (adjustedTimespan / targetTimespan) - // The result uses integer division which means it will be slightly - // rounded down. Bitcoind also uses integer division to calculate this - // result. - oldTarget := blockchain.CompactToBig(lastNode.Header.Bits) - newTarget := new(big.Int).Mul(oldTarget, big.NewInt(adjustedTimespan)) - targetTimeSpan := int64(b.cfg.ChainParams.TargetTimespan / - time.Second) - newTarget.Div(newTarget, big.NewInt(targetTimeSpan)) - - // Limit new value to the proof of work limit. - if newTarget.Cmp(b.cfg.ChainParams.PowLimit) > 0 { - newTarget.Set(b.cfg.ChainParams.PowLimit) - } - - // Log new target difficulty and return it. The new target logging is - // intentionally converting the bits back to a number instead of using - // newTarget since conversion to the compact representation loses - // precision. - newTargetBits := blockchain.BigToCompact(newTarget) - log.Debugf("Difficulty retarget at block height %d", lastNode.Height+1) - log.Debugf("Old target %08x (%064x)", lastNode.Header.Bits, oldTarget) - log.Debugf("New target %08x (%064x)", newTargetBits, - blockchain.CompactToBig(newTargetBits)) - log.Debugf("Actual timespan %v, adjusted timespan %v, target timespan %v", - time.Duration(actualTimespan)*time.Second, - time.Duration(adjustedTimespan)*time.Second, - b.cfg.ChainParams.TargetTimespan) - - return newTargetBits, nil -} - -// findPrevTestNetDifficulty returns the difficulty of the previous block which -// did not have the special testnet minimum difficulty rule applied. -func (b *blockManager) findPrevTestNetDifficulty(hList headerlist.Chain) (uint32, error) { - startNode := hList.Back() - - // Genesis block. - if startNode == nil { - return b.cfg.ChainParams.PowLimitBits, nil - } - - // Search backwards through the chain for the last block without - // the special rule applied. - iterEl := startNode - iterNode := &startNode.Header - iterHeight := startNode.Height - for iterNode != nil && iterHeight%b.blocksPerRetarget != 0 && - iterNode.Bits == b.cfg.ChainParams.PowLimitBits { // nolint - - // Get the previous block node. This function is used over - // simply accessing iterNode.parent directly as it will - // dynamically create previous block nodes as needed. This - // helps allow only the pieces of the chain that are needed - // to remain in memory. - iterHeight-- - el := iterEl.Prev() - if el != nil { - iterNode = &el.Header - } else { - node, err := b.cfg.BlockHeaders.FetchHeaderByHeight( - uint32(iterHeight), - ) - if err != nil { - log.Errorf("GetBlockByHeight: %s", err) - return 0, err - } - iterNode = node - } + return err } - // Return the found difficulty or the minimum difficulty if no - // appropriate block was found. - lastBits := b.cfg.ChainParams.PowLimitBits - if iterNode != nil { - lastBits = iterNode.Bits - } - return lastBits, nil + return blockchain.CheckBlockHeaderSanity( + blockHeader, b.cfg.ChainParams.PowLimit, b.cfg.TimeSource, + emptyFlags, + ) } // onBlockConnected queues a block notification that extends the current chain. @@ -2953,3 +2868,179 @@ func (b *blockManager) NotificationsSinceHeight( return blocks, bestHeight, nil } + +// lightChainCtx is an implementation of the blockchain.ChainCtx interface and +// gives a neutrino node the ability to contextually validate headers it +// receives. +type lightChainCtx struct { + params *chaincfg.Params + blocksPerRetarget int32 + minRetargetTimespan int64 + maxRetargetTimespan int64 +} + +// newLightChainCtx returns a new lightChainCtx instance from the passed +// arguments. +func newLightChainCtx(params *chaincfg.Params, blocksPerRetarget int32, + minRetargetTimespan, maxRetargetTimespan int64) *lightChainCtx { + + return &lightChainCtx{ + params: params, + blocksPerRetarget: blocksPerRetarget, + minRetargetTimespan: minRetargetTimespan, + maxRetargetTimespan: maxRetargetTimespan, + } +} + +// ChainParams returns the configured chain parameters. +// +// NOTE: Part of the blockchain.ChainCtx interface. +func (l *lightChainCtx) ChainParams() *chaincfg.Params { + return l.params +} + +// BlocksPerRetarget returns the number of blocks before retargeting occurs. +// +// NOTE: Part of the blockchain.ChainCtx interface. +func (l *lightChainCtx) BlocksPerRetarget() int32 { + return l.blocksPerRetarget +} + +// MinRetargetTimespan returns the minimum amount of time used in the +// difficulty calculation. +// +// NOTE: Part of the blockchain.ChainCtx interface. +func (l *lightChainCtx) MinRetargetTimespan() int64 { + return l.minRetargetTimespan +} + +// MaxRetargetTimespan returns the maximum amount of time used in the +// difficulty calculation. +// +// NOTE: Part of the blockchain.ChainCtx interface. +func (l *lightChainCtx) MaxRetargetTimespan() int64 { + return l.maxRetargetTimespan +} + +// VerifyCheckpoint returns false as the lightChainCtx does not need to validate +// checkpoints. This is already done inside the handleHeadersMsg function. +// +// NOTE: Part of the blockchain.ChainCtx interface. +func (l *lightChainCtx) VerifyCheckpoint(int32, *chainhash.Hash) bool { + return false +} + +// FindPreviousCheckpoint returns nil values since the lightChainCtx does not +// need to validate against checkpoints. This is already done inside the +// handleHeadersMsg function. +// +// NOTE: Part of the blockchain.ChainCtx interface. +func (l *lightChainCtx) FindPreviousCheckpoint() (blockchain.HeaderCtx, error) { + return nil, nil +} + +// lightHeaderCtx is an implementation of the blockchain.HeaderCtx interface. +// It is used so neutrino can perform contextual header validation checks. +type lightHeaderCtx struct { + height int32 + bits uint32 + timestamp int64 + + store headerfs.BlockHeaderStore + headerList headerlist.Chain +} + +// newLightHeaderCtx returns an instance of a lightHeaderCtx to be used when +// contextually validating headers. +func newLightHeaderCtx(height int32, header *wire.BlockHeader, + store headerfs.BlockHeaderStore, + headerList headerlist.Chain) *lightHeaderCtx { + + return &lightHeaderCtx{ + height: height, + bits: header.Bits, + timestamp: header.Timestamp.Unix(), + store: store, + headerList: headerList, + } +} + +// Height returns the height for the underlying header this context was created +// from. +// +// NOTE: Part of the blockchain.HeaderCtx interface. +func (l *lightHeaderCtx) Height() int32 { + return l.height +} + +// Bits returns the difficulty bits for the underlying header this context was +// created from. +// +// NOTE: Part of the blockchain.HeaderCtx interface. +func (l *lightHeaderCtx) Bits() uint32 { + return l.bits +} + +// Timestamp returns the timestamp for the underlying header this context was +// created from. +// +// NOTE: Part of the blockchain.HeaderCtx interface. +func (l *lightHeaderCtx) Timestamp() int64 { + return l.timestamp +} + +// Parent returns the parent of the underlying header this context was created +// from. +// +// NOTE: Part of the blockchain.HeaderCtx interface. +func (l *lightHeaderCtx) Parent() blockchain.HeaderCtx { + // The parent is just an ancestor with distance 1. + return l.RelativeAncestorCtx(1) +} + +// RelativeAncestorCtx returns the ancestor that is distance blocks before the +// underlying header in the chain. +// +// NOTE: Part of the blockchain.HeaderCtx interface. +func (l *lightHeaderCtx) RelativeAncestorCtx( + distance int32) blockchain.HeaderCtx { + + ancestorHeight := l.height - distance + + var ( + ancestor *wire.BlockHeader + err error + ) + + // We'll first consult the headerList to see if the ancestor can be + // found there. If that fails, we'll look up the header in the header + // store. + iterNode := l.headerList.Back() + + // Keep looping until iterNode is nil or the ancestor height is + // encountered. + for iterNode != nil { + if iterNode.Height == ancestorHeight { + // We've found the ancestor. + ancestor = &iterNode.Header + break + } + + // We haven't hit the ancestor header yet, so we'll go back one. + iterNode = iterNode.Prev() + } + + if ancestor == nil { + // Lookup the ancestor in the header store. + ancestor, err = l.store.FetchHeaderByHeight( + uint32(ancestorHeight), + ) + if err != nil { + return nil + } + } + + return newLightHeaderCtx( + ancestorHeight, ancestor, l.store, l.headerList, + ) +} diff --git a/blockmanager_test.go b/blockmanager_test.go index e45cca86..45554b7c 100644 --- a/blockmanager_test.go +++ b/blockmanager_test.go @@ -3,17 +3,18 @@ package neutrino import ( "encoding/binary" "fmt" - "io/ioutil" "math/rand" - "os" + "reflect" "strings" "testing" "time" + "github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/btcutil/gcs" "github.com/btcsuite/btcd/btcutil/gcs/builder" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/integration/rpctest" "github.com/btcsuite/btcd/peer" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" @@ -22,13 +23,14 @@ import ( "github.com/lightninglabs/neutrino/blockntfns" "github.com/lightninglabs/neutrino/headerfs" "github.com/lightninglabs/neutrino/query" + "github.com/stretchr/testify/require" ) const ( - // maxHeight is the height we will generate filter headers up to. We use an odd - // number of checkpoints to ensure we can test cases where the block manager is - // only able to fetch filter headers for one checkpoint interval rather than - // two. + // maxHeight is the height we will generate filter headers up to. We use + // an odd number of checkpoints to ensure we can test cases where the + // block manager is only able to fetch filter headers for one checkpoint + // interval rather than two. maxHeight = 21 * uint32(wire.CFCheckptInterval) dbOpenTimeout = time.Second * 10 @@ -52,36 +54,27 @@ func (m *mockDispatcher) Query(requests []*query.Request, } // setupBlockManager initialises a blockManager to be used in tests. -func setupBlockManager() (*blockManager, headerfs.BlockHeaderStore, - *headerfs.FilterHeaderStore, func(), error) { +func setupBlockManager(t *testing.T) (*blockManager, headerfs.BlockHeaderStore, + *headerfs.FilterHeaderStore, error) { // Set up the block and filter header stores. - tempDir, err := ioutil.TempDir("", "neutrino") - if err != nil { - return nil, nil, nil, nil, fmt.Errorf("Failed to create "+ - "temporary directory: %s", err) - } - + tempDir := t.TempDir() db, err := walletdb.Create( "bdb", tempDir+"/weks.db", true, dbOpenTimeout, ) if err != nil { - os.RemoveAll(tempDir) - return nil, nil, nil, nil, fmt.Errorf("Error opening DB: %s", - err) + return nil, nil, nil, fmt.Errorf("error opening DB: %s", err) } - cleanUp := func() { - db.Close() - os.RemoveAll(tempDir) - } + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) hdrStore, err := headerfs.NewBlockHeaderStore( tempDir, db, &chaincfg.SimNetParams, ) if err != nil { - cleanUp() - return nil, nil, nil, nil, fmt.Errorf("Error creating block "+ + return nil, nil, nil, fmt.Errorf("error creating block "+ "header store: %s", err) } @@ -90,8 +83,7 @@ func setupBlockManager() (*blockManager, headerfs.BlockHeaderStore, nil, ) if err != nil { - cleanUp() - return nil, nil, nil, nil, fmt.Errorf("Error creating filter "+ + return nil, nil, nil, fmt.Errorf("error creating filter "+ "header store: %s", err) } @@ -101,14 +93,17 @@ func setupBlockManager() (*blockManager, headerfs.BlockHeaderStore, BlockHeaders: hdrStore, RegFilterHeaders: cfStore, QueryDispatcher: &mockDispatcher{}, - BanPeer: func(string, banman.Reason) error { return nil }, + TimeSource: blockchain.NewMedianTime(), + BanPeer: func(string, banman.Reason) error { + return nil + }, }) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("unable to create "+ + return nil, nil, nil, fmt.Errorf("unable to create "+ "blockmanager: %v", err) } - return bm, hdrStore, cfStore, cleanUp, nil + return bm, hdrStore, cfStore, nil } // headers wraps the different headers and filters used throughout the tests. @@ -315,36 +310,29 @@ func TestBlockManagerInitialInterval(t *testing.T) { testDesc := fmt.Sprintf("permute=%v, partial=%v, repeat=%v", test.permute, test.partialInterval, test.repeat) - bm, hdrStore, cfStore, cleanUp, err := setupBlockManager() + bm, hdrStore, cfStore, err := setupBlockManager(t) if err != nil { t.Fatalf("unable to set up ChainService: %v", err) } - defer cleanUp() // Keep track of the filter headers and block headers. Since // the genesis headers are written automatically when the store // is created, we query it to add to the slices. genesisBlockHeader, _, err := hdrStore.ChainTip() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) genesisFilterHeader, _, err := cfStore.ChainTip() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - headers, err := generateHeaders(genesisBlockHeader, - genesisFilterHeader, nil) - if err != nil { - t.Fatalf("unable to generate headers: %v", err) - } + headers, err := generateHeaders( + genesisBlockHeader, genesisFilterHeader, nil, + ) + require.NoError(t, err) // Write all block headers but the genesis, since it is already // in the store. - if err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...); err != nil { - t.Fatalf("Error writing batch of headers: %s", err) - } + err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...) + require.NoError(t, err) // We emulate the case where a few filter headers are already // written to the store by writing 1/3 of the first interval. @@ -352,10 +340,7 @@ func TestBlockManagerInitialInterval(t *testing.T) { err = cfStore.WriteHeaders( headers.cfHeaders[1 : wire.CFCheckptInterval/3]..., ) - if err != nil { - t.Fatalf("Error writing batch of headers: %s", - err) - } + require.NoError(t, err) } // We set up a custom query batch method for this test, as we @@ -371,10 +356,7 @@ func TestBlockManagerInitialInterval(t *testing.T) { } responses, err := generateResponses(msgs, headers) - if err != nil { - t.Fatalf("unable to generate responses: %v", - err) - } + require.NoError(t, err) // We permute the response order if the test signals // that. @@ -459,20 +441,12 @@ func TestBlockManagerInitialInterval(t *testing.T) { // Finally make sure the filter header tip is what we expect. tip, tipHeight, err := cfStore.ChainTip() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if tipHeight != maxHeight { - t.Fatalf("expected tip height to be %v, was %v", - maxHeight, tipHeight) - } + require.Equal(t, maxHeight, tipHeight, "tip height") lastCheckpoint := headers.checkpoints[len(headers.checkpoints)-1] - if *tip != *lastCheckpoint { - t.Fatalf("expected tip to be %v, was %v", - lastCheckpoint, tip) - } + require.Equal(t, *lastCheckpoint, *tip, "tip") } } @@ -547,11 +521,8 @@ func TestBlockManagerInvalidInterval(t *testing.T) { for _, test := range testCases { test := test - bm, hdrStore, cfStore, cleanUp, err := setupBlockManager() - if err != nil { - t.Fatalf("unable to set up ChainService: %v", err) - } - defer cleanUp() + bm, hdrStore, cfStore, err := setupBlockManager(t) + require.NoError(t, err) // Create a mock peer to prevent panics when attempting to ban // a peer that served an invalid filter header. @@ -559,22 +530,17 @@ func TestBlockManagerInvalidInterval(t *testing.T) { mockPeer.Peer, err = peer.NewOutboundPeer( NewPeerConfig(mockPeer), "127.0.0.1:8333", ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Keep track of the filter headers and block headers. Since // the genesis headers are written automatically when the store // is created, we query it to add to the slices. genesisBlockHeader, _, err := hdrStore.ChainTip() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) genesisFilterHeader, _, err := cfStore.ChainTip() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + // To emulate a full node serving us filter headers derived // from different genesis than what we have, we flip a bit in // the genesis filter header. @@ -582,8 +548,8 @@ func TestBlockManagerInvalidInterval(t *testing.T) { genesisFilterHeader[0] ^= 1 } - headers, err := generateHeaders(genesisBlockHeader, - genesisFilterHeader, + headers, err := generateHeaders( + genesisBlockHeader, genesisFilterHeader, func(currentCFHeader *chainhash.Hash) { // If we are testing that each interval doesn't // line up properly with the previous, we flip @@ -592,16 +558,14 @@ func TestBlockManagerInvalidInterval(t *testing.T) { if test.intervalMisaligned { currentCFHeader[0] ^= 1 } - }) - if err != nil { - t.Fatalf("unable to generate headers: %v", err) - } + }, + ) + require.NoError(t, err) // Write all block headers but the genesis, since it is already // in the store. - if err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...); err != nil { - t.Fatalf("Error writing batch of headers: %s", err) - } + err = hdrStore.WriteHeaders(headers.blockHeaders[1:]...) + require.NoError(t, err) // We emulate the case where a few filter headers are already // written to the store by writing 1/3 of the first interval. @@ -609,10 +573,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { err = cfStore.WriteHeaders( headers.cfHeaders[1 : wire.CFCheckptInterval/3]..., ) - if err != nil { - t.Fatalf("Error writing batch of headers: %s", - err) - } + require.NoError(t, err) } bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( @@ -624,10 +585,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { msgs = append(msgs, q.Req) } responses, err := generateResponses(msgs, headers) - if err != nil { - t.Fatalf("unable to generate responses: %v", - err) - } + require.NoError(t, err) // Since we used the generated checkpoints when // creating the responses, we must flip the @@ -644,7 +602,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { } // If we are testing for intervals with invalid prev - // hashes, we flip a bit to corrup them, regardless of + // hashes, we flip a bit to corrupt them, regardless of // whether we are testing misaligned intervals. if test.invalidPrevHash { for i := range responses { @@ -760,12 +718,12 @@ func assertBadPeers(expBad map[string]struct{}, badPeers []string) error { for p := range expBad { remBad[p] = struct{}{} } - for _, peer := range badPeers { - _, ok := remBad[peer] + for _, p := range badPeers { + _, ok := remBad[p] if !ok { - return fmt.Errorf("did not expect %v to be bad", peer) + return fmt.Errorf("did not expect %v to be bad", p) } - delete(remBad, peer) + delete(remBad, p) } if len(remBad) != 0 { @@ -855,7 +813,9 @@ func TestBlockManagerDetectBadPeers(t *testing.T) { options ...QueryOption) { for p, resp := range answers { - pp, err := peer.NewOutboundPeer(&peer.Config{}, p) + pp, err := peer.NewOutboundPeer( + &peer.Config{}, p, + ) if err != nil { panic(err) } @@ -863,12 +823,15 @@ func TestBlockManagerDetectBadPeers(t *testing.T) { sp := &ServerPeer{ Peer: pp, } - checkResponse(sp, resp, make(chan struct{}), make(chan struct{})) + checkResponse( + sp, resp, make(chan struct{}), + make(chan struct{}), + ) } } - for _, peer := range peers { - test.filterAnswers(peer, answers) + for _, p := range peers { + test.filterAnswers(p, answers) } // For the CFHeaders, we pretend all peers responded with the same @@ -884,8 +847,8 @@ func TestBlockManagerDetectBadPeers(t *testing.T) { } headers := make(map[string]*wire.MsgCFHeaders) - for _, peer := range peers { - headers[peer] = msg + for _, p := range peers { + headers[p] = msg } bm := &blockManager{ @@ -900,12 +863,85 @@ func TestBlockManagerDetectBadPeers(t *testing.T) { badPeers, err := bm.detectBadPeers( headers, targetIndex, badIndex, fType, ) - if err != nil { - t.Fatalf("failed to detect bad peers: %v", err) - } + require.NoError(t, err) + + err = assertBadPeers(expBad, badPeers) + require.NoError(t, err) + } +} + +// TestHandleHeaders checks that we handle headers correctly, and that we +// disconnect peers that serve us bad headers (headers that don't connect to +// each other properly). +func TestHandleHeaders(t *testing.T) { + t.Parallel() + + // First, we set up a block manager and a fake peer that will act as the + // test's remote peer. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") + require.NoError(t, err) - if err := assertBadPeers(expBad, badPeers); err != nil { - t.Fatal(err) + assertPeerDisconnected := func(shouldBeDisconnected bool) { + // This is quite hacky but works: We expect the peer to be + // disconnected, which sets the unexported "disconnected" field + // to 1. + refValue := reflect.ValueOf(fakePeer).Elem() + foo := refValue.FieldByName("disconnect").Int() + + if shouldBeDisconnected { + require.EqualValues(t, 1, foo) + } else { + require.EqualValues(t, 0, foo) } } + + // We'll want to use actual, real blocks, so we take a miner harness + // that we can use to generate some. + harness, err := rpctest.New( + &chaincfg.SimNetParams, nil, []string{"--txindex"}, "", + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, harness.TearDown()) + }) + + err = harness.SetUp(false, 0) + require.NoError(t, err) + + // Generate 200 valid blocks that we then feed to the block manager. + blockHashes, err := harness.Client.Generate(200) + require.NoError(t, err) + + hmsg := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, len(blockHashes)), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + for i := range blockHashes { + header, err := harness.Client.GetBlockHeader(blockHashes[i]) + require.NoError(t, err) + + hmsg.headers.Headers[i] = header + } + + // Let's feed in the correct headers. This should work fine and the peer + // should not be disconnected. + bm.handleHeadersMsg(hmsg) + assertPeerDisconnected(false) + + // Now scramble the headers and feed them in again. This should cause + // the peer to be disconnected. + rand.Shuffle(len(hmsg.headers.Headers), func(i, j int) { + hmsg.headers.Headers[i], hmsg.headers.Headers[j] = + hmsg.headers.Headers[j], hmsg.headers.Headers[i] + }) + bm.handleHeadersMsg(hmsg) + assertPeerDisconnected(true) } diff --git a/go.mod b/go.mod index 1710d3b6..9d5f1b92 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/lightninglabs/neutrino require ( - github.com/btcsuite/btcd v0.23.3 + github.com/btcsuite/btcd v0.23.5-0.20230711222809-7faa9b266231 github.com/btcsuite/btcd/btcec/v2 v2.1.3 github.com/btcsuite/btcd/btcutil v1.1.1 github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 @@ -29,8 +29,8 @@ require ( github.com/lightningnetwork/lnd/ticker v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.etcd.io/bbolt v1.3.5-0.20200615073812-232d8fc87f50 // indirect - golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 // indirect - golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed // indirect + golang.org/x/crypto v0.1.0 // indirect + golang.org/x/sys v0.1.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6f7cc611..aec8781e 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,8 @@ github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13P github.com/btcsuite/btcd v0.22.0-beta.0.20220111032746-97732e52810c/go.mod h1:tjmYdS6MLJ5/s0Fj4DbLgSbDHbEqLJrtnHecBFkdz5M= github.com/btcsuite/btcd v0.22.0-beta.0.20220204213055-eaf0459ff879/go.mod h1:osu7EoKiL36UThEgzYPqdRaxeo0NU8VoXqgcnwpey0g= github.com/btcsuite/btcd v0.22.0-beta.0.20220316175102-8d5c75c28923/go.mod h1:taIcYprAW2g6Z9S0gGUxyR+zDwimyDMK5ePOX+iJ2ds= -github.com/btcsuite/btcd v0.23.3 h1:4KH/JKy9WiCd+iUS9Mu0Zp7Dnj17TGdKrg9xc/FGj24= -github.com/btcsuite/btcd v0.23.3/go.mod h1:0QJIIN1wwIXF/3G/m87gIwGniDMDQqjVn4SZgnFpsYY= +github.com/btcsuite/btcd v0.23.5-0.20230711222809-7faa9b266231 h1:FZR6mILlSI/GDx8ydNVBZAlXlRXsoRBWX2Un64mpfsI= +github.com/btcsuite/btcd v0.23.5-0.20230711222809-7faa9b266231/go.mod h1:0QJIIN1wwIXF/3G/m87gIwGniDMDQqjVn4SZgnFpsYY= github.com/btcsuite/btcd/btcec/v2 v2.1.0/go.mod h1:2VzYrv4Gm4apmbVVsSq5bqf1Ec8v56E48Vt0Y/umPgA= github.com/btcsuite/btcd/btcec/v2 v2.1.1/go.mod h1:ctjw4H1kknNJmRN4iP1R7bTQ+v3GJkZBd6mui8ZsAZE= github.com/btcsuite/btcd/btcec/v2 v2.1.3 h1:xM/n3yIhHAhHy04z4i43C8p4ehixJZMsnrVJkgl+MTE= @@ -110,8 +110,9 @@ go.etcd.io/bbolt v1.3.5-0.20200615073812-232d8fc87f50/go.mod h1:G5EMThwa9y8QZGBC golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= +golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/net v0.0.0-20180719180050-a680a1efc54d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190206173232-65e2d4e15006/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -129,8 +130,9 @@ golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed h1:J22ig1FUekjjkmZUM7pTKixYm8DvrYsvrBZdunYeIuQ= golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= diff --git a/query.go b/query.go index fb5deb13..66a506dd 100644 --- a/query.go +++ b/query.go @@ -885,8 +885,20 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, return noProgress } - // TODO(roasbeef): modify CheckBlockSanity to also check witness - // commitment + if err := blockchain.ValidateWitnessCommitment( + block, + ); err != nil { + log.Warnf("Invalid block for %s received from %s: %v "+ + "-- disconnecting peer", blockHash, peer, err) + + err = s.BanPeer(peer, banman.InvalidBlock) + if err != nil { + log.Errorf("Unable to ban peer %v: %v", peer, + err) + } + + return noProgress + } // At this point, the block matches what we know about it, and // we declare it sane. We can kill the query and pass the