From 40670d1b4c45ba4fbfee6e70917e66d12d9709ea Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Wed, 30 Aug 2023 00:53:33 +0100 Subject: [PATCH 1/7] neutrino + query: Request made more flexible to querying. This commit makes the query package more flexible in querying peers. This is done by adding sendQuery function field to query.Request struct. Instead of using only QueueMessageWithEncoding for all requests. This would be useful in coming commits where we would be using pushGetHeadersMsg to fetch block headers from peers. Consequent changes: encoding was removed in the queryJob field as it would not be useful for all requests. Requests that need encoding can define it as one of the fields in its own implementation of the new interface created as a type for the Req field in Request. PriorityIndex function signature would be used in the coming commits to indicate the the priority a request should preferably have in a query batch. An implementaion of the interface was created for GetCfheaders, GetCfilter and getData requests. Tests were updated and added to reflect these changes. Signed-off-by: Maureen Ononiwu --- blockmanager.go | 52 +++++++++++++++++++--- blockmanager_test.go | 4 +- query.go | 17 ++++++-- query/interface.go | 20 ++++++++- query/worker.go | 35 ++++++++++----- query/worker_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++- query/workmanager.go | 1 - query_test.go | 4 +- 8 files changed, 206 insertions(+), 29 deletions(-) diff --git a/blockmanager.go b/blockmanager.go index 9c9e6f4c6..de966fd35 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -5,6 +5,7 @@ package neutrino import ( "bytes" "container/list" + "errors" "fmt" "math" "math/big" @@ -808,12 +809,31 @@ func (b *blockManager) getUncheckpointedCFHeaders( // handle a query for checkpointed filter headers. type checkpointedCFHeadersQuery struct { blockMgr *blockManager - msgs []wire.Message + msgs []*encodedQuery checkpoints []*chainhash.Hash stopHashes map[chainhash.Hash]uint32 headerChan chan *wire.MsgCFHeaders } +// encodedQuery holds all the information needed to query a message that pushes requests +// using the QueryMessagingWithEncoding method. +type encodedQuery struct { + message wire.Message + encoding wire.MessageEncoding + priorityIndex uint64 +} + +// Message returns the wire.Message of encodedQuery's struct. +func (e *encodedQuery) Message() wire.Message { + return e.message +} + +// PriorityIndex returns the specified priority the caller wants +// the request to take. +func (e *encodedQuery) PriorityIndex() uint64 { + return e.priorityIndex +} + // requests creates the query.Requests for this CF headers query. func (c *checkpointedCFHeadersQuery) requests() []*query.Request { reqs := make([]*query.Request, len(c.msgs)) @@ -821,6 +841,7 @@ func (c *checkpointedCFHeadersQuery) requests() []*query.Request { reqs[idx] = &query.Request{ Req: m, HandleResp: c.handleResponse, + SendQuery: sendQueryMessageWithEncoding, } } return reqs @@ -924,6 +945,24 @@ func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, } } +// sendQueryMessageWithEncoding sends a message to the peer with encoding. +func sendQueryMessageWithEncoding(peer query.Peer, req query.ReqMessage) error { + sp, ok := peer.(*ServerPeer) + if !ok { + err := "peer is not of type ServerPeer" + log.Errorf(err) + return errors.New(err) + } + request, ok := req.(*encodedQuery) + if !ok { + return errors.New("invalid request type") + } + + sp.QueueMessageWithEncoding(request.message, nil, request.encoding) + + return nil +} + // getCheckpointedCFHeaders catches a filter header store up with the // checkpoints we got from the network. It assumes that the filter header store // matches the checkpoints up to the tip of the store. @@ -959,7 +998,7 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // the remaining checkpoint intervals. numCheckpts := uint32(len(checkpoints)) - startingInterval numQueries := (numCheckpts + maxCFCheckptsPerQuery - 1) / maxCFCheckptsPerQuery - queryMsgs := make([]wire.Message, 0, numQueries) + queryMsgs := make([]*encodedQuery, 0, numQueries) // We'll also create an additional set of maps that we'll use to // re-order the responses as we get them in. @@ -1004,9 +1043,12 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // Once we have the stop hash, we can construct the query // message itself. - queryMsg := wire.NewMsgGetCFHeaders( - fType, startHeightRange, &stopHash, - ) + queryMsg := &encodedQuery{ + message: wire.NewMsgGetCFHeaders( + fType, startHeightRange, &stopHash, + ), + encoding: wire.WitnessEncoding, + } // We'll mark that the ith interval is queried by this message, // and also map the stop hash back to the index of this message. diff --git a/blockmanager_test.go b/blockmanager_test.go index 45554b7cd..97416cf1b 100644 --- a/blockmanager_test.go +++ b/blockmanager_test.go @@ -352,7 +352,7 @@ func TestBlockManagerInitialInterval(t *testing.T) { var msgs []wire.Message for _, q := range requests { - msgs = append(msgs, q.Req) + msgs = append(msgs, q.Req.Message()) } responses, err := generateResponses(msgs, headers) @@ -582,7 +582,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { var msgs []wire.Message for _, q := range requests { - msgs = append(msgs, q.Req) + msgs = append(msgs, q.Req.Message()) } responses, err := generateResponses(msgs, headers) require.NoError(t, err) diff --git a/query.go b/query.go index 66a506dd1..d8e72aae8 100644 --- a/query.go +++ b/query.go @@ -435,13 +435,17 @@ type cfiltersQuery struct { // request couples a query message with the handler to be used for the response // in a query.Request struct. func (q *cfiltersQuery) request() *query.Request { - msg := wire.NewMsgGetCFilters( - q.filterType, uint32(q.startHeight), q.stopHash, - ) + msg := &encodedQuery{ + message: wire.NewMsgGetCFilters( + q.filterType, uint32(q.startHeight), q.stopHash, + ), + encoding: wire.WitnessEncoding, + } return &query.Request{ Req: msg, HandleResp: q.handleResponse, + SendQuery: sendQueryMessageWithEncoding, } } @@ -833,6 +837,10 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, // Construct the appropriate getdata message to fetch the target block. getData := wire.NewMsgGetData() _ = getData.AddInvVect(inv) + msg := &encodedQuery{ + message: getData, + encoding: wire.WitnessEncoding, + } var foundBlock *btcutil.Block @@ -912,8 +920,9 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, // Prepare the query request. request := &query.Request{ - Req: getData, + Req: msg, HandleResp: handleResp, + SendQuery: sendQueryMessageWithEncoding, } // Prepare the query options. diff --git a/query/interface.go b/query/interface.go index dca5f42dc..311a2e621 100644 --- a/query/interface.go +++ b/query/interface.go @@ -124,8 +124,8 @@ type Progress struct { // Request is the main struct that defines a bitcoin network query to be sent to // connected peers. type Request struct { - // Req is the message request to send. - Req wire.Message + // Req contains the message request to send. + Req ReqMessage // HandleResp is a response handler that will be called for every // message received from the peer that the request was made to. It @@ -139,6 +139,22 @@ type Request struct { // The response should be handed off to another goroutine for // processing. HandleResp func(req, resp wire.Message, peer string) Progress + + // SendQuery handles sending request to the worker's peer. It returns an error, + // if one is encountered while sending the request. + SendQuery func(peer Peer, request ReqMessage) error +} + +// ReqMessage is an interface which all structs containing information +// required to process a message request must implement. +type ReqMessage interface { + + // Message returns the message request. + Message() wire.Message + + // PriorityIndex returns the priority the caller prefers the request + // would take. + PriorityIndex() uint64 } // WorkManager defines an API for a manager that dispatches queries to bitcoin diff --git a/query/worker.go b/query/worker.go index dc15a18cf..6b718f45c 100644 --- a/query/worker.go +++ b/query/worker.go @@ -3,8 +3,6 @@ package query import ( "errors" "time" - - "github.com/btcsuite/btcd/wire" ) var ( @@ -27,7 +25,6 @@ type queryJob struct { tries uint8 index uint64 timeout time.Duration - encoding wire.MessageEncoding cancelChan <-chan struct{} *Request } @@ -89,6 +86,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { msgChan, cancel := peer.SubscribeRecvMsg() defer cancel() +nexJobLoop: for { log.Tracef("Worker %v waiting for more work", peer.Addr()) @@ -133,7 +131,22 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { log.Tracef("Worker %v queuing job %T with index %v", peer.Addr(), job.Req, job.Index()) - peer.QueueMessageWithEncoding(job.Req, nil, job.encoding) + err := job.SendQuery(peer, job.Req) + + // If any error occurs while sending query, quickly send the result + // containing the error to the workmanager. + if err != nil { + select { + case results <- &jobResult{ + job: job, + peer: peer, + err: err, + }: + case <-quit: + return + } + goto nexJobLoop + } } // Wait for the correct response to be received from the peer, @@ -143,7 +156,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { timeout = time.NewTimer(job.timeout) ) - Loop: + feedbackLoop: for { select { // A message was received from the peer, use the @@ -151,7 +164,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { // our request. case resp := <-msgChan: progress := job.HandleResp( - job.Req, resp, peer.Addr(), + job.Req.Message(), resp, peer.Addr(), ) log.Tracef("Worker %v handled msg %T while "+ @@ -176,12 +189,12 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { job.timeout, ) } - continue Loop + continue feedbackLoop } // We did get a valid response, and can break // the loop. - break Loop + break feedbackLoop // If the timeout is reached before a valid response // has been received, we exit with an error. @@ -193,7 +206,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { "with job index %v", peer.Addr(), job.Req, job.Index()) - break Loop + break feedbackLoop // If the peer disconnects before giving us a valid // answer, we'll also exit with an error. @@ -203,7 +216,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { job.Index()) jobErr = ErrPeerDisconnected - break Loop + break feedbackLoop // If the job was canceled, we report this back to the // work manager. @@ -212,7 +225,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { peer.Addr(), job.Index()) jobErr = ErrJobCanceled - break Loop + break feedbackLoop case <-quit: return diff --git a/query/worker_test.go b/query/worker_test.go index 8cb5c17d2..b4a0d278f 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -1,6 +1,7 @@ package query import ( + "errors" "fmt" "testing" "time" @@ -8,8 +9,26 @@ import ( "github.com/btcsuite/btcd/wire" ) +type mockQueryEncoded struct { + message *wire.MsgGetData + encoding wire.MessageEncoding + index uint64 +} + +func (m *mockQueryEncoded) Message() wire.Message { + return m.message +} + +func (m *mockQueryEncoded) PriorityIndex() uint64 { + return m.index +} + var ( - req = &wire.MsgGetData{} + msg = &wire.MsgGetData{} + req = &mockQueryEncoded{ + message: msg, + encoding: wire.WitnessEncoding, + } progressResp = &wire.MsgTx{ Version: 111, } @@ -24,6 +43,7 @@ type mockPeer struct { responses chan<- wire.Message subscriptions chan chan wire.Message quit chan struct{} + err error } var _ Peer = (*mockPeer)(nil) @@ -75,11 +95,21 @@ func makeJob() *queryJob { Progressed: false, } }, + SendQuery: func(peer Peer, req ReqMessage) error { + m := peer.(*mockPeer) + + if m.err != nil { + return m.err + } + + m.requests <- req.Message() + return nil + }, } + return &queryJob{ index: 123, timeout: 30 * time.Second, - encoding: defaultQueryEncoding, cancelChan: nil, Request: q, } @@ -472,3 +502,71 @@ func TestWorkerJobCanceled(t *testing.T) { } } } + +// TestWorkerSendQueryErr will test if the result would return an error +// that would be handled by the worker if there is an error returned while +// sending a query. +func TestWorkerSendQueryErr(t *testing.T) { + t.Parallel() + + ctx, err := startWorker() + if err != nil { + t.Fatalf("unable to start worker: %v", err) + } + + cancelChan := make(chan struct{}) + + // Give the worker a new job. + taskJob := makeJob() + taskJob.cancelChan = cancelChan + + // Assign error to be returned while sending query. + ctx.peer.err = errors.New("query error") + + // Send job to worker + select { + case ctx.nextJob <- taskJob: + case <-time.After(1 * time.Second): + t.Fatalf("did not pick up job") + } + + // Request should not be sent as there should be an error while + // querying. + select { + case <-ctx.peer.requests: + t.Fatalf("request sent when query failed") + case <-time.After(time.Second): + } + + // jobResult should be sent by worker at this point. + var result *jobResult + select { + case result = <-ctx.jobResults: + case <-time.After(time.Second): + t.Fatalf("response not received") + } + + // jobResult should contain error. + if result.err != ctx.peer.err { + t.Fatalf("expected result's error to be %v, was %v", + ctx.peer.err, result.err) + } + + // Make sure the result was given for the intended task. + if result.job != taskJob { + t.Fatalf("got result for unexpected job") + } + + // And the correct peer. + if result.peer != ctx.peer { + t.Fatalf("expected peer to be %v, was %v", + ctx.peer.Addr(), result.peer) + } + + // The worker should be in the nextJob Loop. + select { + case ctx.nextJob <- taskJob: + case <-time.After(1 * time.Second): + t.Fatalf("did not pick up job") + } +} diff --git a/query/workmanager.go b/query/workmanager.go index e99f57abc..5e3000b40 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -438,7 +438,6 @@ Loop: heap.Push(work, &queryJob{ index: queryIndex, timeout: minQueryTimeout, - encoding: batch.options.encoding, cancelChan: batch.options.cancelChan, Request: q, }) diff --git a/query_test.go b/query_test.go index cbc9a74a9..bb45deed6 100644 --- a/query_test.go +++ b/query_test.go @@ -303,9 +303,9 @@ func TestBlockCache(t *testing.T) { defer close(errChan) require.Len(t, reqs, 1) - require.IsType(t, &wire.MsgGetData{}, reqs[0].Req) + require.IsType(t, &wire.MsgGetData{}, reqs[0].Req.Message()) - getData := reqs[0].Req.(*wire.MsgGetData) + getData := reqs[0].Req.Message().(*wire.MsgGetData) require.Len(t, getData.InvList, 1) inv := getData.InvList[0] From 9ab1d8b82d5fd5c8c73fe66a59b8dc6b9917879e Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Wed, 30 Aug 2023 07:49:55 +0100 Subject: [PATCH 2/7] query: Removed unused QueueMessageWithEncoding from Peer interface. Signed-off-by: Maureen Ononiwu --- query/interface.go | 5 ----- query/worker_test.go | 6 ------ 2 files changed, 11 deletions(-) diff --git a/query/interface.go b/query/interface.go index 311a2e621..394525b95 100644 --- a/query/interface.go +++ b/query/interface.go @@ -183,11 +183,6 @@ type Dispatcher interface { // Peer is the interface that defines the methods needed by the query package // to be able to make requests and receive responses from a network peer. type Peer interface { - // QueueMessageWithEncoding adds the passed bitcoin message to the peer - // send queue. - QueueMessageWithEncoding(msg wire.Message, doneChan chan<- struct{}, - encoding wire.MessageEncoding) - // SubscribeRecvMsg adds a OnRead subscription to the peer. All bitcoin // messages received from this peer will be sent on the returned // channel. A closure is also returned, that should be called to cancel diff --git a/query/worker_test.go b/query/worker_test.go index b4a0d278f..cdb9f1dad 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -48,12 +48,6 @@ type mockPeer struct { var _ Peer = (*mockPeer)(nil) -func (m *mockPeer) QueueMessageWithEncoding(msg wire.Message, - doneChan chan<- struct{}, encoding wire.MessageEncoding) { - - m.requests <- msg -} - func (m *mockPeer) SubscribeRecvMsg() (<-chan wire.Message, func()) { msgChan := make(chan wire.Message) m.subscriptions <- msgChan From 101c6901f383531b1057539accad794056e1dec0 Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Tue, 5 Sep 2023 06:12:48 +0100 Subject: [PATCH 3/7] query: ErrQueryTimeout does not exit worker feedback loop. Worker waits after timeout for response but job is scheduled on another worker so as not to be slowed down by one worker. Signed-off-by: Maureen Ononiwu --- query.go | 3 ++ query/worker.go | 10 ++++ query/worker_test.go | 2 +- query/workmanager.go | 4 +- query/workmanager_test.go | 100 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 117 insertions(+), 2 deletions(-) diff --git a/query.go b/query.go index d8e72aae8..56178e663 100644 --- a/query.go +++ b/query.go @@ -430,6 +430,7 @@ type cfiltersQuery struct { headerIndex map[chainhash.Hash]int targetHash chainhash.Hash targetFilter *gcs.Filter + mtx sync.Mutex } // request couples a query message with the handler to be used for the response @@ -480,6 +481,8 @@ func (q *cfiltersQuery) handleResponse(req, resp wire.Message, // If this filter is for a block not in our index, we can ignore it, as // we either already got it, or it is out of our queried range. + q.mtx.Lock() + defer q.mtx.Unlock() i, ok := q.headerIndex[response.BlockHash] if !ok { return noProgress diff --git a/query/worker.go b/query/worker.go index 6b718f45c..5aaa6f9e5 100644 --- a/query/worker.go +++ b/query/worker.go @@ -247,6 +247,16 @@ nexJobLoop: return } + // If the error is a timeout still wait for the response as we are assured a response as long as there was a + // request but reschedule on another worker to quickly fetch a response so as not to be slowed down by this + // worker. We either get a response or the peer stalls (i.e. disconnects due to an elongated time without + // a response) + if jobErr == ErrQueryTimeout { + jobErr = nil + + goto feedbackLoop + } + // If the peer disconnected, we can exit immediately. if jobErr == ErrPeerDisconnected { return diff --git a/query/worker_test.go b/query/worker_test.go index cdb9f1dad..5515049df 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -278,8 +278,8 @@ func TestWorkerTimeout(t *testing.T) { // It will immediately attempt to fetch another task. select { case ctx.nextJob <- task: + t.Fatalf("worker still in feedback loop picked up job") case <-time.After(1 * time.Second): - t.Fatalf("did not pick up job") } } diff --git a/query/workmanager.go b/query/workmanager.go index 5e3000b40..74f6ae747 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -298,7 +298,9 @@ Loop: // Delete the job from the worker's active job, such // that the slot gets opened for more work. r := workers[result.peer.Addr()] - r.activeJob = nil + if result.err != ErrQueryTimeout { + r.activeJob = nil + } // Get the index of this query's batch, and delete it // from the map of current queries, since we don't have diff --git a/query/workmanager_test.go b/query/workmanager_test.go index b7bec809c..87bff15c7 100644 --- a/query/workmanager_test.go +++ b/query/workmanager_test.go @@ -297,6 +297,106 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { } } +// TestWorkManagerErrQueryTimeout tests that the workers that return query +// timeout are not sent jobs until they return a different error. +func TestWorkManagerErrQueryTimeout(t *testing.T) { + const numQueries = 2 + const numWorkers = 1 + + // Start work manager. + wm, workers := startWorkManager(t, numWorkers) + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var scheduledJobs [numQueries]chan sched + var queries [numQueries]*Request + for i := 0; i < numQueries; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries[i] = q + scheduledJobs[i] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:]) + + var iter int + var s sched + for i := 0; i < numQueries; i++ { + select { + case s = <-scheduledJobs[i]: + if s.job.index != uint64(i) { + t.Fatalf("wrong index") + } + if iter == 1 { + t.Fatalf("Expected only one scheduled job") + } + iter++ + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + if iter < 1 { + t.Fatalf("next job not received") + } + } + } + + select { + case s.wk.results <- &jobResult{ + job: s.job, + err: ErrQueryTimeout, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + + // Finally, make sure the job is not retried as there are no available + // peer to retry it. + + select { + case <-scheduledJobs[0]: + t.Fatalf("did not expect job rescheduled") + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + } + + // There should be no errChan message as query is still incomplete. + select { + case err := <-errChan: + if err != nil { + t.Fatalf("got error: %v", err) + } + t.Fatalf("expected no errChan message") + case <-time.After(time.Second): + } +} + // TestWorkManagerCancelBatch checks that we can cancel a batch query midway, // and that the jobs it contains are canceled. func TestWorkManagerCancelBatch(t *testing.T) { From f168538091d229ad5e667a1e3fde0aa16219a7fd Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Tue, 5 Sep 2023 06:20:50 +0100 Subject: [PATCH 4/7] neutrino + query: Priority index implementation - The workmanager is made to check if a request has a priority index before assigning a query index. - Job index type is also changed to float64 in this commit for flexibility. Signed-off-by: Maureen Ononiwu --- blockmanager.go | 4 +- query/interface.go | 2 +- query/worker.go | 4 +- query/worker_test.go | 4 +- query/workmanager.go | 21 +++-- query/workmanager_test.go | 157 ++++++++++++++++++++++++++++++++++---- query/workqueue.go | 2 +- query/workqueue_test.go | 18 ++--- 8 files changed, 176 insertions(+), 36 deletions(-) diff --git a/blockmanager.go b/blockmanager.go index de966fd35..59b38c823 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -820,7 +820,7 @@ type checkpointedCFHeadersQuery struct { type encodedQuery struct { message wire.Message encoding wire.MessageEncoding - priorityIndex uint64 + priorityIndex float64 } // Message returns the wire.Message of encodedQuery's struct. @@ -830,7 +830,7 @@ func (e *encodedQuery) Message() wire.Message { // PriorityIndex returns the specified priority the caller wants // the request to take. -func (e *encodedQuery) PriorityIndex() uint64 { +func (e *encodedQuery) PriorityIndex() float64 { return e.priorityIndex } diff --git a/query/interface.go b/query/interface.go index 394525b95..ce9f751c5 100644 --- a/query/interface.go +++ b/query/interface.go @@ -154,7 +154,7 @@ type ReqMessage interface { // PriorityIndex returns the priority the caller prefers the request // would take. - PriorityIndex() uint64 + PriorityIndex() float64 } // WorkManager defines an API for a manager that dispatches queries to bitcoin diff --git a/query/worker.go b/query/worker.go index 5aaa6f9e5..ede799a3d 100644 --- a/query/worker.go +++ b/query/worker.go @@ -23,7 +23,7 @@ var ( // addition to some information about the query. type queryJob struct { tries uint8 - index uint64 + index float64 timeout time.Duration cancelChan <-chan struct{} *Request @@ -36,7 +36,7 @@ var _ Task = (*queryJob)(nil) // Index returns the queryJob's index within the work queue. // // NOTE: Part of the Task interface. -func (q *queryJob) Index() uint64 { +func (q *queryJob) Index() float64 { return q.index } diff --git a/query/worker_test.go b/query/worker_test.go index 5515049df..d4a0760ba 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -12,14 +12,14 @@ import ( type mockQueryEncoded struct { message *wire.MsgGetData encoding wire.MessageEncoding - index uint64 + index float64 } func (m *mockQueryEncoded) Message() wire.Message { return m.message } -func (m *mockQueryEncoded) PriorityIndex() uint64 { +func (m *mockQueryEncoded) PriorityIndex() float64 { return m.index } diff --git a/query/workmanager.go b/query/workmanager.go index 74f6ae747..847f53113 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -199,8 +199,8 @@ func (w *peerWorkManager) workDispatcher() { // We set up a counter that we'll increase with each incoming query, // and will serve as the priority of each. In addition we map each // query to the batch they are part of. - queryIndex := uint64(0) - currentQueries := make(map[uint64]uint64) + queryIndex := float64(0) + currentQueries := make(map[float64]uint64) workers := make(map[string]*activeWorker) @@ -437,14 +437,25 @@ Loop: "work queue", batchIndex, len(batch.requests)) for _, q := range batch.requests { + idx := queryIndex + + // If priority index is set, use that index. + if q.Req.PriorityIndex() != 0 { + idx = q.Req.PriorityIndex() + } heap.Push(work, &queryJob{ - index: queryIndex, + index: idx, timeout: minQueryTimeout, cancelChan: batch.options.cancelChan, Request: q, }) - currentQueries[queryIndex] = batchIndex - queryIndex++ + currentQueries[idx] = batchIndex + + // Only increment queryIndex if it was + // assigned to this job. + if q.Req.PriorityIndex() == 0 { + queryIndex++ + } } currentBatches[batchIndex] = &batchProgress{ diff --git a/query/workmanager_test.go b/query/workmanager_test.go index 87bff15c7..940ea2d1c 100644 --- a/query/workmanager_test.go +++ b/query/workmanager_test.go @@ -131,7 +131,10 @@ func TestWorkManagerWorkDispatcherSingleWorker(t *testing.T) { // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries = append(queries, q) } @@ -141,7 +144,7 @@ func TestWorkManagerWorkDispatcherSingleWorker(t *testing.T) { // Each query should be sent on the nextJob queue, in the order they // had in their batch. - for i := uint64(0); i < numQueries; i++ { + for i := float64(0); i < numQueries; i++ { var job *queryJob select { case job = <-wk.nextJob: @@ -199,7 +202,9 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { var scheduledJobs [numQueries]chan sched var queries [numQueries]*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries[i] = q scheduledJobs[i] = make(chan sched) } @@ -212,7 +217,7 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { go func() { for { job := <-wk.nextJob - scheduledJobs[job.index] <- sched{ + scheduledJobs[int(job.index)] <- sched{ wk: wk, job: job, } @@ -224,11 +229,11 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { errChan := wm.Query(queries[:]) var jobs [numQueries]sched - for i := uint64(0); i < numQueries; i++ { + for i := 0; i < numQueries; i++ { var s sched select { case s = <-scheduledJobs[i]: - if s.job.index != i { + if s.job.index != float64(i) { t.Fatalf("wrong index") } @@ -238,7 +243,7 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { t.Fatalf("next job not received") } - jobs[s.job.index] = s + jobs[int(s.job.index)] = s } // Go backwards, and fail half of them. @@ -262,10 +267,10 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { // Finally, make sure the failed jobs are being retried, in the same // order as they were originally scheduled. - for i := uint64(0); i < numQueries; i += 2 { + for i := float64(0); i < numQueries; i += 2 { var s sched select { - case s = <-scheduledJobs[i]: + case s = <-scheduledJobs[int(i)]: if s.job.index != i { t.Fatalf("wrong index") } @@ -348,7 +353,7 @@ func TestWorkManagerErrQueryTimeout(t *testing.T) { for i := 0; i < numQueries; i++ { select { case s = <-scheduledJobs[i]: - if s.job.index != uint64(i) { + if s.job.index != float64(i) { t.Fatalf("wrong index") } if iter == 1 { @@ -409,7 +414,9 @@ func TestWorkManagerCancelBatch(t *testing.T) { // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries = append(queries, q) } @@ -499,7 +506,9 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries = append(queries, q) } @@ -511,7 +520,7 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { for i := 0; i < numQueries; i++ { select { case job := <-workers[i].nextJob: - if job.index != uint64(i) { + if job.index != float64(i) { t.Fatalf("unexpected job") } jobs = append(jobs, job) @@ -562,7 +571,9 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { // Send a new set of queries. queries = nil for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries = append(queries, q) } _ = wm.Query(queries) @@ -576,3 +587,121 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { } } } + +// TestWorkManagerSchedulePriorityIndex tests that the workmanager acknowledges +// priority index. +func TestWorkManagerSchedulePriorityIndex(t *testing.T) { + const numQueries = 3 + + // Start work manager with as many workers as queries. This is not very + // realistic, but makes the work manager able to schedule all queries + // concurrently. + wm, workers := startWorkManager(t, numQueries) + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var scheduledJobs [5]chan sched + var queries [numQueries]*Request + for i := 0; i < numQueries; i++ { + var q *Request + idx := i + if i == 0 { + q = &Request{ + Req: &mockQueryEncoded{}, + } + } else { + // Assign priority index. + idx = i + 2 + q = &Request{ + Req: &mockQueryEncoded{ + index: float64(idx), + }, + } + } + queries[i] = q + scheduledJobs[idx] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:]) + + var jobs [numQueries]sched + for i := uint64(0); i < numQueries; i++ { + var expectedIndex float64 + + if i == 0 { + expectedIndex = float64(0) + } else { + expectedIndex = float64(i + 2) + } + var s sched + select { + case s = <-scheduledJobs[int(expectedIndex)]: + + if s.job.index != expectedIndex { + t.Fatalf("wrong index: Got %v but expected %v", s.job.index, + expectedIndex) + } + case <-errChan: + t.Fatalf("did not expect an errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + jobs[i] = s + } + + // Go backwards send results for job. + for i := numQueries - 1; i >= 0; i-- { + select { + case jobs[i].wk.results <- &jobResult{ + job: jobs[i].job, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // Finally, make sure no jobs are retried. + for i := uint64(0); i < numQueries; i++ { + select { + case <-scheduledJobs[i]: + t.Fatalf("did not expect a retried job") + case <-time.After(time.Second): + } + } + + // The query should ultimately succeed. + select { + case err := <-errChan: + if err != nil { + t.Fatalf("got error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("nothing received on errChan") + } +} diff --git a/query/workqueue.go b/query/workqueue.go index 9a92ce8f2..cc9b2f2ef 100644 --- a/query/workqueue.go +++ b/query/workqueue.go @@ -4,7 +4,7 @@ package query // work queue. type Task interface { // Index returns this Task's index in the work queue. - Index() uint64 + Index() float64 } // workQueue is struct implementing the heap interface, and is used to keep a diff --git a/query/workqueue_test.go b/query/workqueue_test.go index d9abc66a5..b0f7bae00 100644 --- a/query/workqueue_test.go +++ b/query/workqueue_test.go @@ -6,12 +6,12 @@ import ( ) type task struct { - index uint64 + index float64 } var _ Task = (*task)(nil) -func (t *task) Index() uint64 { +func (t *task) Index() float64 { return t.index } @@ -27,7 +27,7 @@ func TestWorkQueue(t *testing.T) { // Create a simple list of tasks and add them all to the queue. var tasks []*task - for i := uint64(0); i < numTasks; i++ { + for i := float64(0); i < numTasks; i++ { tasks = append(tasks, &task{ index: i, }) @@ -40,11 +40,11 @@ func TestWorkQueue(t *testing.T) { // Check that it reports the expected number of elements. l := q.Len() if l != numTasks { - t.Fatalf("expected %d length, was %d", numTasks, l) + t.Fatalf("expected %v length, was %v", numTasks, l) } // Pop half, and make sure they arrive in the right order. - for i := uint64(0); i < numTasks/2; i++ { + for i := float64(0); i < numTasks/2; i++ { peek := q.Peek().(*task) pop := heap.Pop(q) @@ -54,7 +54,7 @@ func TestWorkQueue(t *testing.T) { } if peek.index != i { - t.Fatalf("wrong index: %d", peek.index) + t.Fatalf("wrong index: %v", peek.index) } } @@ -63,7 +63,7 @@ func TestWorkQueue(t *testing.T) { heap.Push(q, tasks[0]) } - for i := uint64(numTasks/2 - 3); i < numTasks; i++ { + for i := float64(numTasks/2 - 3); i < numTasks; i++ { peek := q.Peek().(*task) pop := heap.Pop(q) @@ -80,13 +80,13 @@ func TestWorkQueue(t *testing.T) { } if peek.index != exp { - t.Fatalf("wrong index: %d", peek.index) + t.Fatalf("wrong index: %v", peek.index) } } // Finally, the queue should be empty. l = q.Len() if l != 0 { - t.Fatalf("expected %d length, was %d", 0, l) + t.Fatalf("expected %v length, was %v", 0, l) } } From dd02e22bb2e30f2d5205b29e787886e4ca71b1d9 Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Tue, 5 Sep 2023 06:24:57 +0100 Subject: [PATCH 5/7] query + neutrino: Clone job before sending to workmanager Adds CloneReq function field to query.Request struct. Jobs are cloned in the worker before sending to the workmanager. This would be useful in coming commits where a job's request is modified according to the response it gets. Such as in the case of block header fetching. A CloneReq function is defined in the instance of GetCFilter, GetCFHeader and GetData requests in this commit as well. Signed-off-by: Maureen Ononiwu --- blockmanager.go | 22 ++++++++++++- query.go | 35 ++++++++++++++++++++ query/interface.go | 3 ++ query/worker.go | 22 ++++++++++++- query/worker_test.go | 77 ++++++++++++++++++++++++++++++++++---------- 5 files changed, 140 insertions(+), 19 deletions(-) diff --git a/blockmanager.go b/blockmanager.go index 59b38c823..35523fa1e 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -842,6 +842,7 @@ func (c *checkpointedCFHeadersQuery) requests() []*query.Request { Req: m, HandleResp: c.handleResponse, SendQuery: sendQueryMessageWithEncoding, + CloneReq: cloneMsgCFHeaders, } } return reqs @@ -957,12 +958,31 @@ func sendQueryMessageWithEncoding(peer query.Peer, req query.ReqMessage) error { if !ok { return errors.New("invalid request type") } - sp.QueueMessageWithEncoding(request.message, nil, request.encoding) return nil } +// cloneMsgCFHeaders clones query.ReqMessage that contains the MsgGetCFHeaders message. +func cloneMsgCFHeaders(req query.ReqMessage) query.ReqMessage { + oldReq, ok := req.(*encodedQuery) + if !ok { + log.Errorf("request not of type *encodedQuery") + } + oldReqMessage, ok := oldReq.message.(*wire.MsgGetCFHeaders) + if !ok { + log.Errorf("request not of type *wire.MsgGetCFHeaders") + } + newReq := &encodedQuery{ + message: wire.NewMsgGetCFHeaders( + oldReqMessage.FilterType, oldReqMessage.StartHeight, &oldReqMessage.StopHash, + ), + encoding: oldReq.encoding, + priorityIndex: oldReq.priorityIndex, + } + return newReq +} + // getCheckpointedCFHeaders catches a filter header store up with the // checkpoints we got from the network. It assumes that the filter header store // matches the checkpoints up to the tip of the store. diff --git a/query.go b/query.go index 56178e663..ad75bfca2 100644 --- a/query.go +++ b/query.go @@ -447,6 +447,24 @@ func (q *cfiltersQuery) request() *query.Request { Req: msg, HandleResp: q.handleResponse, SendQuery: sendQueryMessageWithEncoding, + CloneReq: func(req query.ReqMessage) query.ReqMessage { + oldReq, ok := req.(*encodedQuery) + if !ok { + log.Errorf("request not of type *encodedQuery") + } + oldReqMessage, ok := oldReq.message.(*wire.MsgGetCFilters) + if !ok { + log.Errorf("request not of type *wire.MsgGetCFilters") + } + newReq := &encodedQuery{ + message: wire.NewMsgGetCFilters( + oldReqMessage.FilterType, oldReqMessage.StartHeight, &oldReqMessage.StopHash, + ), + encoding: oldReq.encoding, + priorityIndex: oldReq.priorityIndex, + } + return newReq + }, } } @@ -926,6 +944,23 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, Req: msg, HandleResp: handleResp, SendQuery: sendQueryMessageWithEncoding, + CloneReq: func(req query.ReqMessage) query.ReqMessage { + newMsg := wire.NewMsgGetData() + _ = newMsg.AddInvVect(inv) + + oldReq, ok := req.(*encodedQuery) + if !ok { + log.Errorf("request not of type *encodedQuery") + } + + newReq := &encodedQuery{ + message: newMsg, + encoding: oldReq.encoding, + priorityIndex: oldReq.priorityIndex, + } + + return newReq + }, } // Prepare the query options. diff --git a/query/interface.go b/query/interface.go index ce9f751c5..d19598b28 100644 --- a/query/interface.go +++ b/query/interface.go @@ -143,6 +143,9 @@ type Request struct { // SendQuery handles sending request to the worker's peer. It returns an error, // if one is encountered while sending the request. SendQuery func(peer Peer, request ReqMessage) error + + // CloneReq clones the message. + CloneReq func(message ReqMessage) ReqMessage } // ReqMessage is an interface which all structs containing information diff --git a/query/worker.go b/query/worker.go index ede799a3d..a03dd984b 100644 --- a/query/worker.go +++ b/query/worker.go @@ -235,11 +235,31 @@ nexJobLoop: // Stop to allow garbage collection. timeout.Stop() + // This is necessary to avoid a situation where future changes to the job's request affect the current job. + // For example: suppose we want to fetch headers between checkpoints 0 and 20,000. The maximum number of headers + // that a peer can send in one message is 2000. When we receive 2000 headers for one request, + // we update the job's request, changing its startheight and blocklocator to match the next batch of headers + // that we want to fetch. Since we are not done with fetching our target of 20,000 headers, + // we will have to make more changes to the job's request in the future. This could alter previous requests, + // resulting in unwanted behaviour. + resultJob := &queryJob{ + index: job.Index(), + Request: &Request{ + Req: job.CloneReq(job.Req), + HandleResp: job.Request.HandleResp, + CloneReq: job.Request.CloneReq, + SendQuery: job.Request.SendQuery, + }, + cancelChan: job.cancelChan, + tries: job.tries, + timeout: job.timeout, + } + // We have a result ready for the query, hand it off before // getting a new job. select { case results <- &jobResult{ - job: job, + job: resultJob, peer: peer, err: jobErr, }: diff --git a/query/worker_test.go b/query/worker_test.go index d4a0760ba..2cf6f1751 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -99,6 +99,19 @@ func makeJob() *queryJob { m.requests <- req.Message() return nil }, + CloneReq: func(req ReqMessage) ReqMessage { + oldReq := req.(*mockQueryEncoded) + + newMsg := &wire.MsgGetData{ + InvList: oldReq.message.InvList, + } + + clone := &mockQueryEncoded{ + message: newMsg, + } + + return clone + }, } return &queryJob{ @@ -209,9 +222,15 @@ func TestWorkerIgnoreMsgs(t *testing.T) { t.Fatalf("response error: %v", result.err) } - // Make sure the result was given for the intended job. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -264,9 +283,15 @@ func TestWorkerTimeout(t *testing.T) { t.Fatalf("expected timeout, got: %v", result.err) } - // Make sure the result was given for the intended job. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -323,9 +348,15 @@ func TestWorkerDisconnect(t *testing.T) { t.Fatalf("expected peer disconnect, got: %v", result.err) } - // Make sure the result was given for the intended job. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -411,9 +442,15 @@ func TestWorkerProgress(t *testing.T) { t.Fatalf("expected no error, got: %v", result.err) } - // Make sure the result was given for the intended task. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -484,9 +521,15 @@ func TestWorkerJobCanceled(t *testing.T) { t.Fatalf("expected job canceled, got: %v", result.err) } - // Make sure the result was given for the intended task. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -546,9 +589,9 @@ func TestWorkerSendQueryErr(t *testing.T) { ctx.peer.err, result.err) } - // Make sure the result was given for the intended task. + // Make sure the QueryJob instance in the result is same as the taskJob's. if result.job != taskJob { - t.Fatalf("got result for unexpected job") + t.Fatalf("result's job should be same as the taskJob's") } // And the correct peer. From 2285ce7199891d1028537aaf3a75d890b127f1fc Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Sat, 2 Sep 2023 09:31:58 +0100 Subject: [PATCH 6/7] neutrino + query: Refactored ErrChan, jobResults, HandleResp, Progress - Added Unfinished bool to jobResult to indicate successful jobs that still need to send another request to the peer to be considered complete. - Made ErrChan a query option in that way it is optional for different queries. - Refactored HandleResp, peer is now passed as query.Peer instead of using its address. - Changed type for query.Progress. Signed-off-by: Maureen Ononiwu --- blockmanager.go | 47 ++--- blockmanager_test.go | 36 ++-- query.go | 34 ++-- query/interface.go | 54 ++++-- query/worker.go | 88 +++++---- query/worker_test.go | 211 ++++++++++++++------- query/workmanager.go | 77 ++++++-- query/workmanager_test.go | 380 +++++++++++++++++++++++++++++++++----- query_test.go | 8 +- 9 files changed, 683 insertions(+), 252 deletions(-) diff --git a/blockmanager.go b/blockmanager.go index 35523fa1e..f32a9efa6 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -850,43 +850,37 @@ func (c *checkpointedCFHeadersQuery) requests() []*query.Request { // handleResponse is the internal response handler used for requests for this // CFHeaders query. -func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, - peerAddr string) query.Progress { +func (c *checkpointedCFHeadersQuery) handleResponse(request query.ReqMessage, resp wire.Message, + peer query.Peer) query.Progress { + + peerAddr := "" + if peer != nil { + peerAddr = peer.Addr() + } + req := request.Message() r, ok := resp.(*wire.MsgCFHeaders) if !ok { // We are only looking for cfheaders messages. - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } q, ok := req.(*wire.MsgGetCFHeaders) if !ok { // We sent a getcfheaders message, so that's what we should be // comparing against. - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } // The response doesn't match the query. if q.FilterType != r.FilterType || q.StopHash != r.StopHash { - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } checkPointIndex, ok := c.stopHashes[r.StopHash] if !ok { // We never requested a matching stop hash. - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } // Use either the genesis header or the previous checkpoint index as @@ -920,10 +914,7 @@ func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, log.Errorf("Unable to ban peer %v: %v", peerAddr, err) } - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } // At this point, the response matches the query, and the relevant @@ -934,16 +925,10 @@ func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, select { case c.headerChan <- r: case <-c.blockMgr.quit: - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } - return query.Progress{ - Finished: true, - Progressed: true, - } + return query.Finished } // sendQueryMessageWithEncoding sends a message to the peer with encoding. @@ -1106,7 +1091,7 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // Hand the queries to the work manager, and consume the verified // responses as they come back. errChan := b.cfg.QueryDispatcher.Query( - q.requests(), query.Cancel(b.quit), query.NoRetryMax(), + q.requests(), query.Cancel(b.quit), query.NoRetryMax(), query.ErrChan(make(chan error, 1)), ) // Keep waiting for more headers as long as we haven't received an diff --git a/blockmanager_test.go b/blockmanager_test.go index 97416cf1b..060443b91 100644 --- a/blockmanager_test.go +++ b/blockmanager_test.go @@ -214,14 +214,14 @@ func generateHeaders(genesisBlockHeader *wire.BlockHeader, // generateResponses generates the MsgCFHeaders messages from the given queries // and headers. -func generateResponses(msgs []wire.Message, +func generateResponses(msgs []query.ReqMessage, headers *headers) ([]*wire.MsgCFHeaders, error) { // Craft a response for each message. var responses []*wire.MsgCFHeaders for _, msg := range msgs { // Only GetCFHeaders expected. - q, ok := msg.(*wire.MsgGetCFHeaders) + q, ok := msg.Message().(*wire.MsgGetCFHeaders) if !ok { return nil, fmt.Errorf("got unexpected message %T", msg) @@ -350,9 +350,9 @@ func TestBlockManagerInitialInterval(t *testing.T) { requests []*query.Request, options ...query.QueryOption) chan error { - var msgs []wire.Message + var msgs []query.ReqMessage for _, q := range requests { - msgs = append(msgs, q.Req.Message()) + msgs = append(msgs, q.Req) } responses, err := generateResponses(msgs, headers) @@ -379,13 +379,13 @@ func TestBlockManagerInitialInterval(t *testing.T) { // Let the blockmanager handle the // message. progress := requests[index].HandleResp( - msgs[index], &resp, "", + msgs[index], &resp, nil, ) - if !progress.Finished { + if progress != query.Finished { errChan <- fmt.Errorf("got "+ - "response false on "+ - "send of index %d: %v", + " %v on "+ + "send of index %d: %v", progress, index, testDesc) return } @@ -400,13 +400,13 @@ func TestBlockManagerInitialInterval(t *testing.T) { // Otherwise resend the response we // just sent. progress = requests[index].HandleResp( - msgs[index], &resp2, "", + msgs[index], &resp2, nil, ) - if !progress.Finished { + if progress != query.Finished { errChan <- fmt.Errorf("got "+ - "response false on "+ - "resend of index %d: "+ - "%v", index, testDesc) + " %v on "+ + "send of index %d: %v", progress, + index, testDesc) return } } @@ -580,9 +580,9 @@ func TestBlockManagerInvalidInterval(t *testing.T) { requests []*query.Request, options ...query.QueryOption) chan error { - var msgs []wire.Message + var msgs []query.ReqMessage for _, q := range requests { - msgs = append(msgs, q.Req.Message()) + msgs = append(msgs, q.Req) } responses, err := generateResponses(msgs, headers) require.NoError(t, err) @@ -619,10 +619,10 @@ func TestBlockManagerInvalidInterval(t *testing.T) { // expect. for i := range responses { progress := requests[i].HandleResp( - msgs[i], responses[i], "", + msgs[i], responses[i], nil, ) if i == test.firstInvalid { - if progress.Finished { + if progress == query.Finished { t.Errorf("expected interval "+ "%d to be invalid", i) return @@ -631,7 +631,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { break } - if !progress.Finished { + if progress != query.Finished { t.Errorf("expected interval %d to be "+ "valid", i) return diff --git a/query.go b/query.go index ad75bfca2..40c968a5f 100644 --- a/query.go +++ b/query.go @@ -66,10 +66,7 @@ var ( // noProgress will be used to indicate to a query.WorkManager that a // response makes no progress towards the completion of the query. - noProgress = query.Progress{ - Finished: false, - Progressed: false, - } + noProgress = query.NoResponse ) // queries are a set of options that can be modified per-query, unlike global @@ -470,9 +467,10 @@ func (q *cfiltersQuery) request() *query.Request { // handleResponse validates that the cfilter response we get from a peer is // sane given the getcfilter query that we made. -func (q *cfiltersQuery) handleResponse(req, resp wire.Message, - _ string) query.Progress { +func (q *cfiltersQuery) handleResponse(r query.ReqMessage, resp wire.Message, + peer query.Peer) query.Progress { + req := r.Message() // The request must have been a "getcfilters" msg. request, ok := req.(*wire.MsgGetCFilters) if !ok { @@ -573,17 +571,11 @@ func (q *cfiltersQuery) handleResponse(req, resp wire.Message, // If there are still entries left in the headerIndex then the query // has made progress but has not yet completed. if len(q.headerIndex) != 0 { - return query.Progress{ - Finished: false, - Progressed: true, - } + return query.Progressed } // The headerIndex is empty and so this query is complete. - return query.Progress{ - Finished: true, - Progressed: true, - } + return query.Finished } // prepareCFiltersQuery creates a cfiltersQuery that can be used to fetch a @@ -784,6 +776,7 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash, query.Cancel(s.quit), query.Encoding(qo.encoding), query.NumRetries(qo.numRetries), + query.ErrChan(make(chan error, 1)), } errChan := s.workManager.Query( @@ -868,7 +861,12 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, // handleResp will be called for each message received from a peer. It // will be used to signal to the work manager whether progress has been // made or not. - handleResp := func(req, resp wire.Message, peer string) query.Progress { + handleResp := func(request query.ReqMessage, resp wire.Message, sp query.Peer) query.Progress { + req := request.Message() + peer := "" + if sp != nil { + peer = sp.Addr() + } // The request must have been a "getdata" msg. _, ok := req.(*wire.MsgGetData) if !ok { @@ -933,10 +931,7 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, // we declare it sane. We can kill the query and pass the // response back to the caller. foundBlock = block - return query.Progress{ - Finished: true, - Progressed: true, - } + return query.Finished } // Prepare the query request. @@ -968,6 +963,7 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, query.Encoding(qo.encoding), query.NumRetries(qo.numRetries), query.Cancel(s.quit), + query.ErrChan(make(chan error, 1)), } // Send the request to the work manager and await a response. diff --git a/query/interface.go b/query/interface.go index d19598b28..70b82cbcf 100644 --- a/query/interface.go +++ b/query/interface.go @@ -43,6 +43,9 @@ type queryOptions struct { // that a query can be retried. If this is set then numRetries has no // effect. noRetryMax bool + + // errChan error channel with which the workmananger sends error. + errChan chan error } // QueryOption is a functional option argument to any of the network query @@ -67,6 +70,14 @@ func (qo *queryOptions) applyQueryOptions(options ...QueryOption) { } } +// ErrChan is a query option that specifies the error channel which the workmanager +// sends any error to. +func ErrChan(err chan error) QueryOption { + return func(qo *queryOptions) { + qo.errChan = err + } +} + // NumRetries is a query option that specifies the number of times a query // should be retried. func NumRetries(num uint8) QueryOption { @@ -107,19 +118,34 @@ func Cancel(cancel chan struct{}) QueryOption { } } -// Progress encloses the result of handling a response for a given Request, -// determining whether the response did progress the query. -type Progress struct { - // Finished is true if the query was finished as a result of the - // received response. - Finished bool - - // Progressed is true if the query made progress towards fully - // answering the request as a result of the received response. This is - // used for the requests types where more than one response is - // expected. - Progressed bool -} +// Progress encloses the result of handling a response for a given Request. +type Progress string + +var ( + + // Finished indicates we have received the complete, valid response for this request, + // and so we are done with it. + Finished Progress = "Received complete and valid response for request." + + // Progressed indicates that we have received a valid response, but we are expecting more. + Progressed Progress = "Received valid response, expecting more response for query." + + // UnFinishedRequest indicates that we have received some response, but we need to rescheule the job + // to completely fetch all the response required for this request. + UnFinishedRequest Progress = "Received valid response, reschedule to complete request" + + // ResponseErr indicates we obtained a valid response but response fails checks and needs to + // be rescheduled. + ResponseErr Progress = "Received valid response but fails checks " + + // IgnoreRequest indicates that we have received a valid response but the workmanager need take + // no action on the result of this job. + IgnoreRequest Progress = "Received response but ignoring" + + // NoResponse indicates that we have received an invalid response for this request, and we need + // to wait for a valid one. + NoResponse Progress = "Received invalid response" +) // Request is the main struct that defines a bitcoin network query to be sent to // connected peers. @@ -138,7 +164,7 @@ type Request struct { // should validate the response and immediately return the progress. // The response should be handed off to another goroutine for // processing. - HandleResp func(req, resp wire.Message, peer string) Progress + HandleResp func(req ReqMessage, resp wire.Message, peer Peer) Progress // SendQuery handles sending request to the worker's peer. It returns an error, // if one is encountered while sending the request. diff --git a/query/worker.go b/query/worker.go index a03dd984b..82c4948c6 100644 --- a/query/worker.go +++ b/query/worker.go @@ -17,6 +17,14 @@ var ( // ErrJobCanceled is returned if the job is canceled before the query // has been answered. ErrJobCanceled = errors.New("job canceled") + + // ErrIgnoreRequest is returned if we want to ignore the request after getting + // a response. + ErrIgnoreRequest = errors.New("ignore request") + + // ErrResponseErr is returned if we received a compatible response for the query but, it did not pass + // preliminary verification. + ErrResponseErr = errors.New("received response with error") ) // queryJob is the internal struct that wraps the Query to work on, in @@ -42,9 +50,10 @@ func (q *queryJob) Index() float64 { // jobResult is the final result of the worker's handling of the queryJob. type jobResult struct { - job *queryJob - peer Peer - err error + job *queryJob + peer Peer + err error + unfinished bool } // worker is responsible for polling work from its work queue, and handing it @@ -152,8 +161,9 @@ nexJobLoop: // Wait for the correct response to be received from the peer, // or an error happening. var ( - jobErr error - timeout = time.NewTimer(job.timeout) + jobErr error + jobUnfinished bool + timeout = time.NewTimer(job.timeout) ) feedbackLoop: @@ -164,36 +174,49 @@ nexJobLoop: // our request. case resp := <-msgChan: progress := job.HandleResp( - job.Req.Message(), resp, peer.Addr(), + job.Req, resp, peer, ) log.Tracef("Worker %v handled msg %T while "+ - "waiting for response to %T (job=%v). "+ - "Finished=%v, progressed=%v", - peer.Addr(), resp, job.Req, job.Index(), - progress.Finished, progress.Progressed) - - // If the response did not answer our query, we - // check whether it did progress it. - if !progress.Finished { - // If it did make progress we reset the - // timeout. This ensures that the - // queries with multiple responses - // expected won't timeout before all - // responses have been handled. - // TODO(halseth): separate progress - // timeout value. - if progress.Progressed { - timeout.Stop() - timeout = time.NewTimer( - job.timeout, - ) - } + "waiting for response to %T (job=%v). ", + peer.Addr(), resp, job.Req, job.Index()) + + switch { + case progress == Finished: + + // Wait for valid response if we have not gotten any one yet. + case progress == NoResponse: + + continue feedbackLoop + + // Increase job's timeout if valid response has been received, and we + // are awaiting more to prevent premature timeout. + case progress == Progressed: + + timeout.Stop() + timeout = time.NewTimer( + job.timeout, + ) + continue feedbackLoop + + // Assign true to jobUnfinished to indicate that we need to reschedule job to complete request. + case progress == UnFinishedRequest: + + jobUnfinished = true + + // Assign ErrIgnoreRequest to indicate that workmanager should take no action on receipt of + // this request. + case progress == IgnoreRequest: + + jobErr = ErrIgnoreRequest + + // Assign ErrResponseErr to jobErr if we received a valid response that did not pass checks. + case progress == ResponseErr: + + jobErr = ErrResponseErr } - // We did get a valid response, and can break - // the loop. break feedbackLoop // If the timeout is reached before a valid response @@ -259,9 +282,10 @@ nexJobLoop: // getting a new job. select { case results <- &jobResult{ - job: resultJob, - peer: peer, - err: jobErr, + job: resultJob, + peer: peer, + err: jobErr, + unfinished: jobUnfinished, }: case <-quit: return diff --git a/query/worker_test.go b/query/worker_test.go index 2cf6f1751..84e06659c 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -35,6 +35,15 @@ var ( finalResp = &wire.MsgTx{ Version: 222, } + UnfinishedRequestResp = &wire.MsgTx{ + Version: 333, + } + finalRespWithErr = &wire.MsgTx{ + Version: 444, + } + IgnoreRequestResp = &wire.MsgTx{ + Version: 444, + } ) type mockPeer struct { @@ -69,25 +78,26 @@ func (m *mockPeer) Addr() string { func makeJob() *queryJob { q := &Request{ Req: req, - HandleResp: func(req, resp wire.Message, _ string) Progress { + HandleResp: func(req ReqMessage, resp wire.Message, peer Peer) Progress { if resp == finalResp { - return Progress{ - Finished: true, - Progressed: true, - } + return Finished } if resp == progressResp { - return Progress{ - Finished: false, - Progressed: true, - } + return Progressed + } + + if resp == UnfinishedRequestResp { + return UnFinishedRequest } - return Progress{ - Finished: false, - Progressed: false, + if resp == finalRespWithErr { + return ResponseErr + } + if resp == IgnoreRequestResp { + return IgnoreRequest } + return NoResponse }, SendQuery: func(peer Peer, req ReqMessage) error { m := peer.(*mockPeer) @@ -233,6 +243,11 @@ func TestWorkerIgnoreMsgs(t *testing.T) { t.Fatalf("result's job index should not be different from task's") } + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") + } + // And the correct peer. if result.peer != ctx.peer { t.Fatalf("expected peer to be %v, was %v", @@ -300,6 +315,11 @@ func TestWorkerTimeout(t *testing.T) { ctx.peer.Addr(), result.peer) } + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") + } + // It will immediately attempt to fetch another task. select { case ctx.nextJob <- task: @@ -365,6 +385,11 @@ func TestWorkerDisconnect(t *testing.T) { ctx.peer.Addr(), result.peer) } + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") + } + // No more jobs should be accepted by the worker after it has exited. select { case ctx.nextJob <- task: @@ -393,70 +418,121 @@ func TestWorkerProgress(t *testing.T) { } // Create a task with a small timeout, and give it to the worker. - task := makeJob() - task.timeout = taskTimeout - - select { - case ctx.nextJob <- task: - case <-time.After(1 * time.Second): - t.Fatalf("did not pick up job") + type testResp struct { + name string + response *wire.MsgTx + err *error + unfinished bool } - // The request should be given to the peer. - select { - case <-ctx.peer.requests: - case <-time.After(time.Second): - t.Fatalf("request not sent") - } + testCases := []testResp{ - // Send a few other responses that indicates progress, but not success. - // We add a small delay between each time we send a response. In total - // the delay will be larger than the query timeout, but since we are - // making progress, the timeout won't trigger. - for i := 0; i < 5; i++ { - select { - case ctx.peer.responses <- progressResp: - case <-time.After(time.Second): - t.Fatalf("resp not received") - } + { + name: "final response.", + response: finalResp, + }, - time.Sleep(taskTimeout / 2) - } + { + name: "Unfinished request response.", + response: UnfinishedRequestResp, + unfinished: true, + }, - // Finally send the final response. - select { - case ctx.peer.responses <- finalResp: - case <-time.After(time.Second): - t.Fatalf("resp not received") - } + { + name: "ignore request", + response: IgnoreRequestResp, + err: &ErrIgnoreRequest, + }, - // The worker should respond with a job finised. - var result *jobResult - select { - case result = <-ctx.jobResults: - case <-time.After(time.Second): - t.Fatalf("response not received") + { + name: "final response, with err", + response: finalRespWithErr, + err: &ErrResponseErr, + }, } - if result.err != nil { - t.Fatalf("expected no error, got: %v", result.err) - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + task := makeJob() + task.timeout = taskTimeout - // Make sure the QueryJob instance in the result is different from the initial one - // supplied to the worker - if result.job == task { - t.Fatalf("result's job should be different from the task's") - } + select { + case ctx.nextJob <- task: + case <-time.After(1 * time.Second): + t.Fatalf("did not pick up job") + } - // Make sure we are receiving the corresponding result for the given task. - if result.job.Index() != task.Index() { - t.Fatalf("result's job index should not be different from task's") - } + // The request should be given to the peer. + select { + case <-ctx.peer.requests: + case <-time.After(time.Second): + t.Fatalf("request not sent") + } - // And the correct peer. - if result.peer != ctx.peer { - t.Fatalf("expected peer to be %v, was %v", - ctx.peer.Addr(), result.peer) + // Send a few other responses that indicates progress, but not success. + // We add a small delay between each time we send a response. In total + // the delay will be larger than the query timeout, but since we are + // making progress, the timeout won't trigger. + for i := 0; i < 5; i++ { + select { + case ctx.peer.responses <- progressResp: + case <-time.After(time.Second): + t.Fatalf("resp not received") + } + + time.Sleep(taskTimeout / 2) + } + + // Finally send the final response. + select { + case ctx.peer.responses <- tc.response: + case <-time.After(time.Second): + t.Fatalf("resp not received") + } + + // The worker should respond with a job finished. + var result *jobResult + select { + case result = <-ctx.jobResults: + case <-time.After(time.Second): + t.Fatalf("response not received") + } + + if tc.err == nil && result.err != nil { + t.Fatalf("expected no error, got: %v", result.err) + } + + if tc.err != nil && result.err != *tc.err { + t.Fatalf("expected error, %v but got: %v", *tc.err, + result.err) + } + + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") + } + + // And the correct peer. + if result.peer != ctx.peer { + t.Fatalf("expected peer to be %v, was %v", + ctx.peer.Addr(), result.peer) + } + + // Make sure job does not return as unfinished. + if tc.unfinished && !result.unfinished { + t.Fatalf("expected job unfinished but got job finished") + } + + if !tc.unfinished && result.unfinished { + t.Fatalf("expected job finished but got unfinished job") + } + }) } } @@ -532,6 +608,11 @@ func TestWorkerJobCanceled(t *testing.T) { t.Fatalf("result's job index should not be different from task's") } + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") + } + // And the correct peer. if result.peer != ctx.peer { t.Fatalf("expected peer to be %v, was %v", diff --git a/query/workmanager.go b/query/workmanager.go index 847f53113..8dffec1eb 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -26,7 +26,6 @@ var ( type batch struct { requests []*Request options *queryOptions - errChan chan error } // Worker is the interface that must be satisfied by workers managed by the @@ -94,6 +93,9 @@ type Config struct { // Ranking is used to rank the connected peers when determining who to // give work to. Ranking PeerRanking + + // IsEligibleWorkerFunc determines which peer is eligible to receive a job. + IsEligibleWorkerFunc func(r *activeWorker, next *queryJob) bool } // peerWorkManager is the main access point for outside callers, and satisfies @@ -192,7 +194,9 @@ func (w *peerWorkManager) workDispatcher() { // batches and send on their error channel. defer func() { for _, b := range currentBatches { - b.errChan <- ErrWorkManagerShuttingDown + if b.errChan != nil { + b.errChan <- ErrWorkManagerShuttingDown + } } }() @@ -212,7 +216,7 @@ Loop: next := work.Peek().(*queryJob) // Find the peers with free work slots available. - var freeWorkers []string + var freeEligibleWorkers []string for p, r := range workers { // Only one active job at a time is currently // supported. @@ -220,15 +224,24 @@ Loop: continue } - freeWorkers = append(freeWorkers, p) + // If there is a specified eligibilty function for + // the peer, use it to determine which peers we can + // send jobs to. + if w.cfg.IsEligibleWorkerFunc != nil { + if !w.cfg.IsEligibleWorkerFunc(r, next) { + continue + } + } + + freeEligibleWorkers = append(freeEligibleWorkers, p) } // Use the historical data to rank them. - w.cfg.Ranking.Order(freeWorkers) + w.cfg.Ranking.Order(freeEligibleWorkers) // Give the job to the highest ranked peer with free // slots available. - for _, p := range freeWorkers { + for _, p := range freeEligibleWorkers { r := workers[p] // The worker has free work slots, it should @@ -322,13 +335,19 @@ Loop: // batch's error channel. We do this since a // cancellation applies to the whole batch. if batch != nil { - batch.errChan <- result.err + if batch.errChan != nil { + batch.errChan <- result.err + } delete(currentBatches, batchNum) log.Debugf("Canceled batch %v", batchNum) continue Loop } + // Take no action if we are to ignore request. + case result.err == ErrIgnoreRequest: + log.Debugf("received ignore request") + continue Loop // If the query ended with any other error, put it back // into the work queue if it has not reached the @@ -356,7 +375,9 @@ Loop: // Return the error and cancel the // batch. - batch.errChan <- result.err + if batch.errChan != nil { + batch.errChan <- result.err + } delete(currentBatches, batchNum) log.Debugf("Canceled batch %v", @@ -388,6 +409,19 @@ Loop: // Reward the peer for the successful query. w.cfg.Ranking.Reward(result.peer.Addr()) + // If the result is unfinished add 0.0005 to the job index to maintain the + // required priority then push to work queue + if result.unfinished { + result.job.index = result.job.Index() + 0.0005 + log.Debugf("job %v is unfinished, creating new index", result.job.Index()) + + heap.Push(work, result.job) + batch.rem++ + currentQueries[result.job.Index()] = batchNum + } else { + log.Debugf("job %v is Finished", result.job.Index()) + } + // Decrement the number of queries remaining in // the batch. if batch != nil { @@ -399,7 +433,9 @@ Loop: // for this batch, we can notify that // it finished, and delete it. if batch.rem == 0 { - batch.errChan <- nil + if batch.errChan != nil { + batch.errChan <- nil + } delete(currentBatches, batchNum) log.Tracef("Batch %v done", @@ -414,7 +450,9 @@ Loop: if batch != nil { select { case <-batch.timeout: - batch.errChan <- ErrQueryTimeout + if batch.errChan != nil { + batch.errChan <- ErrQueryTimeout + } delete(currentBatches, batchNum) log.Warnf("Query(%d) failed with "+ @@ -463,7 +501,7 @@ Loop: maxRetries: batch.options.numRetries, timeout: time.After(batch.options.timeout), rem: len(batch.requests), - errChan: batch.errChan, + errChan: batch.options.errChan, } batchIndex++ @@ -482,18 +520,19 @@ func (w *peerWorkManager) Query(requests []*Request, qo := defaultQueryOptions() qo.applyQueryOptions(options...) - errChan := make(chan error, 1) + newBatch := &batch{ + requests: requests, + options: qo, + } // Add query messages to the queue of batches to handle. select { - case w.newBatches <- &batch{ - requests: requests, - options: qo, - errChan: errChan, - }: + case w.newBatches <- newBatch: case <-w.quit: - errChan <- ErrWorkManagerShuttingDown + if newBatch.options.errChan != nil { + newBatch.options.errChan <- ErrWorkManagerShuttingDown + } } - return errChan + return newBatch.options.errChan } diff --git a/query/workmanager_test.go b/query/workmanager_test.go index 940ea2d1c..c8483d5af 100644 --- a/query/workmanager_test.go +++ b/query/workmanager_test.go @@ -63,9 +63,14 @@ func (p *mockPeerRanking) Punish(peer string) { func (p *mockPeerRanking) Reward(peer string) { } +type ctx struct { + wm WorkManager + peerChan chan Peer +} + // startWorkManager starts a new workmanager with the given number of mock // workers. -func startWorkManager(t *testing.T, numWorkers int) (WorkManager, +func startWorkManager(t *testing.T, numWorkers int) (ctx, []*mockWorker) { // We set up a custom NewWorker closure for the WorkManager, such that @@ -116,7 +121,10 @@ func startWorkManager(t *testing.T, numWorkers int) (WorkManager, workers[i] = w } - return wm, workers + return ctx{ + wm: wm, + peerChan: peerChan, + }, workers } // TestWorkManagerWorkDispatcherSingleWorker tests that the workDispatcher @@ -126,8 +134,8 @@ func TestWorkManagerWorkDispatcherSingleWorker(t *testing.T) { const numQueries = 100 // Start work manager with a sinlge worker. - wm, workers := startWorkManager(t, 1) - + c, workers := startWorkManager(t, 1) + wm := c.wm // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { @@ -138,7 +146,7 @@ func TestWorkManagerWorkDispatcherSingleWorker(t *testing.T) { queries = append(queries, q) } - errChan := wm.Query(queries) + errChan := wm.Query(queries, ErrChan(make(chan error, 1))) wk := workers[0] @@ -189,7 +197,8 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { // Start work manager with as many workers as queries. This is not very // realistic, but makes the work manager able to schedule all queries // concurrently. - wm, workers := startWorkManager(t, numQueries) + c, workers := startWorkManager(t, numQueries) + wm := c.wm // When the jobs gets scheduled, keep track of which worker was // assigned the job. @@ -226,7 +235,7 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { } // Send the batch, and Retrieve all jobs immediately. - errChan := wm.Query(queries[:]) + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) var jobs [numQueries]sched for i := 0; i < numQueries; i++ { @@ -305,11 +314,12 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { // TestWorkManagerErrQueryTimeout tests that the workers that return query // timeout are not sent jobs until they return a different error. func TestWorkManagerErrQueryTimeout(t *testing.T) { - const numQueries = 2 + const numQueries = 1 const numWorkers = 1 // Start work manager. - wm, workers := startWorkManager(t, numWorkers) + c, workers := startWorkManager(t, numWorkers) + wm := c.wm // When the jobs gets scheduled, keep track of which worker was // assigned the job. @@ -329,46 +339,38 @@ func TestWorkManagerErrQueryTimeout(t *testing.T) { scheduledJobs[i] = make(chan sched) } - // Fot each worker, spin up a goroutine that will forward the job it - // got to our slice of scheduled jobs, such that we can handle them in + // Spin up goroutine for only one worker. Forward gotten jobs + // to our slice of scheduled jobs, such that we can handle them in // order. - for i := 0; i < len(workers); i++ { - wk := workers[i] - go func() { - for { - job := <-wk.nextJob - scheduledJobs[int(job.index)] <- sched{ - wk: wk, - job: job, - } + wk := workers[0] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, } - }() - } + } + }() // Send the batch, and Retrieve all jobs immediately. - errChan := wm.Query(queries[:]) + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) - var iter int var s sched - for i := 0; i < numQueries; i++ { - select { - case s = <-scheduledJobs[i]: - if s.job.index != float64(i) { - t.Fatalf("wrong index") - } - if iter == 1 { - t.Fatalf("Expected only one scheduled job") - } - iter++ - case <-errChan: - t.Fatalf("did not expect on errChan") - case <-time.After(time.Second): - if iter < 1 { - t.Fatalf("next job not received") - } + + // Ensure job is sent to the worker. + select { + case s = <-scheduledJobs[0]: + if s.job.index != float64(0) { + t.Fatalf("wrong index") } + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") } + // Return jobResult with an ErrQueryTimeout. select { case s.wk.results <- &jobResult{ job: s.job, @@ -380,9 +382,9 @@ func TestWorkManagerErrQueryTimeout(t *testing.T) { t.Fatalf("result not handled") } - // Finally, make sure the job is not retried as there are no available - // peer to retry it. - + // Make sure the job is not retried as there are no available + // peer to retry it. The only available worker should be waiting in the + // worker feedback loop. select { case <-scheduledJobs[0]: t.Fatalf("did not expect job rescheduled") @@ -408,7 +410,8 @@ func TestWorkManagerCancelBatch(t *testing.T) { const numQueries = 100 // Start the workDispatcher goroutine. - wm, workers := startWorkManager(t, 1) + c, workers := startWorkManager(t, 1) + wm := c.wm wk := workers[0] // Schedule a batch of queries. @@ -422,7 +425,7 @@ func TestWorkManagerCancelBatch(t *testing.T) { // Send the query, and include a channel to cancel the batch. cancelChan := make(chan struct{}) - errChan := wm.Query(queries, Cancel(cancelChan)) + errChan := wm.Query(queries, Cancel(cancelChan), ErrChan(make(chan error, 1))) // Respond with a result to half of the queries. for i := 0; i < numQueries/2; i++ { @@ -493,7 +496,8 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { const numQueries = 4 const numWorkers = 8 - workMgr, workers := startWorkManager(t, numWorkers) + c, workers := startWorkManager(t, numWorkers) + workMgr := c.wm require.IsType(t, workMgr, &peerWorkManager{}) wm := workMgr.(*peerWorkManager) //nolint:forcetypeassert @@ -513,7 +517,7 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { } // Send the batch, and Retrieve all jobs immediately. - errChan := wm.Query(queries) + errChan := wm.Query(queries, ErrChan(make(chan error, 1))) // The 4 first workers should get the job. var jobs []*queryJob @@ -576,7 +580,7 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { } queries = append(queries, q) } - _ = wm.Query(queries) + _ = wm.Query(queries, ErrChan(make(chan error, 1))) // The new jobs should be scheduled on the even numbered workers. for i := 0; i < len(workers); i += 2 { @@ -596,7 +600,8 @@ func TestWorkManagerSchedulePriorityIndex(t *testing.T) { // Start work manager with as many workers as queries. This is not very // realistic, but makes the work manager able to schedule all queries // concurrently. - wm, workers := startWorkManager(t, numQueries) + c, workers := startWorkManager(t, numQueries) + wm := c.wm // When the jobs gets scheduled, keep track of which worker was // assigned the job. @@ -645,7 +650,7 @@ func TestWorkManagerSchedulePriorityIndex(t *testing.T) { } // Send the batch, and Retrieve all jobs immediately. - errChan := wm.Query(queries[:]) + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) var jobs [numQueries]sched for i := uint64(0); i < numQueries; i++ { @@ -705,3 +710,278 @@ func TestWorkManagerSchedulePriorityIndex(t *testing.T) { t.Fatalf("nothing received on errChan") } } + +// TestPeerWorkManager_Stop tests the workmanager shutdown. +func TestPeerWorkManager_Stop(t *testing.T) { + const numQueries = 5 + + c, _ := startWorkManager(t, 0) + wm := c.wm + + createRequest := func(numQuery int) []*Request { + var queries []*Request + for i := 0; i < numQuery; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries = append(queries, q) + } + + return queries + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(createRequest(numQueries), ErrChan(make(chan error, 1))) + errChan2 := wm.Query(createRequest(numQueries)) + + if errChan2 != nil { + t.Fatalf("expected Query call without ErrChan option func to return" + + "niil errChan") + } + + errChan3 := make(chan error, 1) + go func() { + err := wm.Stop() + + errChan3 <- err + }() + + select { + case <-errChan: + case <-time.After(time.Second): + t.Fatalf("expected error workmanager shutting down") + } + + select { + case err := <-errChan3: + if err != nil { + t.Fatalf("unexpected error while stopping workmanager: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("workmanager stop functunction should return error") + } +} + +// TestWorkManagerErrResponseExistForQuery tests a scenario in which a workmanager handles +// an ErrIgnoreRequest. +func TestWorkManagerErrResponseExistForQuery(t *testing.T) { + const numQueries = 5 + + // Start work manager with as many workers as queries. This is not very + // realistic, but makes the work manager able to schedule all queries + // concurrently. + c, workers := startWorkManager(t, numQueries) + wm := c.wm + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var ( + queries [numQueries]*Request + scheduledJobs [numQueries]chan sched + ) + for i := 0; i < numQueries; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries[i] = q + scheduledJobs[i] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) + var jobs [numQueries]sched + for i := 0; i < numQueries; i++ { + var s sched + select { + case s = <-scheduledJobs[i]: + if s.job.index != float64(i) { + t.Fatalf("wrong index") + } + + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + jobs[int(s.job.index)] = s + } + + // Go backwards, and make half of it return with an ErrIgnoreRequest. + for i := numQueries - 1; i >= 0; i-- { + select { + case jobs[i].wk.results <- &jobResult{ + job: jobs[i].job, + err: ErrIgnoreRequest, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // Finally, make sure the failed jobs are not retried. + for i := 0; i < numQueries; i++ { + var s sched + select { + case s = <-scheduledJobs[i]: + t.Fatalf("did not expect any retried job but job"+ + "%v\n retried", s.job.index) + case <-errChan: + t.Fatalf("did not expect an errChan") + case <-time.After(time.Second): + } + } +} + +// TestWorkManagerResultUnfinished tests the workmanager handling a result with an unfinished boolean set +// to true. +func TestWorkManagerResultUnfinished(t *testing.T) { + const numQueries = 10 + + // Start work manager with as many workers as queries. This is not very + // realistic, but makes the work manager able to schedule all queries + // concurrently. + c, workers := startWorkManager(t, numQueries) + wm := c.wm + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var ( + queries [numQueries]*Request + scheduledJobs [numQueries]chan sched + ) + for i := 0; i < numQueries; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries[i] = q + scheduledJobs[i] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) + var jobs [numQueries]sched + for i := 0; i < numQueries; i++ { + var s sched + select { + case s = <-scheduledJobs[i]: + if s.job.index != float64(i) { + t.Fatalf("wrong index") + } + + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + jobs[int(s.job.index)] = s + } + + // Go backwards, and make half of it unfinished. + for i := numQueries - 1; i >= 0; i-- { + var ( + unfinished bool + ) + if i%2 == 0 { + unfinished = true + } + + select { + case jobs[i].wk.results <- &jobResult{ + job: jobs[i].job, + unfinished: unfinished, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // Finally, make sure the failed jobs are being retried, in the same + // order as they were originally scheduled. + for i := 0; i < numQueries; i += 2 { + var s sched + select { + case s = <-scheduledJobs[i]: + + // The new tindex the job should have. + idx := float64(i) + 0.0005 + if idx != s.job.index { + t.Fatalf("expected index %v for job"+ + "but got, %v\n", idx, s.job.index) + } + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + select { + case s.wk.results <- &jobResult{ + job: s.job, + err: nil, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // The query should ultimately succeed. + select { + case err := <-errChan: + if err != nil { + t.Fatalf("got error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("nothing received on errChan") + } +} diff --git a/query_test.go b/query_test.go index bb45deed6..f8610b6e0 100644 --- a/query_test.go +++ b/query_test.go @@ -324,10 +324,10 @@ func TestBlockCache(t *testing.T) { Header: *header, Transactions: b.MsgBlock().Transactions, } - - progress := reqs[0].HandleResp(getData, resp, "") - require.True(t, progress.Progressed) - require.True(t, progress.Finished) + progress := reqs[0].HandleResp(&encodedQuery{ + message: getData, + }, resp, nil) + require.Equal(t, query.Finished, progress) // Notify the test about the query. select { From 64b278771ff75da0b30136af7d4ab0ace06d897e Mon Sep 17 00:00:00 2001 From: Maureen Ononiwu Date: Fri, 1 Sep 2023 21:12:49 +0100 Subject: [PATCH 7/7] neutrino: parallelized block header download. This commit distributes header download across peers leveraging checckpoints and the workmanager. Signed-off-by: Maureen Ononiwu --- blockmanager.go | 545 +++++++++++++++++++++-- blockmanager_test.go | 913 +++++++++++++++++++++++++++++++++++++- neutrino.go | 52 ++- query/interface.go | 7 + query/worker.go | 26 +- query/worker_test.go | 18 +- query/workmanager.go | 7 + query/workmanager_test.go | 110 +++++ 8 files changed, 1612 insertions(+), 66 deletions(-) diff --git a/blockmanager.go b/blockmanager.go index f32a9efa6..e517f1c63 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -9,6 +9,7 @@ import ( "fmt" "math" "math/big" + "sort" "sync" "sync/atomic" "time" @@ -91,8 +92,13 @@ type blockManagerCfg struct { // the connected peers. TimeSource blockchain.MedianTimeSource - // QueryDispatcher is used to make queries to connected Bitcoin peers. - QueryDispatcher query.Dispatcher + // cfHeaderQueryDispatcher is used to make queries to connected Bitcoin peers to fetch + // CFHeaders + cfHeaderQueryDispatcher query.Dispatcher + + // cfHeaderQueryDispatcher is used to make queries to connected Bitcoin peers to fetch + // block Headers + blkHdrCheckptQueryDispatcher query.WorkManager // BanPeer bans and disconnects the given peer. BanPeer func(addr string, reason banman.Reason) error @@ -174,6 +180,10 @@ type blockManager struct { // nolint:maligned // time, newHeadersMtx should always be acquired first. newFilterHeadersMtx sync.RWMutex + // writeBatchMtx is the mutex used to hold reading and reading and writing into the + // hdrTipToResponse map. + writeBatchMtx sync.RWMutex + // newFilterHeadersSignal is condition variable which will be used to // notify any waiting callers (via Broadcast()) that the tip of the // current filter header chain has changed. This is useful when callers @@ -207,6 +217,18 @@ type blockManager struct { // nolint:maligned minRetargetTimespan int64 // target timespan / adjustment factor maxRetargetTimespan int64 // target timespan * adjustment factor blocksPerRetarget int32 // target timespan / target time per block + + // hdrTipToResponse is a map that holds the response gotten from querying peers + // using the workmanager, to fetch headers within the chain's checkpointed region. + // It is a map of the request's startheight to the fetch response. + hdrTipToResponse map[int32]*headersMsg + + // hdrTipSlice is a slice that holds request startHeight of the responses that have been + // fetched using the workmanager to fetch headers within the chain's checkpointed region. + // It is used to easily access this startheight in the case we have to delete these responses + // in the hdrTipResponse map during a reorg while fetching headers within the chain's checkpointed + // region. + hdrTipSlice []int32 } // newBlockManager returns a new bitcoin block manager. Use Start to begin @@ -236,6 +258,8 @@ func newBlockManager(cfg *blockManagerCfg) (*blockManager, error) { blocksPerRetarget: int32(targetTimespan / targetTimePerBlock), minRetargetTimespan: targetTimespan / adjustmentFactor, maxRetargetTimespan: targetTimespan * adjustmentFactor, + hdrTipToResponse: make(map[int32]*headersMsg), + hdrTipSlice: make([]int32, 0), } // Next we'll create the two signals that goroutines will use to wait @@ -291,8 +315,21 @@ func (b *blockManager) Start() { } log.Trace("Starting block manager") - b.wg.Add(2) + b.wg.Add(3) go b.blockHandler() + go func() { + wm := b.cfg.blkHdrCheckptQueryDispatcher + + defer b.wg.Done() + defer func(wm query.WorkManager) { + err := wm.Stop() + if err != nil { + log.Errorf("Unable to stop block header workmanager: %v", err) + } + }(wm) + + b.processBlKHeaderInCheckPtRegionInOrder() + }() go func() { defer b.wg.Done() @@ -306,6 +343,12 @@ func (b *blockManager) Start() { return } + checkpoints := b.cfg.ChainParams.Checkpoints + numCheckpts := len(checkpoints) + if numCheckpts != 0 && b.nextCheckpoint != nil { + b.batchCheckpointedBlkHeaders() + } + log.Debug("Peer connected, starting cfHandler.") b.cfHandler() }() @@ -361,6 +404,19 @@ func (b *blockManager) NewPeer(sp *ServerPeer) { } } +// addNewPeerToList adds the peer to the peers list. +func (b *blockManager) addNewPeerToList(peers *list.List, sp *ServerPeer) { + // Ignore if in the process of shutting down. + if atomic.LoadInt32(&b.shutdown) != 0 { + return + } + + log.Infof("New valid peer %s (%s)", sp, sp.UserAgent()) + + // Add the peer as a candidate to sync from. + peers.PushBack(sp) +} + // handleNewPeerMsg deals with new peers that have signalled they may be // considered as a sync peer (they have already successfully negotiated). It // also starts syncing if needed. It is invoked from the syncHandler @@ -374,12 +430,12 @@ func (b *blockManager) handleNewPeerMsg(peers *list.List, sp *ServerPeer) { log.Infof("New valid peer %s (%s)", sp, sp.UserAgent()) // Ignore the peer if it's not a sync candidate. - if !b.isSyncCandidate(sp) { + if !sp.IsSyncCandidate() { return } // Add the peer as a candidate to sync from. - peers.PushBack(sp) + b.addNewPeerToList(peers, sp) // If we're current with our sync peer and the new peer is advertising // a higher block than the newest one we know of, request headers from @@ -419,11 +475,8 @@ func (b *blockManager) DonePeer(sp *ServerPeer) { } } -// handleDonePeerMsg deals with peers that have signalled they are done. It -// removes the peer as a candidate for syncing and in the case where it was the -// current sync peer, attempts to select a new best peer to sync from. It is -// invoked from the syncHandler goroutine. -func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { +// removeDonePeerFromList removes the peer from the peers list. +func (b *blockManager) removeDonePeerFromList(peers *list.List, sp *ServerPeer) { // Remove the peer from the list of candidate peers. for e := peers.Front(); e != nil; e = e.Next() { if e.Value == sp { @@ -433,6 +486,17 @@ func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { } log.Infof("Lost peer %s", sp) +} + +// handleDonePeerMsg deals with peers that have signalled they are done. It +// removes the peer as a candidate for syncing and in the case where it was the +// current sync peer, attempts to select a new best peer to sync from. It is +// invoked from the syncHandler goroutine. +func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { + // Remove the peer from the list of candidate peers. + b.removeDonePeerFromList(peers, sp) + + log.Infof("Lost peer %s", sp) // Attempt to find a new peer to sync from if the quitting peer is the // sync peer. Also, reset the header state. @@ -1090,7 +1154,7 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // Hand the queries to the work manager, and consume the verified // responses as they come back. - errChan := b.cfg.QueryDispatcher.Query( + errChan := b.cfg.cfHeaderQueryDispatcher.Query( q.requests(), query.Cancel(b.quit), query.NoRetryMax(), query.ErrChan(make(chan error, 1)), ) @@ -2029,7 +2093,38 @@ func (b *blockManager) blockHandler() { defer b.wg.Done() candidatePeers := list.New() -out: + checkpoints := b.cfg.ChainParams.Checkpoints + if len(checkpoints) == 0 || b.nextCheckpoint == nil { + goto unCheckPtLoop + } + + // Loop to fetch headers within the check pointed range + b.newHeadersMtx.RLock() + for b.headerTip <= uint32(checkpoints[len(checkpoints)-1].Height) { + b.newHeadersMtx.RUnlock() + select { + case m := <-b.peerChan: + switch msg := m.(type) { + case *newPeerMsg: + b.addNewPeerToList(candidatePeers, msg.peer) + case *donePeerMsg: + b.removeDonePeerFromList(candidatePeers, msg.peer) + default: + log.Tracef("Invalid message type in block "+ + "handler: %T", msg) + } + + case <-b.quit: + return + } + b.newHeadersMtx.RLock() + } + b.newHeadersMtx.RUnlock() + + log.Infof("Fetching uncheckpointed block headers from %v", b.headerTip) + b.startSync(candidatePeers) + +unCheckPtLoop: for { // Now check peer messages and quit channels. select { @@ -2053,13 +2148,367 @@ out: } case <-b.quit: - break out + break unCheckPtLoop } } log.Trace("Block handler done") } +// processBlKHeaderInCheckPtRegionInOrder handles and writes the block headers received from querying the +// workmanager while fetching headers within the block header checkpoint region. This process is carried out +// in order. +func (b *blockManager) processBlKHeaderInCheckPtRegionInOrder() { + lenCheckPts := len(b.cfg.ChainParams.Checkpoints) + + // Loop should run as long as we are in the block header checkpointed region. + b.newHeadersMtx.RLock() + for int32(b.headerTip) <= b.cfg.ChainParams.Checkpoints[lenCheckPts-1].Height { + hdrTip := b.headerTip + b.newHeadersMtx.RUnlock() + + select { + // return quickly if the blockmanager quits. + case <-b.quit: + return + default: + } + + // do not go further if we have not received the response mapped to our header tip. + b.writeBatchMtx.RLock() + msg, ok := b.hdrTipToResponse[int32(hdrTip)] + b.writeBatchMtx.RUnlock() + + if !ok { + b.newHeadersMtx.RLock() + continue + } + + b.syncPeerMutex.Lock() + b.syncPeer = msg.peer + b.syncPeerMutex.Unlock() + + b.handleHeadersMsg(msg) + err := b.resetHeaderListToChainTip() + if err != nil { + log.Errorf(err.Error()) + } + + finalNode := b.headerList.Back() + newHdrTip := finalNode.Height + newHdrTipHash := finalNode.Header.BlockHash() + prevCheckPt := b.findPreviousHeaderCheckpoint(newHdrTip) + + log.Tracef("New headertip %v", newHdrTip) + log.Debugf("New headertip %v", newHdrTip) + + // If our header tip has not increased, there is a problem with the headers we received and so we + // delete all the header response within our previous header tip and our new header tip, then send + // another query to the workmanager. + if uint32(newHdrTip) <= hdrTip { + b.deleteHeaderTipResp(newHdrTip, int32(hdrTip)) + + log.Tracef("while fetching checkpointed headers received invalid headers") + + q := CheckpointedBlockHeadersQuery{ + blockMgr: b, + msgs: []*headerQuery{ + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{&newHdrTipHash}, + HashStop: *b.nextCheckpoint.Hash, + }, + startHeight: newHdrTip, + initialHeight: prevCheckPt.Height, + startHash: newHdrTipHash, + endHeight: b.nextCheckpoint.Height, + initialHash: newHdrTipHash, + // Set it as high priority so that workmanager can schedule before any other queries. + index: 0, + }, + }, + } + + b.cfg.blkHdrCheckptQueryDispatcher.Query( + q.requests(), query.Cancel(b.quit), query.Timeout(1*time.Hour), query.NoRetryMax(), + ) + + log.Tracef("Sending query to workmanager from processBlKHeaderInCheckPtRegionInOrder loop") + } + b.newHeadersMtx.RLock() + } + b.newHeadersMtx.RUnlock() + + b.syncPeerMutex.Lock() + b.syncPeer = nil + b.syncPeerMutex.Unlock() + + log.Infof("Successfully completed fetching checkpointed block headers") +} + +// batchCheckpointedBlkHeaders creates headerQuery to fetch block headers +// within the chain's checkpointed region. +func (b *blockManager) batchCheckpointedBlkHeaders() { + var queryMsgs []*headerQuery + curHeight := b.headerTip + curHash := b.headerTipHash + nextCheckpoint := b.nextCheckpoint + nextCheckptHash := nextCheckpoint.Hash + nextCheckptHeight := nextCheckpoint.Height + + log.Infof("Fetching set of checkpointed blockheaders from "+ + "height=%v, hash=%v\n", curHeight, curHash) + + for nextCheckpoint != nil { + endHash := nextCheckptHash + endHeight := nextCheckptHeight + tmpCurHash := curHash + + msg := &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{&tmpCurHash}), + HashStop: *endHash, + }, + startHeight: int32(curHeight), + initialHeight: int32(curHeight), + startHash: curHash, + endHeight: endHeight, + initialHash: tmpCurHash, + } + + log.Debugf("Fetching set of checkpointed blockheaders from "+ + "start_height=%v to end-height=%v", curHeight, endHash) + + queryMsgs = append(queryMsgs, msg) + curHeight = uint32(endHeight) + curHash = *endHash + + nextCheckpoint := b.findNextHeaderCheckpoint(int32(curHeight)) + if nextCheckpoint == nil { + break + } + + nextCheckptHeight = nextCheckpoint.Height + nextCheckptHash = nextCheckpoint.Hash + } + + msg := &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{nextCheckptHash}), + HashStop: zeroHash, + }, + startHeight: nextCheckptHeight, + initialHeight: nextCheckptHeight, + startHash: *nextCheckptHash, + endHeight: nextCheckptHeight + wire.MaxBlockHeadersPerMsg, + initialHash: *nextCheckptHash, + } + + log.Debugf("Fetching set of checkpointed blockheaders from "+ + "start_height=%v to end-height=%v", curHeight, zeroHash) + + queryMsgs = append(queryMsgs, msg) + + log.Debugf("Attempting to query for %v blockheader batches", len(queryMsgs)) + + q := CheckpointedBlockHeadersQuery{ + blockMgr: b, + msgs: queryMsgs, + } + + b.cfg.blkHdrCheckptQueryDispatcher.Query( + q.requests(), query.Cancel(b.quit), query.Timeout(1*time.Hour), query.NoRetryMax(), + ) +} + +// CheckpointedBlockHeadersQuery holds all information necessary to perform and +// // handle a query for checkpointed block headers. +type CheckpointedBlockHeadersQuery struct { + blockMgr *blockManager + msgs []*headerQuery +} + +// requests creates the query.Requests for this block headers query. +func (c *CheckpointedBlockHeadersQuery) requests() []*query.Request { + reqs := make([]*query.Request, len(c.msgs)) + for idx, m := range c.msgs { + reqs[idx] = &query.Request{ + Req: m, + SendQuery: c.PushHeadersMsg, + HandleResp: c.handleResponse, + CloneReq: cloneHeaderQuery, + } + } + + return reqs +} + +// cloneHeaderQuery clones the query.ReqMessage containing the headerQuery Struct. +func cloneHeaderQuery(req query.ReqMessage) query.ReqMessage { + oldReq, ok := req.(*headerQuery) + if !ok { + log.Errorf("request not of type *wire.MsgGetHeaders") + } + oldReqMessage := req.Message().(*wire.MsgGetHeaders) + message := &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: oldReqMessage.BlockLocatorHashes, + HashStop: oldReqMessage.HashStop, + }, + startHeight: oldReq.startHeight, + initialHeight: oldReq.initialHeight, + startHash: oldReq.startHash, + endHeight: oldReq.endHeight, + } + + return message +} + +// PushHeadersMsg is the internal response handler used for requests for this +// block Headers query. +func (c *CheckpointedBlockHeadersQuery) PushHeadersMsg(peer query.Peer, + task query.ReqMessage) error { + + request, _ := task.Message().(*wire.MsgGetHeaders) + + requestMsg := task.(*headerQuery) + + // check if we have response for the query already. If we do return an error. + c.blockMgr.writeBatchMtx.RLock() + _, ok := c.blockMgr.hdrTipToResponse[requestMsg.startHeight] + c.blockMgr.writeBatchMtx.RUnlock() + if ok { + log.Debugf("Response already received, peer=%v, "+ + "start_height=%v, end_height=%v, index=%v", peer.Addr(), + requestMsg.startHeight, requestMsg.endHeight) + return query.ErrIgnoreRequest + } + + sp := peer.(*ServerPeer) + err := sp.PushGetHeadersMsg(request.BlockLocatorHashes, &request.HashStop) + if err != nil { + log.Errorf(err.Error()) + return err + } + + return nil +} + +// handleResponse is the internal response handler used for requests for this +// block header query. +func (c *CheckpointedBlockHeadersQuery) handleResponse(req query.ReqMessage, resp wire.Message, + peer query.Peer) query.Progress { + + sp := peer.(*ServerPeer) + if peer == nil { + return query.NoResponse + } + + msg, ok := resp.(*wire.MsgHeaders) + if !ok { + // We are only looking for msgHeaders messages. + return query.NoResponse + } + + request, ok := req.(*headerQuery) + if !ok { + // request should only be of type headerQuery. + return query.NoResponse + } + + // Check if we already have a response for this request startHeight, if we do modify our jobErr variable + // so that worker can send appropriate error to workmanager. + c.blockMgr.writeBatchMtx.RLock() + _, ok = c.blockMgr.hdrTipToResponse[request.startHeight] + c.blockMgr.writeBatchMtx.RUnlock() + if ok { + return query.IgnoreRequest + } + + // If we received an empty response from peer, return with an error to break worker's + // feed back loop. + hdrLength := len(msg.Headers) + if hdrLength == 0 { + return query.ResponseErr + } + + // The initialHash represents the lower bound checkpoint for this checkpoint region. + // We verify, if the header received at that checkpoint height has the same hash as the + // checkpoint's hash. If it does not, mimicking the handleheaders function behaviour, we + // disconnect the peer and return a failed progress to reschedule the query. + if msg.Headers[0].PrevBlock != request.startHash && + request.startHash == request.initialHash { + + sp.Disconnect() + + return query.ResponseErr + } + + // If the peer sends us more headers than we need, it is probably not aligned with our chain, so we disconnect + // peer and return a failed progress. + reqMessage := request.Message().(*wire.MsgGetHeaders) + + if hdrLength > int(request.endHeight-request.startHeight) { + sp.Disconnect() + return query.ResponseErr + } + + // Write header into hdrTipResponse map, add the request's startHeight to the hdrTipSlice, for tracking + // and handling by the processBlKHeaderInCheckPtRegionInOrder loop. + c.blockMgr.writeBatchMtx.Lock() + c.blockMgr.hdrTipToResponse[request.startHeight] = &headersMsg{ + headers: msg, + peer: sp, + } + i := sort.Search(len(c.blockMgr.hdrTipSlice), func(i int) bool { + return c.blockMgr.hdrTipSlice[i] >= request.startHeight + }) + + c.blockMgr.hdrTipSlice = append(c.blockMgr.hdrTipSlice[:i], append([]int32{request.startHeight}, c.blockMgr.hdrTipSlice[i:]...)...) + c.blockMgr.writeBatchMtx.Unlock() + + // Check if job is unfinished, if it is, we modify the job accordingly and send back to the workmanager to be rescheduled. + if msg.Headers[hdrLength-1].BlockHash() != reqMessage.HashStop && reqMessage.HashStop != zeroHash { + // set new startHash, startHeight and blocklocator to set the next set of header for this job. + newStartHash := msg.Headers[hdrLength-1].BlockHash() + request.startHeight += int32(hdrLength) + request.startHash = newStartHash + reqMessage.BlockLocatorHashes = []*chainhash.Hash{&newStartHash} + + // Incase there is a rollback after handling reqMessage + // This ensures the job created by writecheckpt does not exceed that which we have fetched already. + c.blockMgr.writeBatchMtx.RLock() + _, ok = c.blockMgr.hdrTipToResponse[request.startHeight] + c.blockMgr.writeBatchMtx.RUnlock() + + if !ok { + return query.UnFinishedRequest + } + } + + return query.Finished +} + +// headerQuery implements ReqMessage interface for fetching block headers. +type headerQuery struct { + message wire.Message + startHeight int32 + initialHeight int32 + startHash chainhash.Hash + endHeight int32 + initialHash chainhash.Hash + index float64 +} + +func (h *headerQuery) Message() wire.Message { + return h.message +} + +func (h *headerQuery) PriorityIndex() float64 { + return h.index +} + // SyncPeer returns the current sync peer. func (b *blockManager) SyncPeer() *ServerPeer { b.syncPeerMutex.Lock() @@ -2068,11 +2517,48 @@ func (b *blockManager) SyncPeer() *ServerPeer { return b.syncPeer } -// isSyncCandidate returns whether or not the peer is a candidate to consider -// syncing from. -func (b *blockManager) isSyncCandidate(sp *ServerPeer) bool { - // The peer is not a candidate for sync if it's not a full node. - return sp.Services()&wire.SFNodeNetwork == wire.SFNodeNetwork +// deleteHeaderTipResp deletes all responses from newTip to prevTip. +func (b *blockManager) deleteHeaderTipResp(newTip, prevTip int32) { + b.writeBatchMtx.Lock() + defer b.writeBatchMtx.Unlock() + + var ( + finalIdx int + initialIdx int + ) + + for i := 0; i < len(b.hdrTipSlice) && b.hdrTipSlice[i] <= newTip; i++ { + if b.hdrTipSlice[i] < prevTip { + continue + } + + if b.hdrTipSlice[i] == prevTip { + initialIdx = i + } + + tip := b.hdrTipSlice[i] + + delete(b.hdrTipToResponse, tip) + + finalIdx = i + } + + b.hdrTipSlice = append(b.hdrTipSlice[:initialIdx], b.hdrTipSlice[finalIdx+1:]...) +} + +// resetHeaderListToChainTip resets the headerList to the chain tip. +func (b *blockManager) resetHeaderListToChainTip() error { + header, height, err := b.cfg.BlockHeaders.ChainTip() + if err != nil { + return err + } + b.headerList.ResetHeaderState(headerlist.Node{ + Header: *header, + Height: int32(height), + }) + log.Debugf("Resetting header list to chain tip %v ", b.headerTip) + + return nil } // findNextHeaderCheckpoint returns the next checkpoint after the passed height. @@ -2747,20 +3233,16 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { } } - // When this header is a checkpoint, find the next checkpoint. - if receivedCheckpoint { - b.nextCheckpoint = b.findNextHeaderCheckpoint(finalHeight) - } - // If not current, request the next batch of headers starting from the // latest known header and ending with the next checkpoint. - if b.cfg.ChainParams.Net == chaincfg.SimNetParams.Net || !b.BlockHeadersSynced() { + // Note this must come before reassigning a new b.nextCheckpoint, so that we push headers + // only when the current headers before this takes us past the checkpointed region. + if b.cfg.ChainParams.Net == chaincfg.SimNetParams.Net || !b.BlockHeadersSynced() && + b.nextCheckpoint == nil { + locator := blockchain.BlockLocator([]*chainhash.Hash{finalHash}) - nextHash := zeroHash - if b.nextCheckpoint != nil { - nextHash = *b.nextCheckpoint.Hash - } - err := hmsg.peer.PushGetHeadersMsg(locator, &nextHash) + + err := hmsg.peer.PushGetHeadersMsg(locator, &zeroHash) if err != nil { log.Warnf("Failed to send getheaders message to "+ "peer %s: %s", hmsg.peer.Addr(), err) @@ -2768,6 +3250,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { } } + // When this header is a checkpoint, find the next checkpoint. + if receivedCheckpoint { + b.nextCheckpoint = b.findNextHeaderCheckpoint(finalHeight) + } + // Since we have a new set of headers written to disk, we'll send out a // new signal to notify any waiting sub-systems that they can now maybe // proceed do to us extending the header chain. diff --git a/blockmanager_test.go b/blockmanager_test.go index 060443b91..1e3ad4de7 100644 --- a/blockmanager_test.go +++ b/blockmanager_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "reflect" + "sort" "strings" "testing" "time" @@ -89,11 +90,12 @@ func setupBlockManager(t *testing.T) (*blockManager, headerfs.BlockHeaderStore, // Set up a blockManager with the chain service we defined. bm, err := newBlockManager(&blockManagerCfg{ - ChainParams: chaincfg.SimNetParams, - BlockHeaders: hdrStore, - RegFilterHeaders: cfStore, - QueryDispatcher: &mockDispatcher{}, - TimeSource: blockchain.NewMedianTime(), + ChainParams: chaincfg.SimNetParams, + BlockHeaders: hdrStore, + RegFilterHeaders: cfStore, + cfHeaderQueryDispatcher: &mockDispatcher{}, + blkHdrCheckptQueryDispatcher: &mockDispatcher{}, + TimeSource: blockchain.NewMedianTime(), BanPeer: func(string, banman.Reason) error { return nil }, @@ -346,7 +348,7 @@ func TestBlockManagerInitialInterval(t *testing.T) { // We set up a custom query batch method for this test, as we // will use this to feed the blockmanager with our crafted // responses. - bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( + bm.cfg.cfHeaderQueryDispatcher.(*mockDispatcher).query = func( requests []*query.Request, options ...query.QueryOption) chan error { @@ -576,7 +578,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { require.NoError(t, err) } - bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( + bm.cfg.cfHeaderQueryDispatcher.(*mockDispatcher).query = func( requests []*query.Request, options ...query.QueryOption) chan error { @@ -884,20 +886,6 @@ func TestHandleHeaders(t *testing.T) { fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") require.NoError(t, err) - assertPeerDisconnected := func(shouldBeDisconnected bool) { - // This is quite hacky but works: We expect the peer to be - // disconnected, which sets the unexported "disconnected" field - // to 1. - refValue := reflect.ValueOf(fakePeer).Elem() - foo := refValue.FieldByName("disconnect").Int() - - if shouldBeDisconnected { - require.EqualValues(t, 1, foo) - } else { - require.EqualValues(t, 0, foo) - } - } - // We'll want to use actual, real blocks, so we take a miner harness // that we can use to generate some. harness, err := rpctest.New( @@ -934,7 +922,7 @@ func TestHandleHeaders(t *testing.T) { // Let's feed in the correct headers. This should work fine and the peer // should not be disconnected. bm.handleHeadersMsg(hmsg) - assertPeerDisconnected(false) + assertPeerDisconnected(false, fakePeer, t) // Now scramble the headers and feed them in again. This should cause // the peer to be disconnected. @@ -943,5 +931,884 @@ func TestHandleHeaders(t *testing.T) { hmsg.headers.Headers[j], hmsg.headers.Headers[i] }) bm.handleHeadersMsg(hmsg) - assertPeerDisconnected(true) + assertPeerDisconnected(true, fakePeer, t) +} + +// assertPeerDisconnected asserts that the peer supplied as an argument is disconnected. +func assertPeerDisconnected(shouldBeDisconnected bool, sp *peer.Peer, t *testing.T) { + // This is quite hacky but works: We expect the peer to be + // disconnected, which sets the unexported "disconnected" field + // to 1. + refValue := reflect.ValueOf(sp).Elem() + foo := refValue.FieldByName("disconnect").Int() + + if shouldBeDisconnected { + require.EqualValues(t, 1, foo) + } else { + require.EqualValues(t, 0, foo) + } +} + +// TestBatchCheckpointedBlkHeaders tests the batch checkpointed headers function. +func TestBatchCheckpointedBlkHeaders(t *testing.T) { + t.Parallel() + + // First, we set up a block manager and a fake peer that will act as the + // test's remote peer. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + // Created checkpoints for our simulated network. + checkpoints := []chaincfg.Checkpoint{ + + { + Hash: &chainhash.Hash{1}, + Height: int32(1), + }, + + { + Hash: &chainhash.Hash{2}, + Height: int32(2), + }, + + { + Hash: &chainhash.Hash{3}, + Height: int32(3), + }, + } + + modParams := chaincfg.SimNetParams + modParams.Checkpoints = append(modParams.Checkpoints, checkpoints...) + bm.cfg.ChainParams = modParams + + // set checkpoint and header tip. + bm.nextCheckpoint = &checkpoints[0] + + bm.newHeadersMtx.Lock() + bm.headerTip = 0 + bm.headerTipHash = chainhash.Hash{0} + bm.newHeadersMtx.Unlock() + + // This is the query we assert to obtain if the function works accordingly. + expectedQuery := CheckpointedBlockHeadersQuery{ + blockMgr: bm, + msgs: []*headerQuery{ + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{{0}}), + HashStop: *checkpoints[0].Hash, + }, + startHeight: int32(0), + initialHeight: int32(0), + startHash: chainhash.Hash{0}, + endHeight: checkpoints[0].Height, + initialHash: chainhash.Hash{0}, + }, + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{checkpoints[0].Hash}), + HashStop: *checkpoints[1].Hash, + }, + startHeight: checkpoints[0].Height, + initialHeight: checkpoints[0].Height, + startHash: *checkpoints[0].Hash, + endHeight: checkpoints[1].Height, + initialHash: *checkpoints[0].Hash, + }, + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{checkpoints[1].Hash}), + HashStop: *checkpoints[2].Hash, + }, + startHeight: checkpoints[1].Height, + initialHeight: checkpoints[1].Height, + startHash: *checkpoints[1].Hash, + endHeight: checkpoints[2].Height, + initialHash: *checkpoints[1].Hash, + }, + + { + + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{checkpoints[2].Hash}), + HashStop: zeroHash, + }, + startHeight: checkpoints[2].Height, + initialHeight: checkpoints[2].Height, + startHash: *checkpoints[2].Hash, + endHeight: checkpoints[2].Height + wire.MaxBlockHeadersPerMsg, + initialHash: *checkpoints[2].Hash, + }, + }, + } + + // create request. + expectedRequest := expectedQuery.requests() + + bm.cfg.blkHdrCheckptQueryDispatcher.(*mockDispatcher).query = func(requests []*query.Request, + options ...query.QueryOption) chan error { + + // assert that the requests obtained has same length as that of our expected query. + if len(requests) != len(expectedRequest) { + t.Fatalf("unequal length") + } + + for i, req := range requests { + testEqualReqMessage(req, expectedRequest[i], t) + } + + // Ensure the query options sent by query is four. This is the number of query option supplied as args while + // querying the workmanager. + if len(options) != 3 { + t.Fatalf("expected five option parameter for query but got, %v\n", len(options)) + } + return nil + } + + // call the function that we are testing. + bm.batchCheckpointedBlkHeaders() +} + +// This function tests the ProcessBlKHeaderInCheckPtRegionInOrder function. +func TestProcessBlKHeaderInCheckPtRegionInOrder(t *testing.T) { + t.Parallel() + + // First, we set up a block manager and a fake peer that will act as the + // test's remote peer. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") + require.NoError(t, err) + + // We'll want to use actual, real blocks, so we take a miner harness + // that we can use to generate some. + harness, err := rpctest.New( + &chaincfg.SimNetParams, nil, []string{"--txindex"}, "", + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, harness.TearDown()) + }) + + err = harness.SetUp(false, 0) + require.NoError(t, err) + + // Generate 30 valid blocks that we then feed to the block manager. + blockHashes, err := harness.Client.Generate(30) + require.NoError(t, err) + + // This is the headerMessage containing 10 headers starting at height 0. + hmsgTip0 := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 10), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + // This is the headerMessage containing 10 headers starting at height 10. + hmsgTip10 := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 10), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + // This is the headerMessage containing 10 headers starting at height 20. + hmsgTip20 := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 10), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + // Loop through the generated blockHashes and add headers to their appropriate slices. + for i := range blockHashes { + header, err := harness.Client.GetBlockHeader(blockHashes[i]) + require.NoError(t, err) + + if i < 10 { + hmsgTip0.headers.Headers[i] = header + } + + if i >= 10 && i < 20 { + hmsgTip10.headers.Headers[i-10] = header + } + + if i >= 20 { + hmsgTip20.headers.Headers[i-20] = header + } + } + + // initialize the hdrTipSlice. + bm.hdrTipSlice = make([]int32, 0) + + // Create checkpoint for our test chain. + checkpoint := chaincfg.Checkpoint{ + Hash: blockHashes[29], + Height: int32(30), + } + bm.cfg.ChainParams.Checkpoints = append(bm.cfg.ChainParams.Checkpoints, []chaincfg.Checkpoint{ + checkpoint, + }...) + bm.nextCheckpoint = &checkpoint + + // If ProcessBlKHeaderInCheckPtRegionInOrder loop receives invalid headers assert the query parameters being sent + // to the workmanager is expected. + bm.cfg.blkHdrCheckptQueryDispatcher.(*mockDispatcher).query = func(requests []*query.Request, + options ...query.QueryOption) chan error { + + // The function should send only one request. + if len(requests) != 1 { + t.Fatalf("expected only one request") + } + + finalNode := bm.headerList.Back() + newHdrTip := finalNode.Height + newHdrTipHash := finalNode.Header.BlockHash() + prevCheckPt := bm.findPreviousHeaderCheckpoint(newHdrTip) + + testEqualReqMessage(requests[0], &query.Request{ + + Req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{&newHdrTipHash}, + HashStop: *bm.nextCheckpoint.Hash, + }, + startHeight: newHdrTip, + initialHeight: prevCheckPt.Height, + startHash: newHdrTipHash, + endHeight: bm.nextCheckpoint.Height, + initialHash: newHdrTipHash, + index: 0, + }, + }, t) + + // The function should include only four query options while querying. + if len(options) != 3 { + t.Fatalf("expected three option parameter for query but got, %v\n", len(options)) + } + return nil + } + + // Call the function in a goroutine. + go bm.processBlKHeaderInCheckPtRegionInOrder() + + // At this point syncPeer should be nil. + bm.syncPeerMutex.RLock() + if bm.syncPeer != nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should be nil initially") + } + bm.syncPeerMutex.RUnlock() + + // Set header tip to zero and write a response at height 10, ensure the ProcessBlKHeaderInCheckPtRegionInOrder loop + // does not handle the response as it does not correspond to the current header tip. + bm.newHeadersMtx.Lock() + bm.headerTip = 0 + bm.newHeadersMtx.Unlock() + + bm.writeBatchMtx.Lock() + newTipWrite := int32(10) + bm.hdrTipToResponse[newTipWrite] = hmsgTip10 + i := sort.Search(len(bm.hdrTipSlice), func(i int) bool { + return bm.hdrTipSlice[i] >= newTipWrite + }) + + bm.hdrTipSlice = append(bm.hdrTipSlice[:i], append([]int32{newTipWrite}, bm.hdrTipSlice[i:]...)...) + bm.writeBatchMtx.Unlock() + + // SyncPeer should still be nil to indicate that the loop did not handle the response. + bm.syncPeerMutex.RLock() + if bm.syncPeer != nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should be nil") + } + bm.syncPeerMutex.RUnlock() + + // Set header tip to 20 to indicate that even when the chain's tip is higher that the available tips in the + // hdrTipToResponse map, the loop does not still handle it. + bm.newHeadersMtx.Lock() + bm.headerTip = 20 + bm.newHeadersMtx.Unlock() + + // SyncPeer should still be nil to indicate that the loop did not handle the response. + bm.syncPeerMutex.RLock() + if bm.syncPeer != nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should be nil") + } + bm.syncPeerMutex.RUnlock() + + // Set headerTip to zero and write a response at height 0 to the hdrTipToResponse map. The loop should handle this + // response now and the following response that would correspond to its new tip after this. + bm.newHeadersMtx.Lock() + bm.headerTip = 0 + bm.newHeadersMtx.Unlock() + + bm.writeBatchMtx.Lock() + newTipWrite = int32(0) + i = sort.Search(len(bm.hdrTipSlice), func(i int) bool { + return bm.hdrTipSlice[i] >= newTipWrite + }) + + bm.hdrTipSlice = append(bm.hdrTipSlice[:i], append([]int32{newTipWrite}, bm.hdrTipSlice[i:]...)...) + + bm.hdrTipToResponse[newTipWrite] = hmsgTip0 + bm.writeBatchMtx.Unlock() + + // Allow time for handling the response. + time.Sleep(1 * time.Second) + bm.syncPeerMutex.RLock() + if bm.syncPeer == nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should not be nil") + } + bm.syncPeerMutex.RUnlock() + + // Header tip should be 20 as th the loop would handle response at height 0 then the previously written + // height 10. + bm.newHeadersMtx.RLock() + if bm.headerTip != 20 { + hdrTip := bm.headerTip + bm.newHeadersMtx.RUnlock() + t.Fatalf("expected header tip at 10 but got %v\n", hdrTip) + } + bm.newHeadersMtx.RUnlock() + + // Now scramble the headers and feed them in again. This should cause + // the loop to delete this response from the map and re-request for this header from + // the workmanager. + rand.Shuffle(len(hmsgTip20.headers.Headers), func(i, j int) { + hmsgTip20.headers.Headers[i], hmsgTip20.headers.Headers[j] = + hmsgTip20.headers.Headers[j], hmsgTip20.headers.Headers[i] + }) + + // Write this header at height 20, this would cause the loop to handle it. + bm.writeBatchMtx.Lock() + newTipWrite = int32(20) + bm.hdrTipToResponse[newTipWrite] = hmsgTip20 + i = sort.Search(len(bm.hdrTipSlice), func(i int) bool { + return bm.hdrTipSlice[i] >= newTipWrite + }) + + bm.hdrTipSlice = append(bm.hdrTipSlice[:i], append([]int32{newTipWrite}, bm.hdrTipSlice[i:]...)...) + + bm.writeBatchMtx.Unlock() + + // Allow time for handling. + time.Sleep(1 * time.Second) + + // HeadrTip should not advance as headers are invalid. + bm.newHeadersMtx.RLock() + if bm.headerTip != 20 { + hdrTip := bm.headerTip + bm.newHeadersMtx.RUnlock() + t.Fatalf("expected header tip at 20 but got %v\n", hdrTip) + } + bm.newHeadersMtx.RUnlock() + + // Syncpeer should not be nil as we are still in the loop. + bm.syncPeerMutex.RLock() + if bm.syncPeer == nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should not be nil") + } + bm.syncPeerMutex.RUnlock() + + // The response at header tip 20 should be deleted. + bm.writeBatchMtx.RLock() + _, ok := bm.hdrTipToResponse[int32(20)] + bm.writeBatchMtx.RUnlock() + + if ok { + t.Fatalf("expected response to header tip deleted") + } +} + +// TestCheckpointedBlockHeadersQuery_handleResponse tests the handleResponse method +// of the CheckpointedBlockHeadersQuery. +func TestCheckpointedBlockHeadersQuery_handleResponse(t *testing.T) { + t.Parallel() + + // handleRespTestCase holds all the information required to test different scenarios while + // using the function. + type handleRespTestCase struct { + + // name of the testcase. + name string + + // resp is the response argument to be sent to the handleResp method as an arg. + resp wire.Message + + // req is the request method to be sent to the handleResp method as an arg. + req query.ReqMessage + + // progress is the expected progress to be returned by the handleResp method. + progress query.Progress + + // lastblock is the block with which we obtain its hash to be used as the request's hashStop. + lastBlock wire.BlockHeader + + // peerDisconnected indicates if the peer would be disconnected after the handleResp method is done. + peerDisconnected bool + } + + testCases := []handleRespTestCase{ + + { + // Scenario in which we have a request type that is not the same as the expected headerQuery type.It should + // return no error and NoProgressNoFinalResp query.Progress. + name: "invalid request type", + resp: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 0, 5), + }, + req: &encodedQuery{}, + progress: query.NoResponse, + }, + + { + // Scenario in which we have a response type that is not same as the expected wire.MsgHeaders. It should + // return no error and NoProgressNoFinalResp query.Progress. + name: "invalid response type", + resp: &wire.MsgCFHeaders{}, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 6, + initialHash: chainhash.Hash{1}, + }, + progress: query.NoResponse, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + }, + + { + // Scenario in which we have the response in the hdrTipResponseMap. While calling these testcases, we + // initialize the hdrTipToResponse map to contain a response at height 0 and 6. Since this request ahs a + // startheight of 0, its response would be in the map already, aligning with this scenario. This scenario + // should return query.IgnoreRequest, + name: "response start Height in hdrTipResponse map", + resp: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 0, 4), + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {0}, + }, + }, + startHeight: 0, + initialHeight: 0, + startHash: chainhash.Hash{0}, + endHeight: 5, + initialHash: chainhash.Hash{0}, + }, + progress: query.IgnoreRequest, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + }, + + { + // Scenario in which the valid response we receive is of length, zero. We should return + // query.ResponseErr. + name: "response header length 0", + resp: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 0), + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 5, + initialHash: chainhash.Hash{1}, + }, + progress: query.ResponseErr, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + }, + + { + // Scenario in which the response received is at the request's initialHeight (lower bound height in + // checkpoint request) but its first block's previous hash is not same as the checkpoint hash. It + // should return query.ResponseErr. + name: "response at initialHash has disconnected start Hash", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 5, + initialHash: chainhash.Hash{1}, + }, + progress: query.ResponseErr, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + peerDisconnected: true, + }, + + { + // Scenario in which the response is not at the initial Hash (lower bound hash in the + // checkpoint request) but the response is complete and valid. It should return query.Finished. + name: "response not at initialHash, valid complete headers", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {2}, + }, + }, + startHeight: 2, + initialHeight: 1, + startHash: chainhash.Hash{2}, + endHeight: 5, + initialHash: chainhash.Hash{2}, + }, + progress: query.Finished, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + }, + + { + // Scenario in which the response is not at initial hash (lower bound height in + // checkpoint request) and the response is unfinished. The jobErr should be nil and return + // finalRespNoProgress query.progress. + name: "response not at initial Hash, unfinished response", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {2}, + }, + }, + startHeight: 2, + initialHeight: 1, + startHash: chainhash.Hash{2}, + endHeight: 6, + initialHash: chainhash.Hash{2}, + }, + progress: query.UnFinishedRequest, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + }, + + { + // Scenario in which the response length is greater than expected. Peer should be + // disconnected and the method should return query.ResponseErr. + name: "response header length more than expected", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + { + PrevBlock: chainhash.Hash{5}, + }, + { + PrevBlock: chainhash.Hash{6}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 6, + initialHash: chainhash.Hash{1}, + }, + progress: query.ResponseErr, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + peerDisconnected: true, + }, + + { + // Scenario in which response is complete and a valid header. Its start height is at the initial height. + // progress should be query.Finished. + name: "complete response valid headers", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 4, + initialHash: chainhash.Hash{1}, + }, + progress: query.Finished, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{3}, + }, + }, + + { + // Scenario in which response is at initialHash and the response is incomplete. + // It should return query.UnFinishedRequest. + name: "response at initial hash, incomplete response, valid headers", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 6, + initialHash: chainhash.Hash{1}, + }, + progress: query.UnFinishedRequest, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + }, + + { + // Scenario in which response is incomplete but valid. The new response's start height created in this + // scenario is present in the hdrTipResponseMap. The startHeight is 6 and response at height 6 has been + // preveiously written in to the hdrTipResponse map for the sake of this test. + name: "incomplete response, valid headers, new resp in hdrTipToResponse map", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + { + PrevBlock: chainhash.Hash{5}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 10, + initialHash: chainhash.Hash{1}, + }, + progress: query.Finished, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{9}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // set up block manager. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + var oldReqStartHeight int32 + + bm.hdrTipToResponse[0] = &headersMsg{ + headers: &wire.MsgHeaders{}, + } + bm.hdrTipToResponse[6] = &headersMsg{ + headers: &wire.MsgHeaders{}, + } + + fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") + require.NoError(t, err) + + blkHdrquery := &CheckpointedBlockHeadersQuery{ + blockMgr: bm, + } + req := tc.req + r, ok := tc.req.(*headerQuery) + if ok { + reqMessage, ok := req.Message().(*wire.MsgGetHeaders) + if !ok { + t.Fatalf("request message not of type wire.MsgGetHeaders") + } + reqMessage.HashStop = tc.lastBlock.BlockHash() + req = r + oldReqStartHeight = r.startHeight + } + actualProgress := blkHdrquery.handleResponse(req, tc.resp, &ServerPeer{ + Peer: fakePeer, + }) + + if tc.progress != actualProgress { + t.Fatalf("unexpected progress.Expected: %v but got: %v", tc.progress, actualProgress) + } + + if actualProgress == query.UnFinishedRequest { + resp := tc.resp.(*wire.MsgHeaders) + request := req.(*headerQuery) + if request.startHash != resp.Headers[len(resp.Headers)-1].BlockHash() { + t.Fatalf("unexpected new startHash") + } + + if request.startHeight != oldReqStartHeight+int32(len(resp.Headers)) { + t.Fatalf("unexpected new start height") + } + + requestMessage := req.Message().(*wire.MsgGetHeaders) + + if *requestMessage.BlockLocatorHashes[0] != request.startHash { + t.Fatalf("unexpected new blockLocator") + } + } + + assertPeerDisconnected(tc.peerDisconnected, fakePeer, t) + }) + } +} + +// testEqualReqMessage tests if two query.Request are same. +func testEqualReqMessage(a, b *query.Request, t *testing.T) { + aMessage := a.Req.(*headerQuery) + bMessage := b.Req.(*headerQuery) + + if aMessage.startHeight != bMessage.startHeight { + t.Fatalf("dissimilar startHeight") + } + if aMessage.startHash != bMessage.startHash { + t.Fatalf("dissimilar startHash") + } + if aMessage.endHeight != bMessage.endHeight { + t.Fatalf("dissimilar endHash") + } + if aMessage.initialHash != bMessage.initialHash { + t.Fatalf("dissimilar initialHash") + } + + aMessageGetHeaders := aMessage.Message().(*wire.MsgGetHeaders) + bMessageGetHeaders := bMessage.Message().(*wire.MsgGetHeaders) + + if !reflect.DeepEqual(aMessageGetHeaders.BlockLocatorHashes, bMessageGetHeaders.BlockLocatorHashes) { + t.Fatalf("dissimilar blocklocator hash") + } + + if aMessageGetHeaders.HashStop != bMessageGetHeaders.HashStop { + t.Fatalf("dissimilar hashstop") + } + if a.Req.PriorityIndex() != b.Req.PriorityIndex() { + t.Fatalf("dissimilar priority index") + } } diff --git a/neutrino.go b/neutrino.go index 7ee45edd0..5179e0d9d 100644 --- a/neutrino.go +++ b/neutrino.go @@ -195,6 +195,31 @@ func NewServerPeer(s *ChainService, isPersistent bool) *ServerPeer { } } +// IsSyncCandidate returns whether or not the peer is a candidate to consider +// syncing from. +func (sp *ServerPeer) IsSyncCandidate() bool { + // The peer is not a candidate for sync if it's not a full node. + return sp.Services()&wire.SFNodeNetwork == wire.SFNodeNetwork +} + +// IsPeerBehindStartHeight returns a boolean indicating if the peer's last block height +// is behind the start height of the request. If the peer is not behind the request start +// height false is returned, otherwise, true is. +func (sp *ServerPeer) IsPeerBehindStartHeight(req query.ReqMessage) bool { + queryGetHeaders, ok := req.(*headerQuery) + + if !ok { + log.Debugf("request is not type headerQuery") + + return true + } + + if sp.LastBlock() < queryGetHeaders.startHeight { + return true + } + return false +} + // newestBlock returns the current best block hash and height using the format // required by the configuration for the peer package. func (sp *ServerPeer) newestBlock() (*chainhash.Hash, int32, error) { @@ -800,15 +825,21 @@ func NewChainService(cfg Config) (*ChainService, error) { } bm, err := newBlockManager(&blockManagerCfg{ - ChainParams: s.chainParams, - BlockHeaders: s.BlockHeaders, - RegFilterHeaders: s.RegFilterHeaders, - TimeSource: s.timeSource, - QueryDispatcher: s.workManager, - BanPeer: s.BanPeer, - GetBlock: s.GetBlock, - firstPeerSignal: s.firstPeerConnect, - queryAllPeers: s.queryAllPeers, + ChainParams: s.chainParams, + BlockHeaders: s.BlockHeaders, + RegFilterHeaders: s.RegFilterHeaders, + TimeSource: s.timeSource, + cfHeaderQueryDispatcher: s.workManager, + BanPeer: s.BanPeer, + GetBlock: s.GetBlock, + firstPeerSignal: s.firstPeerConnect, + queryAllPeers: s.queryAllPeers, + blkHdrCheckptQueryDispatcher: query.NewWorkManager(&query.Config{ + ConnectedPeers: s.ConnectedPeers, + NewWorker: query.NewWorker, + Ranking: query.NewPeerRanking(), + IsEligibleWorkerFunc: query.IsWorkerEligibleForBlkHdrFetch, + }), }) if err != nil { return nil, err @@ -1610,6 +1641,9 @@ func (s *ChainService) Start() error { s.addrManager.Start() s.blockManager.Start() s.blockSubscriptionMgr.Start() + if err := s.blockManager.cfg.blkHdrCheckptQueryDispatcher.Start(); err != nil { + return fmt.Errorf("unable to start block header work manager: %v", err) + } if err := s.workManager.Start(); err != nil { return fmt.Errorf("unable to start work manager: %v", err) } diff --git a/query/interface.go b/query/interface.go index 70b82cbcf..1a6517d4b 100644 --- a/query/interface.go +++ b/query/interface.go @@ -224,4 +224,11 @@ type Peer interface { // OnDisconnect returns a channel that will be closed when this peer is // disconnected. OnDisconnect() <-chan struct{} + + // IsPeerBehindStartHeight returns a boolean indicating if the peer's known last height is behind + // the request's start Height which it receives as an argument. + IsPeerBehindStartHeight(req ReqMessage) bool + + // IsSyncCandidate returns true if the peer is a sync candidate. + IsSyncCandidate() bool } diff --git a/query/worker.go b/query/worker.go index 82c4948c6..dec49ae00 100644 --- a/query/worker.go +++ b/query/worker.go @@ -95,7 +95,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { msgChan, cancel := peer.SubscribeRecvMsg() defer cancel() -nexJobLoop: +nextJobLoop: for { log.Tracef("Worker %v waiting for more work", peer.Addr()) @@ -154,7 +154,7 @@ nexJobLoop: case <-quit: return } - goto nexJobLoop + goto nextJobLoop } } @@ -308,6 +308,28 @@ nexJobLoop: } } +func (w *worker) IsSyncCandidate() bool { + return w.peer.IsSyncCandidate() +} + +func (w *worker) IsPeerBehindStartHeight(req ReqMessage) bool { + return w.peer.IsPeerBehindStartHeight(req) +} + +// IsWorkerEligibleForBlkHdrFetch is the eligibility function used for the BlockHdrWorkManager to determine workers +// eligible to receive jobs (the job is to fetch headers). If the peer is not a sync candidate or if its last known +// block height is behind the job query's start height, it returns false. Otherwise, it returns true. +func IsWorkerEligibleForBlkHdrFetch(r *activeWorker, next *queryJob) bool { + if !r.w.IsSyncCandidate() { + return false + } + + if r.w.IsPeerBehindStartHeight(next.Req) { + return false + } + return true +} + // NewJob returns a channel where work that is to be handled by the worker can // be sent. If the worker reads a queryJob from this channel, it is guaranteed // that a response will eventually be deliverd on the results channel (except diff --git a/query/worker_test.go b/query/worker_test.go index 84e06659c..dcb640f3b 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -10,9 +10,10 @@ import ( ) type mockQueryEncoded struct { - message *wire.MsgGetData - encoding wire.MessageEncoding - index float64 + message *wire.MsgGetData + encoding wire.MessageEncoding + index float64 + startHeight int } func (m *mockQueryEncoded) Message() wire.Message { @@ -52,6 +53,8 @@ type mockPeer struct { responses chan<- wire.Message subscriptions chan chan wire.Message quit chan struct{} + bestHeight int + fullNode bool err error } @@ -72,6 +75,15 @@ func (m *mockPeer) Addr() string { return m.addr } +func (m *mockPeer) IsPeerBehindStartHeight(request ReqMessage) bool { + r := request.(*mockQueryEncoded) + return m.bestHeight < r.startHeight +} + +func (m *mockPeer) IsSyncCandidate() bool { + return m.fullNode +} + // makeJob returns a new query job that will be done when it is given the // finalResp message. Similarly ot will progress on being given the // progressResp message, while any other message will be ignored. diff --git a/query/workmanager.go b/query/workmanager.go index 8dffec1eb..0f691a966 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -46,6 +46,13 @@ type Worker interface { // delivered on the results channel (except when the quit channel has // been closed). NewJob() chan<- *queryJob + + // IsPeerBehindStartHeight returns a boolean indicating if the peer's known last height is behind + // the request's start Height which it receives as an argument. + IsPeerBehindStartHeight(req ReqMessage) bool + + // IsSyncCandidate returns if the peer is a sync candidate. + IsSyncCandidate() bool } // PeerRanking is an interface that must be satisfied by the underlying module diff --git a/query/workmanager_test.go b/query/workmanager_test.go index c8483d5af..075d14637 100644 --- a/query/workmanager_test.go +++ b/query/workmanager_test.go @@ -15,6 +15,14 @@ type mockWorker struct { results chan *jobResult } +func (m *mockWorker) IsPeerBehindStartHeight(req ReqMessage) bool { + return m.peer.IsPeerBehindStartHeight(req) +} + +func (m *mockWorker) IsSyncCandidate() bool { + return m.peer.IsSyncCandidate() +} + var _ Worker = (*mockWorker)(nil) func (m *mockWorker) NewJob() chan<- *queryJob { @@ -985,3 +993,105 @@ func TestWorkManagerResultUnfinished(t *testing.T) { t.Fatalf("nothing received on errChan") } } + +// TestIsWorkerEligibleForBlkHdrFetch tests the IsWorkerEligibleForBlkHdrFetch function. +func TestIsWorkerEligibleForBlkHdrFetch(t *testing.T) { + type testArgs struct { + name string + activeWorker *activeWorker + job *queryJob + expectedEligibility bool + } + + testCases := []testArgs{ + { + name: "peer sync candidate, best height behind job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 5, + fullNode: true, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 10, + }, + }, + }, + expectedEligibility: false, + }, + + { + name: "peer sync candidate, best height ahead job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 10, + fullNode: true, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 5, + }, + }, + }, + expectedEligibility: true, + }, + + { + name: "peer not sync candidate, best height behind job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 5, + fullNode: false, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 10, + }, + }, + }, + expectedEligibility: false, + }, + + { + name: "peer not sync candidate, best height ahead job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 10, + fullNode: false, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 5, + }, + }, + }, + expectedEligibility: false, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + isEligible := IsWorkerEligibleForBlkHdrFetch(test.activeWorker, test.job) + if isEligible != test.expectedEligibility { + t.Fatalf("Expected '%v'for eligibility check but got"+ + "'%v'\n", test.expectedEligibility, isEligible) + } + }) + } +}