diff --git a/blockmanager.go b/blockmanager.go index 9c9e6f4c..e517f1c6 100644 --- a/blockmanager.go +++ b/blockmanager.go @@ -5,9 +5,11 @@ package neutrino import ( "bytes" "container/list" + "errors" "fmt" "math" "math/big" + "sort" "sync" "sync/atomic" "time" @@ -90,8 +92,13 @@ type blockManagerCfg struct { // the connected peers. TimeSource blockchain.MedianTimeSource - // QueryDispatcher is used to make queries to connected Bitcoin peers. - QueryDispatcher query.Dispatcher + // cfHeaderQueryDispatcher is used to make queries to connected Bitcoin peers to fetch + // CFHeaders + cfHeaderQueryDispatcher query.Dispatcher + + // cfHeaderQueryDispatcher is used to make queries to connected Bitcoin peers to fetch + // block Headers + blkHdrCheckptQueryDispatcher query.WorkManager // BanPeer bans and disconnects the given peer. BanPeer func(addr string, reason banman.Reason) error @@ -173,6 +180,10 @@ type blockManager struct { // nolint:maligned // time, newHeadersMtx should always be acquired first. newFilterHeadersMtx sync.RWMutex + // writeBatchMtx is the mutex used to hold reading and reading and writing into the + // hdrTipToResponse map. + writeBatchMtx sync.RWMutex + // newFilterHeadersSignal is condition variable which will be used to // notify any waiting callers (via Broadcast()) that the tip of the // current filter header chain has changed. This is useful when callers @@ -206,6 +217,18 @@ type blockManager struct { // nolint:maligned minRetargetTimespan int64 // target timespan / adjustment factor maxRetargetTimespan int64 // target timespan * adjustment factor blocksPerRetarget int32 // target timespan / target time per block + + // hdrTipToResponse is a map that holds the response gotten from querying peers + // using the workmanager, to fetch headers within the chain's checkpointed region. + // It is a map of the request's startheight to the fetch response. + hdrTipToResponse map[int32]*headersMsg + + // hdrTipSlice is a slice that holds request startHeight of the responses that have been + // fetched using the workmanager to fetch headers within the chain's checkpointed region. + // It is used to easily access this startheight in the case we have to delete these responses + // in the hdrTipResponse map during a reorg while fetching headers within the chain's checkpointed + // region. + hdrTipSlice []int32 } // newBlockManager returns a new bitcoin block manager. Use Start to begin @@ -235,6 +258,8 @@ func newBlockManager(cfg *blockManagerCfg) (*blockManager, error) { blocksPerRetarget: int32(targetTimespan / targetTimePerBlock), minRetargetTimespan: targetTimespan / adjustmentFactor, maxRetargetTimespan: targetTimespan * adjustmentFactor, + hdrTipToResponse: make(map[int32]*headersMsg), + hdrTipSlice: make([]int32, 0), } // Next we'll create the two signals that goroutines will use to wait @@ -290,8 +315,21 @@ func (b *blockManager) Start() { } log.Trace("Starting block manager") - b.wg.Add(2) + b.wg.Add(3) go b.blockHandler() + go func() { + wm := b.cfg.blkHdrCheckptQueryDispatcher + + defer b.wg.Done() + defer func(wm query.WorkManager) { + err := wm.Stop() + if err != nil { + log.Errorf("Unable to stop block header workmanager: %v", err) + } + }(wm) + + b.processBlKHeaderInCheckPtRegionInOrder() + }() go func() { defer b.wg.Done() @@ -305,6 +343,12 @@ func (b *blockManager) Start() { return } + checkpoints := b.cfg.ChainParams.Checkpoints + numCheckpts := len(checkpoints) + if numCheckpts != 0 && b.nextCheckpoint != nil { + b.batchCheckpointedBlkHeaders() + } + log.Debug("Peer connected, starting cfHandler.") b.cfHandler() }() @@ -360,6 +404,19 @@ func (b *blockManager) NewPeer(sp *ServerPeer) { } } +// addNewPeerToList adds the peer to the peers list. +func (b *blockManager) addNewPeerToList(peers *list.List, sp *ServerPeer) { + // Ignore if in the process of shutting down. + if atomic.LoadInt32(&b.shutdown) != 0 { + return + } + + log.Infof("New valid peer %s (%s)", sp, sp.UserAgent()) + + // Add the peer as a candidate to sync from. + peers.PushBack(sp) +} + // handleNewPeerMsg deals with new peers that have signalled they may be // considered as a sync peer (they have already successfully negotiated). It // also starts syncing if needed. It is invoked from the syncHandler @@ -373,12 +430,12 @@ func (b *blockManager) handleNewPeerMsg(peers *list.List, sp *ServerPeer) { log.Infof("New valid peer %s (%s)", sp, sp.UserAgent()) // Ignore the peer if it's not a sync candidate. - if !b.isSyncCandidate(sp) { + if !sp.IsSyncCandidate() { return } // Add the peer as a candidate to sync from. - peers.PushBack(sp) + b.addNewPeerToList(peers, sp) // If we're current with our sync peer and the new peer is advertising // a higher block than the newest one we know of, request headers from @@ -418,11 +475,8 @@ func (b *blockManager) DonePeer(sp *ServerPeer) { } } -// handleDonePeerMsg deals with peers that have signalled they are done. It -// removes the peer as a candidate for syncing and in the case where it was the -// current sync peer, attempts to select a new best peer to sync from. It is -// invoked from the syncHandler goroutine. -func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { +// removeDonePeerFromList removes the peer from the peers list. +func (b *blockManager) removeDonePeerFromList(peers *list.List, sp *ServerPeer) { // Remove the peer from the list of candidate peers. for e := peers.Front(); e != nil; e = e.Next() { if e.Value == sp { @@ -432,6 +486,17 @@ func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { } log.Infof("Lost peer %s", sp) +} + +// handleDonePeerMsg deals with peers that have signalled they are done. It +// removes the peer as a candidate for syncing and in the case where it was the +// current sync peer, attempts to select a new best peer to sync from. It is +// invoked from the syncHandler goroutine. +func (b *blockManager) handleDonePeerMsg(peers *list.List, sp *ServerPeer) { + // Remove the peer from the list of candidate peers. + b.removeDonePeerFromList(peers, sp) + + log.Infof("Lost peer %s", sp) // Attempt to find a new peer to sync from if the quitting peer is the // sync peer. Also, reset the header state. @@ -808,12 +873,31 @@ func (b *blockManager) getUncheckpointedCFHeaders( // handle a query for checkpointed filter headers. type checkpointedCFHeadersQuery struct { blockMgr *blockManager - msgs []wire.Message + msgs []*encodedQuery checkpoints []*chainhash.Hash stopHashes map[chainhash.Hash]uint32 headerChan chan *wire.MsgCFHeaders } +// encodedQuery holds all the information needed to query a message that pushes requests +// using the QueryMessagingWithEncoding method. +type encodedQuery struct { + message wire.Message + encoding wire.MessageEncoding + priorityIndex float64 +} + +// Message returns the wire.Message of encodedQuery's struct. +func (e *encodedQuery) Message() wire.Message { + return e.message +} + +// PriorityIndex returns the specified priority the caller wants +// the request to take. +func (e *encodedQuery) PriorityIndex() float64 { + return e.priorityIndex +} + // requests creates the query.Requests for this CF headers query. func (c *checkpointedCFHeadersQuery) requests() []*query.Request { reqs := make([]*query.Request, len(c.msgs)) @@ -821,6 +905,8 @@ func (c *checkpointedCFHeadersQuery) requests() []*query.Request { reqs[idx] = &query.Request{ Req: m, HandleResp: c.handleResponse, + SendQuery: sendQueryMessageWithEncoding, + CloneReq: cloneMsgCFHeaders, } } return reqs @@ -828,43 +914,37 @@ func (c *checkpointedCFHeadersQuery) requests() []*query.Request { // handleResponse is the internal response handler used for requests for this // CFHeaders query. -func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, - peerAddr string) query.Progress { +func (c *checkpointedCFHeadersQuery) handleResponse(request query.ReqMessage, resp wire.Message, + peer query.Peer) query.Progress { + + peerAddr := "" + if peer != nil { + peerAddr = peer.Addr() + } + req := request.Message() r, ok := resp.(*wire.MsgCFHeaders) if !ok { // We are only looking for cfheaders messages. - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } q, ok := req.(*wire.MsgGetCFHeaders) if !ok { // We sent a getcfheaders message, so that's what we should be // comparing against. - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } // The response doesn't match the query. if q.FilterType != r.FilterType || q.StopHash != r.StopHash { - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } checkPointIndex, ok := c.stopHashes[r.StopHash] if !ok { // We never requested a matching stop hash. - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } // Use either the genesis header or the previous checkpoint index as @@ -898,10 +978,7 @@ func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, log.Errorf("Unable to ban peer %v: %v", peerAddr, err) } - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } // At this point, the response matches the query, and the relevant @@ -912,16 +989,47 @@ func (c *checkpointedCFHeadersQuery) handleResponse(req, resp wire.Message, select { case c.headerChan <- r: case <-c.blockMgr.quit: - return query.Progress{ - Finished: false, - Progressed: false, - } + return query.NoResponse } - return query.Progress{ - Finished: true, - Progressed: true, + return query.Finished +} + +// sendQueryMessageWithEncoding sends a message to the peer with encoding. +func sendQueryMessageWithEncoding(peer query.Peer, req query.ReqMessage) error { + sp, ok := peer.(*ServerPeer) + if !ok { + err := "peer is not of type ServerPeer" + log.Errorf(err) + return errors.New(err) + } + request, ok := req.(*encodedQuery) + if !ok { + return errors.New("invalid request type") } + sp.QueueMessageWithEncoding(request.message, nil, request.encoding) + + return nil +} + +// cloneMsgCFHeaders clones query.ReqMessage that contains the MsgGetCFHeaders message. +func cloneMsgCFHeaders(req query.ReqMessage) query.ReqMessage { + oldReq, ok := req.(*encodedQuery) + if !ok { + log.Errorf("request not of type *encodedQuery") + } + oldReqMessage, ok := oldReq.message.(*wire.MsgGetCFHeaders) + if !ok { + log.Errorf("request not of type *wire.MsgGetCFHeaders") + } + newReq := &encodedQuery{ + message: wire.NewMsgGetCFHeaders( + oldReqMessage.FilterType, oldReqMessage.StartHeight, &oldReqMessage.StopHash, + ), + encoding: oldReq.encoding, + priorityIndex: oldReq.priorityIndex, + } + return newReq } // getCheckpointedCFHeaders catches a filter header store up with the @@ -959,7 +1067,7 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // the remaining checkpoint intervals. numCheckpts := uint32(len(checkpoints)) - startingInterval numQueries := (numCheckpts + maxCFCheckptsPerQuery - 1) / maxCFCheckptsPerQuery - queryMsgs := make([]wire.Message, 0, numQueries) + queryMsgs := make([]*encodedQuery, 0, numQueries) // We'll also create an additional set of maps that we'll use to // re-order the responses as we get them in. @@ -1004,9 +1112,12 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // Once we have the stop hash, we can construct the query // message itself. - queryMsg := wire.NewMsgGetCFHeaders( - fType, startHeightRange, &stopHash, - ) + queryMsg := &encodedQuery{ + message: wire.NewMsgGetCFHeaders( + fType, startHeightRange, &stopHash, + ), + encoding: wire.WitnessEncoding, + } // We'll mark that the ith interval is queried by this message, // and also map the stop hash back to the index of this message. @@ -1043,8 +1154,8 @@ func (b *blockManager) getCheckpointedCFHeaders(checkpoints []*chainhash.Hash, // Hand the queries to the work manager, and consume the verified // responses as they come back. - errChan := b.cfg.QueryDispatcher.Query( - q.requests(), query.Cancel(b.quit), query.NoRetryMax(), + errChan := b.cfg.cfHeaderQueryDispatcher.Query( + q.requests(), query.Cancel(b.quit), query.NoRetryMax(), query.ErrChan(make(chan error, 1)), ) // Keep waiting for more headers as long as we haven't received an @@ -1982,7 +2093,38 @@ func (b *blockManager) blockHandler() { defer b.wg.Done() candidatePeers := list.New() -out: + checkpoints := b.cfg.ChainParams.Checkpoints + if len(checkpoints) == 0 || b.nextCheckpoint == nil { + goto unCheckPtLoop + } + + // Loop to fetch headers within the check pointed range + b.newHeadersMtx.RLock() + for b.headerTip <= uint32(checkpoints[len(checkpoints)-1].Height) { + b.newHeadersMtx.RUnlock() + select { + case m := <-b.peerChan: + switch msg := m.(type) { + case *newPeerMsg: + b.addNewPeerToList(candidatePeers, msg.peer) + case *donePeerMsg: + b.removeDonePeerFromList(candidatePeers, msg.peer) + default: + log.Tracef("Invalid message type in block "+ + "handler: %T", msg) + } + + case <-b.quit: + return + } + b.newHeadersMtx.RLock() + } + b.newHeadersMtx.RUnlock() + + log.Infof("Fetching uncheckpointed block headers from %v", b.headerTip) + b.startSync(candidatePeers) + +unCheckPtLoop: for { // Now check peer messages and quit channels. select { @@ -2006,13 +2148,367 @@ out: } case <-b.quit: - break out + break unCheckPtLoop } } log.Trace("Block handler done") } +// processBlKHeaderInCheckPtRegionInOrder handles and writes the block headers received from querying the +// workmanager while fetching headers within the block header checkpoint region. This process is carried out +// in order. +func (b *blockManager) processBlKHeaderInCheckPtRegionInOrder() { + lenCheckPts := len(b.cfg.ChainParams.Checkpoints) + + // Loop should run as long as we are in the block header checkpointed region. + b.newHeadersMtx.RLock() + for int32(b.headerTip) <= b.cfg.ChainParams.Checkpoints[lenCheckPts-1].Height { + hdrTip := b.headerTip + b.newHeadersMtx.RUnlock() + + select { + // return quickly if the blockmanager quits. + case <-b.quit: + return + default: + } + + // do not go further if we have not received the response mapped to our header tip. + b.writeBatchMtx.RLock() + msg, ok := b.hdrTipToResponse[int32(hdrTip)] + b.writeBatchMtx.RUnlock() + + if !ok { + b.newHeadersMtx.RLock() + continue + } + + b.syncPeerMutex.Lock() + b.syncPeer = msg.peer + b.syncPeerMutex.Unlock() + + b.handleHeadersMsg(msg) + err := b.resetHeaderListToChainTip() + if err != nil { + log.Errorf(err.Error()) + } + + finalNode := b.headerList.Back() + newHdrTip := finalNode.Height + newHdrTipHash := finalNode.Header.BlockHash() + prevCheckPt := b.findPreviousHeaderCheckpoint(newHdrTip) + + log.Tracef("New headertip %v", newHdrTip) + log.Debugf("New headertip %v", newHdrTip) + + // If our header tip has not increased, there is a problem with the headers we received and so we + // delete all the header response within our previous header tip and our new header tip, then send + // another query to the workmanager. + if uint32(newHdrTip) <= hdrTip { + b.deleteHeaderTipResp(newHdrTip, int32(hdrTip)) + + log.Tracef("while fetching checkpointed headers received invalid headers") + + q := CheckpointedBlockHeadersQuery{ + blockMgr: b, + msgs: []*headerQuery{ + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{&newHdrTipHash}, + HashStop: *b.nextCheckpoint.Hash, + }, + startHeight: newHdrTip, + initialHeight: prevCheckPt.Height, + startHash: newHdrTipHash, + endHeight: b.nextCheckpoint.Height, + initialHash: newHdrTipHash, + // Set it as high priority so that workmanager can schedule before any other queries. + index: 0, + }, + }, + } + + b.cfg.blkHdrCheckptQueryDispatcher.Query( + q.requests(), query.Cancel(b.quit), query.Timeout(1*time.Hour), query.NoRetryMax(), + ) + + log.Tracef("Sending query to workmanager from processBlKHeaderInCheckPtRegionInOrder loop") + } + b.newHeadersMtx.RLock() + } + b.newHeadersMtx.RUnlock() + + b.syncPeerMutex.Lock() + b.syncPeer = nil + b.syncPeerMutex.Unlock() + + log.Infof("Successfully completed fetching checkpointed block headers") +} + +// batchCheckpointedBlkHeaders creates headerQuery to fetch block headers +// within the chain's checkpointed region. +func (b *blockManager) batchCheckpointedBlkHeaders() { + var queryMsgs []*headerQuery + curHeight := b.headerTip + curHash := b.headerTipHash + nextCheckpoint := b.nextCheckpoint + nextCheckptHash := nextCheckpoint.Hash + nextCheckptHeight := nextCheckpoint.Height + + log.Infof("Fetching set of checkpointed blockheaders from "+ + "height=%v, hash=%v\n", curHeight, curHash) + + for nextCheckpoint != nil { + endHash := nextCheckptHash + endHeight := nextCheckptHeight + tmpCurHash := curHash + + msg := &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{&tmpCurHash}), + HashStop: *endHash, + }, + startHeight: int32(curHeight), + initialHeight: int32(curHeight), + startHash: curHash, + endHeight: endHeight, + initialHash: tmpCurHash, + } + + log.Debugf("Fetching set of checkpointed blockheaders from "+ + "start_height=%v to end-height=%v", curHeight, endHash) + + queryMsgs = append(queryMsgs, msg) + curHeight = uint32(endHeight) + curHash = *endHash + + nextCheckpoint := b.findNextHeaderCheckpoint(int32(curHeight)) + if nextCheckpoint == nil { + break + } + + nextCheckptHeight = nextCheckpoint.Height + nextCheckptHash = nextCheckpoint.Hash + } + + msg := &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{nextCheckptHash}), + HashStop: zeroHash, + }, + startHeight: nextCheckptHeight, + initialHeight: nextCheckptHeight, + startHash: *nextCheckptHash, + endHeight: nextCheckptHeight + wire.MaxBlockHeadersPerMsg, + initialHash: *nextCheckptHash, + } + + log.Debugf("Fetching set of checkpointed blockheaders from "+ + "start_height=%v to end-height=%v", curHeight, zeroHash) + + queryMsgs = append(queryMsgs, msg) + + log.Debugf("Attempting to query for %v blockheader batches", len(queryMsgs)) + + q := CheckpointedBlockHeadersQuery{ + blockMgr: b, + msgs: queryMsgs, + } + + b.cfg.blkHdrCheckptQueryDispatcher.Query( + q.requests(), query.Cancel(b.quit), query.Timeout(1*time.Hour), query.NoRetryMax(), + ) +} + +// CheckpointedBlockHeadersQuery holds all information necessary to perform and +// // handle a query for checkpointed block headers. +type CheckpointedBlockHeadersQuery struct { + blockMgr *blockManager + msgs []*headerQuery +} + +// requests creates the query.Requests for this block headers query. +func (c *CheckpointedBlockHeadersQuery) requests() []*query.Request { + reqs := make([]*query.Request, len(c.msgs)) + for idx, m := range c.msgs { + reqs[idx] = &query.Request{ + Req: m, + SendQuery: c.PushHeadersMsg, + HandleResp: c.handleResponse, + CloneReq: cloneHeaderQuery, + } + } + + return reqs +} + +// cloneHeaderQuery clones the query.ReqMessage containing the headerQuery Struct. +func cloneHeaderQuery(req query.ReqMessage) query.ReqMessage { + oldReq, ok := req.(*headerQuery) + if !ok { + log.Errorf("request not of type *wire.MsgGetHeaders") + } + oldReqMessage := req.Message().(*wire.MsgGetHeaders) + message := &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: oldReqMessage.BlockLocatorHashes, + HashStop: oldReqMessage.HashStop, + }, + startHeight: oldReq.startHeight, + initialHeight: oldReq.initialHeight, + startHash: oldReq.startHash, + endHeight: oldReq.endHeight, + } + + return message +} + +// PushHeadersMsg is the internal response handler used for requests for this +// block Headers query. +func (c *CheckpointedBlockHeadersQuery) PushHeadersMsg(peer query.Peer, + task query.ReqMessage) error { + + request, _ := task.Message().(*wire.MsgGetHeaders) + + requestMsg := task.(*headerQuery) + + // check if we have response for the query already. If we do return an error. + c.blockMgr.writeBatchMtx.RLock() + _, ok := c.blockMgr.hdrTipToResponse[requestMsg.startHeight] + c.blockMgr.writeBatchMtx.RUnlock() + if ok { + log.Debugf("Response already received, peer=%v, "+ + "start_height=%v, end_height=%v, index=%v", peer.Addr(), + requestMsg.startHeight, requestMsg.endHeight) + return query.ErrIgnoreRequest + } + + sp := peer.(*ServerPeer) + err := sp.PushGetHeadersMsg(request.BlockLocatorHashes, &request.HashStop) + if err != nil { + log.Errorf(err.Error()) + return err + } + + return nil +} + +// handleResponse is the internal response handler used for requests for this +// block header query. +func (c *CheckpointedBlockHeadersQuery) handleResponse(req query.ReqMessage, resp wire.Message, + peer query.Peer) query.Progress { + + sp := peer.(*ServerPeer) + if peer == nil { + return query.NoResponse + } + + msg, ok := resp.(*wire.MsgHeaders) + if !ok { + // We are only looking for msgHeaders messages. + return query.NoResponse + } + + request, ok := req.(*headerQuery) + if !ok { + // request should only be of type headerQuery. + return query.NoResponse + } + + // Check if we already have a response for this request startHeight, if we do modify our jobErr variable + // so that worker can send appropriate error to workmanager. + c.blockMgr.writeBatchMtx.RLock() + _, ok = c.blockMgr.hdrTipToResponse[request.startHeight] + c.blockMgr.writeBatchMtx.RUnlock() + if ok { + return query.IgnoreRequest + } + + // If we received an empty response from peer, return with an error to break worker's + // feed back loop. + hdrLength := len(msg.Headers) + if hdrLength == 0 { + return query.ResponseErr + } + + // The initialHash represents the lower bound checkpoint for this checkpoint region. + // We verify, if the header received at that checkpoint height has the same hash as the + // checkpoint's hash. If it does not, mimicking the handleheaders function behaviour, we + // disconnect the peer and return a failed progress to reschedule the query. + if msg.Headers[0].PrevBlock != request.startHash && + request.startHash == request.initialHash { + + sp.Disconnect() + + return query.ResponseErr + } + + // If the peer sends us more headers than we need, it is probably not aligned with our chain, so we disconnect + // peer and return a failed progress. + reqMessage := request.Message().(*wire.MsgGetHeaders) + + if hdrLength > int(request.endHeight-request.startHeight) { + sp.Disconnect() + return query.ResponseErr + } + + // Write header into hdrTipResponse map, add the request's startHeight to the hdrTipSlice, for tracking + // and handling by the processBlKHeaderInCheckPtRegionInOrder loop. + c.blockMgr.writeBatchMtx.Lock() + c.blockMgr.hdrTipToResponse[request.startHeight] = &headersMsg{ + headers: msg, + peer: sp, + } + i := sort.Search(len(c.blockMgr.hdrTipSlice), func(i int) bool { + return c.blockMgr.hdrTipSlice[i] >= request.startHeight + }) + + c.blockMgr.hdrTipSlice = append(c.blockMgr.hdrTipSlice[:i], append([]int32{request.startHeight}, c.blockMgr.hdrTipSlice[i:]...)...) + c.blockMgr.writeBatchMtx.Unlock() + + // Check if job is unfinished, if it is, we modify the job accordingly and send back to the workmanager to be rescheduled. + if msg.Headers[hdrLength-1].BlockHash() != reqMessage.HashStop && reqMessage.HashStop != zeroHash { + // set new startHash, startHeight and blocklocator to set the next set of header for this job. + newStartHash := msg.Headers[hdrLength-1].BlockHash() + request.startHeight += int32(hdrLength) + request.startHash = newStartHash + reqMessage.BlockLocatorHashes = []*chainhash.Hash{&newStartHash} + + // Incase there is a rollback after handling reqMessage + // This ensures the job created by writecheckpt does not exceed that which we have fetched already. + c.blockMgr.writeBatchMtx.RLock() + _, ok = c.blockMgr.hdrTipToResponse[request.startHeight] + c.blockMgr.writeBatchMtx.RUnlock() + + if !ok { + return query.UnFinishedRequest + } + } + + return query.Finished +} + +// headerQuery implements ReqMessage interface for fetching block headers. +type headerQuery struct { + message wire.Message + startHeight int32 + initialHeight int32 + startHash chainhash.Hash + endHeight int32 + initialHash chainhash.Hash + index float64 +} + +func (h *headerQuery) Message() wire.Message { + return h.message +} + +func (h *headerQuery) PriorityIndex() float64 { + return h.index +} + // SyncPeer returns the current sync peer. func (b *blockManager) SyncPeer() *ServerPeer { b.syncPeerMutex.Lock() @@ -2021,11 +2517,48 @@ func (b *blockManager) SyncPeer() *ServerPeer { return b.syncPeer } -// isSyncCandidate returns whether or not the peer is a candidate to consider -// syncing from. -func (b *blockManager) isSyncCandidate(sp *ServerPeer) bool { - // The peer is not a candidate for sync if it's not a full node. - return sp.Services()&wire.SFNodeNetwork == wire.SFNodeNetwork +// deleteHeaderTipResp deletes all responses from newTip to prevTip. +func (b *blockManager) deleteHeaderTipResp(newTip, prevTip int32) { + b.writeBatchMtx.Lock() + defer b.writeBatchMtx.Unlock() + + var ( + finalIdx int + initialIdx int + ) + + for i := 0; i < len(b.hdrTipSlice) && b.hdrTipSlice[i] <= newTip; i++ { + if b.hdrTipSlice[i] < prevTip { + continue + } + + if b.hdrTipSlice[i] == prevTip { + initialIdx = i + } + + tip := b.hdrTipSlice[i] + + delete(b.hdrTipToResponse, tip) + + finalIdx = i + } + + b.hdrTipSlice = append(b.hdrTipSlice[:initialIdx], b.hdrTipSlice[finalIdx+1:]...) +} + +// resetHeaderListToChainTip resets the headerList to the chain tip. +func (b *blockManager) resetHeaderListToChainTip() error { + header, height, err := b.cfg.BlockHeaders.ChainTip() + if err != nil { + return err + } + b.headerList.ResetHeaderState(headerlist.Node{ + Header: *header, + Height: int32(height), + }) + log.Debugf("Resetting header list to chain tip %v ", b.headerTip) + + return nil } // findNextHeaderCheckpoint returns the next checkpoint after the passed height. @@ -2700,20 +3233,16 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { } } - // When this header is a checkpoint, find the next checkpoint. - if receivedCheckpoint { - b.nextCheckpoint = b.findNextHeaderCheckpoint(finalHeight) - } - // If not current, request the next batch of headers starting from the // latest known header and ending with the next checkpoint. - if b.cfg.ChainParams.Net == chaincfg.SimNetParams.Net || !b.BlockHeadersSynced() { + // Note this must come before reassigning a new b.nextCheckpoint, so that we push headers + // only when the current headers before this takes us past the checkpointed region. + if b.cfg.ChainParams.Net == chaincfg.SimNetParams.Net || !b.BlockHeadersSynced() && + b.nextCheckpoint == nil { + locator := blockchain.BlockLocator([]*chainhash.Hash{finalHash}) - nextHash := zeroHash - if b.nextCheckpoint != nil { - nextHash = *b.nextCheckpoint.Hash - } - err := hmsg.peer.PushGetHeadersMsg(locator, &nextHash) + + err := hmsg.peer.PushGetHeadersMsg(locator, &zeroHash) if err != nil { log.Warnf("Failed to send getheaders message to "+ "peer %s: %s", hmsg.peer.Addr(), err) @@ -2721,6 +3250,11 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { } } + // When this header is a checkpoint, find the next checkpoint. + if receivedCheckpoint { + b.nextCheckpoint = b.findNextHeaderCheckpoint(finalHeight) + } + // Since we have a new set of headers written to disk, we'll send out a // new signal to notify any waiting sub-systems that they can now maybe // proceed do to us extending the header chain. diff --git a/blockmanager_test.go b/blockmanager_test.go index 45554b7c..1e3ad4de 100644 --- a/blockmanager_test.go +++ b/blockmanager_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "reflect" + "sort" "strings" "testing" "time" @@ -89,11 +90,12 @@ func setupBlockManager(t *testing.T) (*blockManager, headerfs.BlockHeaderStore, // Set up a blockManager with the chain service we defined. bm, err := newBlockManager(&blockManagerCfg{ - ChainParams: chaincfg.SimNetParams, - BlockHeaders: hdrStore, - RegFilterHeaders: cfStore, - QueryDispatcher: &mockDispatcher{}, - TimeSource: blockchain.NewMedianTime(), + ChainParams: chaincfg.SimNetParams, + BlockHeaders: hdrStore, + RegFilterHeaders: cfStore, + cfHeaderQueryDispatcher: &mockDispatcher{}, + blkHdrCheckptQueryDispatcher: &mockDispatcher{}, + TimeSource: blockchain.NewMedianTime(), BanPeer: func(string, banman.Reason) error { return nil }, @@ -214,14 +216,14 @@ func generateHeaders(genesisBlockHeader *wire.BlockHeader, // generateResponses generates the MsgCFHeaders messages from the given queries // and headers. -func generateResponses(msgs []wire.Message, +func generateResponses(msgs []query.ReqMessage, headers *headers) ([]*wire.MsgCFHeaders, error) { // Craft a response for each message. var responses []*wire.MsgCFHeaders for _, msg := range msgs { // Only GetCFHeaders expected. - q, ok := msg.(*wire.MsgGetCFHeaders) + q, ok := msg.Message().(*wire.MsgGetCFHeaders) if !ok { return nil, fmt.Errorf("got unexpected message %T", msg) @@ -346,11 +348,11 @@ func TestBlockManagerInitialInterval(t *testing.T) { // We set up a custom query batch method for this test, as we // will use this to feed the blockmanager with our crafted // responses. - bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( + bm.cfg.cfHeaderQueryDispatcher.(*mockDispatcher).query = func( requests []*query.Request, options ...query.QueryOption) chan error { - var msgs []wire.Message + var msgs []query.ReqMessage for _, q := range requests { msgs = append(msgs, q.Req) } @@ -379,13 +381,13 @@ func TestBlockManagerInitialInterval(t *testing.T) { // Let the blockmanager handle the // message. progress := requests[index].HandleResp( - msgs[index], &resp, "", + msgs[index], &resp, nil, ) - if !progress.Finished { + if progress != query.Finished { errChan <- fmt.Errorf("got "+ - "response false on "+ - "send of index %d: %v", + " %v on "+ + "send of index %d: %v", progress, index, testDesc) return } @@ -400,13 +402,13 @@ func TestBlockManagerInitialInterval(t *testing.T) { // Otherwise resend the response we // just sent. progress = requests[index].HandleResp( - msgs[index], &resp2, "", + msgs[index], &resp2, nil, ) - if !progress.Finished { + if progress != query.Finished { errChan <- fmt.Errorf("got "+ - "response false on "+ - "resend of index %d: "+ - "%v", index, testDesc) + " %v on "+ + "send of index %d: %v", progress, + index, testDesc) return } } @@ -576,11 +578,11 @@ func TestBlockManagerInvalidInterval(t *testing.T) { require.NoError(t, err) } - bm.cfg.QueryDispatcher.(*mockDispatcher).query = func( + bm.cfg.cfHeaderQueryDispatcher.(*mockDispatcher).query = func( requests []*query.Request, options ...query.QueryOption) chan error { - var msgs []wire.Message + var msgs []query.ReqMessage for _, q := range requests { msgs = append(msgs, q.Req) } @@ -619,10 +621,10 @@ func TestBlockManagerInvalidInterval(t *testing.T) { // expect. for i := range responses { progress := requests[i].HandleResp( - msgs[i], responses[i], "", + msgs[i], responses[i], nil, ) if i == test.firstInvalid { - if progress.Finished { + if progress == query.Finished { t.Errorf("expected interval "+ "%d to be invalid", i) return @@ -631,7 +633,7 @@ func TestBlockManagerInvalidInterval(t *testing.T) { break } - if !progress.Finished { + if progress != query.Finished { t.Errorf("expected interval %d to be "+ "valid", i) return @@ -884,20 +886,6 @@ func TestHandleHeaders(t *testing.T) { fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") require.NoError(t, err) - assertPeerDisconnected := func(shouldBeDisconnected bool) { - // This is quite hacky but works: We expect the peer to be - // disconnected, which sets the unexported "disconnected" field - // to 1. - refValue := reflect.ValueOf(fakePeer).Elem() - foo := refValue.FieldByName("disconnect").Int() - - if shouldBeDisconnected { - require.EqualValues(t, 1, foo) - } else { - require.EqualValues(t, 0, foo) - } - } - // We'll want to use actual, real blocks, so we take a miner harness // that we can use to generate some. harness, err := rpctest.New( @@ -934,7 +922,7 @@ func TestHandleHeaders(t *testing.T) { // Let's feed in the correct headers. This should work fine and the peer // should not be disconnected. bm.handleHeadersMsg(hmsg) - assertPeerDisconnected(false) + assertPeerDisconnected(false, fakePeer, t) // Now scramble the headers and feed them in again. This should cause // the peer to be disconnected. @@ -943,5 +931,884 @@ func TestHandleHeaders(t *testing.T) { hmsg.headers.Headers[j], hmsg.headers.Headers[i] }) bm.handleHeadersMsg(hmsg) - assertPeerDisconnected(true) + assertPeerDisconnected(true, fakePeer, t) +} + +// assertPeerDisconnected asserts that the peer supplied as an argument is disconnected. +func assertPeerDisconnected(shouldBeDisconnected bool, sp *peer.Peer, t *testing.T) { + // This is quite hacky but works: We expect the peer to be + // disconnected, which sets the unexported "disconnected" field + // to 1. + refValue := reflect.ValueOf(sp).Elem() + foo := refValue.FieldByName("disconnect").Int() + + if shouldBeDisconnected { + require.EqualValues(t, 1, foo) + } else { + require.EqualValues(t, 0, foo) + } +} + +// TestBatchCheckpointedBlkHeaders tests the batch checkpointed headers function. +func TestBatchCheckpointedBlkHeaders(t *testing.T) { + t.Parallel() + + // First, we set up a block manager and a fake peer that will act as the + // test's remote peer. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + // Created checkpoints for our simulated network. + checkpoints := []chaincfg.Checkpoint{ + + { + Hash: &chainhash.Hash{1}, + Height: int32(1), + }, + + { + Hash: &chainhash.Hash{2}, + Height: int32(2), + }, + + { + Hash: &chainhash.Hash{3}, + Height: int32(3), + }, + } + + modParams := chaincfg.SimNetParams + modParams.Checkpoints = append(modParams.Checkpoints, checkpoints...) + bm.cfg.ChainParams = modParams + + // set checkpoint and header tip. + bm.nextCheckpoint = &checkpoints[0] + + bm.newHeadersMtx.Lock() + bm.headerTip = 0 + bm.headerTipHash = chainhash.Hash{0} + bm.newHeadersMtx.Unlock() + + // This is the query we assert to obtain if the function works accordingly. + expectedQuery := CheckpointedBlockHeadersQuery{ + blockMgr: bm, + msgs: []*headerQuery{ + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{{0}}), + HashStop: *checkpoints[0].Hash, + }, + startHeight: int32(0), + initialHeight: int32(0), + startHash: chainhash.Hash{0}, + endHeight: checkpoints[0].Height, + initialHash: chainhash.Hash{0}, + }, + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{checkpoints[0].Hash}), + HashStop: *checkpoints[1].Hash, + }, + startHeight: checkpoints[0].Height, + initialHeight: checkpoints[0].Height, + startHash: *checkpoints[0].Hash, + endHeight: checkpoints[1].Height, + initialHash: *checkpoints[0].Hash, + }, + + { + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{checkpoints[1].Hash}), + HashStop: *checkpoints[2].Hash, + }, + startHeight: checkpoints[1].Height, + initialHeight: checkpoints[1].Height, + startHash: *checkpoints[1].Hash, + endHeight: checkpoints[2].Height, + initialHash: *checkpoints[1].Hash, + }, + + { + + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: blockchain.BlockLocator([]*chainhash.Hash{checkpoints[2].Hash}), + HashStop: zeroHash, + }, + startHeight: checkpoints[2].Height, + initialHeight: checkpoints[2].Height, + startHash: *checkpoints[2].Hash, + endHeight: checkpoints[2].Height + wire.MaxBlockHeadersPerMsg, + initialHash: *checkpoints[2].Hash, + }, + }, + } + + // create request. + expectedRequest := expectedQuery.requests() + + bm.cfg.blkHdrCheckptQueryDispatcher.(*mockDispatcher).query = func(requests []*query.Request, + options ...query.QueryOption) chan error { + + // assert that the requests obtained has same length as that of our expected query. + if len(requests) != len(expectedRequest) { + t.Fatalf("unequal length") + } + + for i, req := range requests { + testEqualReqMessage(req, expectedRequest[i], t) + } + + // Ensure the query options sent by query is four. This is the number of query option supplied as args while + // querying the workmanager. + if len(options) != 3 { + t.Fatalf("expected five option parameter for query but got, %v\n", len(options)) + } + return nil + } + + // call the function that we are testing. + bm.batchCheckpointedBlkHeaders() +} + +// This function tests the ProcessBlKHeaderInCheckPtRegionInOrder function. +func TestProcessBlKHeaderInCheckPtRegionInOrder(t *testing.T) { + t.Parallel() + + // First, we set up a block manager and a fake peer that will act as the + // test's remote peer. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") + require.NoError(t, err) + + // We'll want to use actual, real blocks, so we take a miner harness + // that we can use to generate some. + harness, err := rpctest.New( + &chaincfg.SimNetParams, nil, []string{"--txindex"}, "", + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, harness.TearDown()) + }) + + err = harness.SetUp(false, 0) + require.NoError(t, err) + + // Generate 30 valid blocks that we then feed to the block manager. + blockHashes, err := harness.Client.Generate(30) + require.NoError(t, err) + + // This is the headerMessage containing 10 headers starting at height 0. + hmsgTip0 := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 10), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + // This is the headerMessage containing 10 headers starting at height 10. + hmsgTip10 := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 10), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + // This is the headerMessage containing 10 headers starting at height 20. + hmsgTip20 := &headersMsg{ + headers: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 10), + }, + peer: &ServerPeer{ + Peer: fakePeer, + }, + } + + // Loop through the generated blockHashes and add headers to their appropriate slices. + for i := range blockHashes { + header, err := harness.Client.GetBlockHeader(blockHashes[i]) + require.NoError(t, err) + + if i < 10 { + hmsgTip0.headers.Headers[i] = header + } + + if i >= 10 && i < 20 { + hmsgTip10.headers.Headers[i-10] = header + } + + if i >= 20 { + hmsgTip20.headers.Headers[i-20] = header + } + } + + // initialize the hdrTipSlice. + bm.hdrTipSlice = make([]int32, 0) + + // Create checkpoint for our test chain. + checkpoint := chaincfg.Checkpoint{ + Hash: blockHashes[29], + Height: int32(30), + } + bm.cfg.ChainParams.Checkpoints = append(bm.cfg.ChainParams.Checkpoints, []chaincfg.Checkpoint{ + checkpoint, + }...) + bm.nextCheckpoint = &checkpoint + + // If ProcessBlKHeaderInCheckPtRegionInOrder loop receives invalid headers assert the query parameters being sent + // to the workmanager is expected. + bm.cfg.blkHdrCheckptQueryDispatcher.(*mockDispatcher).query = func(requests []*query.Request, + options ...query.QueryOption) chan error { + + // The function should send only one request. + if len(requests) != 1 { + t.Fatalf("expected only one request") + } + + finalNode := bm.headerList.Back() + newHdrTip := finalNode.Height + newHdrTipHash := finalNode.Header.BlockHash() + prevCheckPt := bm.findPreviousHeaderCheckpoint(newHdrTip) + + testEqualReqMessage(requests[0], &query.Request{ + + Req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{&newHdrTipHash}, + HashStop: *bm.nextCheckpoint.Hash, + }, + startHeight: newHdrTip, + initialHeight: prevCheckPt.Height, + startHash: newHdrTipHash, + endHeight: bm.nextCheckpoint.Height, + initialHash: newHdrTipHash, + index: 0, + }, + }, t) + + // The function should include only four query options while querying. + if len(options) != 3 { + t.Fatalf("expected three option parameter for query but got, %v\n", len(options)) + } + return nil + } + + // Call the function in a goroutine. + go bm.processBlKHeaderInCheckPtRegionInOrder() + + // At this point syncPeer should be nil. + bm.syncPeerMutex.RLock() + if bm.syncPeer != nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should be nil initially") + } + bm.syncPeerMutex.RUnlock() + + // Set header tip to zero and write a response at height 10, ensure the ProcessBlKHeaderInCheckPtRegionInOrder loop + // does not handle the response as it does not correspond to the current header tip. + bm.newHeadersMtx.Lock() + bm.headerTip = 0 + bm.newHeadersMtx.Unlock() + + bm.writeBatchMtx.Lock() + newTipWrite := int32(10) + bm.hdrTipToResponse[newTipWrite] = hmsgTip10 + i := sort.Search(len(bm.hdrTipSlice), func(i int) bool { + return bm.hdrTipSlice[i] >= newTipWrite + }) + + bm.hdrTipSlice = append(bm.hdrTipSlice[:i], append([]int32{newTipWrite}, bm.hdrTipSlice[i:]...)...) + bm.writeBatchMtx.Unlock() + + // SyncPeer should still be nil to indicate that the loop did not handle the response. + bm.syncPeerMutex.RLock() + if bm.syncPeer != nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should be nil") + } + bm.syncPeerMutex.RUnlock() + + // Set header tip to 20 to indicate that even when the chain's tip is higher that the available tips in the + // hdrTipToResponse map, the loop does not still handle it. + bm.newHeadersMtx.Lock() + bm.headerTip = 20 + bm.newHeadersMtx.Unlock() + + // SyncPeer should still be nil to indicate that the loop did not handle the response. + bm.syncPeerMutex.RLock() + if bm.syncPeer != nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should be nil") + } + bm.syncPeerMutex.RUnlock() + + // Set headerTip to zero and write a response at height 0 to the hdrTipToResponse map. The loop should handle this + // response now and the following response that would correspond to its new tip after this. + bm.newHeadersMtx.Lock() + bm.headerTip = 0 + bm.newHeadersMtx.Unlock() + + bm.writeBatchMtx.Lock() + newTipWrite = int32(0) + i = sort.Search(len(bm.hdrTipSlice), func(i int) bool { + return bm.hdrTipSlice[i] >= newTipWrite + }) + + bm.hdrTipSlice = append(bm.hdrTipSlice[:i], append([]int32{newTipWrite}, bm.hdrTipSlice[i:]...)...) + + bm.hdrTipToResponse[newTipWrite] = hmsgTip0 + bm.writeBatchMtx.Unlock() + + // Allow time for handling the response. + time.Sleep(1 * time.Second) + bm.syncPeerMutex.RLock() + if bm.syncPeer == nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should not be nil") + } + bm.syncPeerMutex.RUnlock() + + // Header tip should be 20 as th the loop would handle response at height 0 then the previously written + // height 10. + bm.newHeadersMtx.RLock() + if bm.headerTip != 20 { + hdrTip := bm.headerTip + bm.newHeadersMtx.RUnlock() + t.Fatalf("expected header tip at 10 but got %v\n", hdrTip) + } + bm.newHeadersMtx.RUnlock() + + // Now scramble the headers and feed them in again. This should cause + // the loop to delete this response from the map and re-request for this header from + // the workmanager. + rand.Shuffle(len(hmsgTip20.headers.Headers), func(i, j int) { + hmsgTip20.headers.Headers[i], hmsgTip20.headers.Headers[j] = + hmsgTip20.headers.Headers[j], hmsgTip20.headers.Headers[i] + }) + + // Write this header at height 20, this would cause the loop to handle it. + bm.writeBatchMtx.Lock() + newTipWrite = int32(20) + bm.hdrTipToResponse[newTipWrite] = hmsgTip20 + i = sort.Search(len(bm.hdrTipSlice), func(i int) bool { + return bm.hdrTipSlice[i] >= newTipWrite + }) + + bm.hdrTipSlice = append(bm.hdrTipSlice[:i], append([]int32{newTipWrite}, bm.hdrTipSlice[i:]...)...) + + bm.writeBatchMtx.Unlock() + + // Allow time for handling. + time.Sleep(1 * time.Second) + + // HeadrTip should not advance as headers are invalid. + bm.newHeadersMtx.RLock() + if bm.headerTip != 20 { + hdrTip := bm.headerTip + bm.newHeadersMtx.RUnlock() + t.Fatalf("expected header tip at 20 but got %v\n", hdrTip) + } + bm.newHeadersMtx.RUnlock() + + // Syncpeer should not be nil as we are still in the loop. + bm.syncPeerMutex.RLock() + if bm.syncPeer == nil { + bm.syncPeerMutex.RUnlock() + t.Fatalf("syncPeer should not be nil") + } + bm.syncPeerMutex.RUnlock() + + // The response at header tip 20 should be deleted. + bm.writeBatchMtx.RLock() + _, ok := bm.hdrTipToResponse[int32(20)] + bm.writeBatchMtx.RUnlock() + + if ok { + t.Fatalf("expected response to header tip deleted") + } +} + +// TestCheckpointedBlockHeadersQuery_handleResponse tests the handleResponse method +// of the CheckpointedBlockHeadersQuery. +func TestCheckpointedBlockHeadersQuery_handleResponse(t *testing.T) { + t.Parallel() + + // handleRespTestCase holds all the information required to test different scenarios while + // using the function. + type handleRespTestCase struct { + + // name of the testcase. + name string + + // resp is the response argument to be sent to the handleResp method as an arg. + resp wire.Message + + // req is the request method to be sent to the handleResp method as an arg. + req query.ReqMessage + + // progress is the expected progress to be returned by the handleResp method. + progress query.Progress + + // lastblock is the block with which we obtain its hash to be used as the request's hashStop. + lastBlock wire.BlockHeader + + // peerDisconnected indicates if the peer would be disconnected after the handleResp method is done. + peerDisconnected bool + } + + testCases := []handleRespTestCase{ + + { + // Scenario in which we have a request type that is not the same as the expected headerQuery type.It should + // return no error and NoProgressNoFinalResp query.Progress. + name: "invalid request type", + resp: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 0, 5), + }, + req: &encodedQuery{}, + progress: query.NoResponse, + }, + + { + // Scenario in which we have a response type that is not same as the expected wire.MsgHeaders. It should + // return no error and NoProgressNoFinalResp query.Progress. + name: "invalid response type", + resp: &wire.MsgCFHeaders{}, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 6, + initialHash: chainhash.Hash{1}, + }, + progress: query.NoResponse, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + }, + + { + // Scenario in which we have the response in the hdrTipResponseMap. While calling these testcases, we + // initialize the hdrTipToResponse map to contain a response at height 0 and 6. Since this request ahs a + // startheight of 0, its response would be in the map already, aligning with this scenario. This scenario + // should return query.IgnoreRequest, + name: "response start Height in hdrTipResponse map", + resp: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 0, 4), + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {0}, + }, + }, + startHeight: 0, + initialHeight: 0, + startHash: chainhash.Hash{0}, + endHeight: 5, + initialHash: chainhash.Hash{0}, + }, + progress: query.IgnoreRequest, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + }, + + { + // Scenario in which the valid response we receive is of length, zero. We should return + // query.ResponseErr. + name: "response header length 0", + resp: &wire.MsgHeaders{ + Headers: make([]*wire.BlockHeader, 0), + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 5, + initialHash: chainhash.Hash{1}, + }, + progress: query.ResponseErr, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + }, + + { + // Scenario in which the response received is at the request's initialHeight (lower bound height in + // checkpoint request) but its first block's previous hash is not same as the checkpoint hash. It + // should return query.ResponseErr. + name: "response at initialHash has disconnected start Hash", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 5, + initialHash: chainhash.Hash{1}, + }, + progress: query.ResponseErr, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + peerDisconnected: true, + }, + + { + // Scenario in which the response is not at the initial Hash (lower bound hash in the + // checkpoint request) but the response is complete and valid. It should return query.Finished. + name: "response not at initialHash, valid complete headers", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {2}, + }, + }, + startHeight: 2, + initialHeight: 1, + startHash: chainhash.Hash{2}, + endHeight: 5, + initialHash: chainhash.Hash{2}, + }, + progress: query.Finished, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{4}, + }, + }, + + { + // Scenario in which the response is not at initial hash (lower bound height in + // checkpoint request) and the response is unfinished. The jobErr should be nil and return + // finalRespNoProgress query.progress. + name: "response not at initial Hash, unfinished response", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {2}, + }, + }, + startHeight: 2, + initialHeight: 1, + startHash: chainhash.Hash{2}, + endHeight: 6, + initialHash: chainhash.Hash{2}, + }, + progress: query.UnFinishedRequest, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + }, + + { + // Scenario in which the response length is greater than expected. Peer should be + // disconnected and the method should return query.ResponseErr. + name: "response header length more than expected", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + { + PrevBlock: chainhash.Hash{5}, + }, + { + PrevBlock: chainhash.Hash{6}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 6, + initialHash: chainhash.Hash{1}, + }, + progress: query.ResponseErr, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + peerDisconnected: true, + }, + + { + // Scenario in which response is complete and a valid header. Its start height is at the initial height. + // progress should be query.Finished. + name: "complete response valid headers", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 4, + initialHash: chainhash.Hash{1}, + }, + progress: query.Finished, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{3}, + }, + }, + + { + // Scenario in which response is at initialHash and the response is incomplete. + // It should return query.UnFinishedRequest. + name: "response at initial hash, incomplete response, valid headers", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 6, + initialHash: chainhash.Hash{1}, + }, + progress: query.UnFinishedRequest, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{5}, + }, + }, + + { + // Scenario in which response is incomplete but valid. The new response's start height created in this + // scenario is present in the hdrTipResponseMap. The startHeight is 6 and response at height 6 has been + // preveiously written in to the hdrTipResponse map for the sake of this test. + name: "incomplete response, valid headers, new resp in hdrTipToResponse map", + resp: &wire.MsgHeaders{ + Headers: []*wire.BlockHeader{ + { + PrevBlock: chainhash.Hash{1}, + }, + { + PrevBlock: chainhash.Hash{2}, + }, + { + PrevBlock: chainhash.Hash{3}, + }, + { + PrevBlock: chainhash.Hash{4}, + }, + { + PrevBlock: chainhash.Hash{5}, + }, + }, + }, + req: &headerQuery{ + message: &wire.MsgGetHeaders{ + BlockLocatorHashes: []*chainhash.Hash{ + {1}, + }, + }, + startHeight: 1, + initialHeight: 1, + startHash: chainhash.Hash{1}, + endHeight: 10, + initialHash: chainhash.Hash{1}, + }, + progress: query.Finished, + lastBlock: wire.BlockHeader{ + PrevBlock: chainhash.Hash{9}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // set up block manager. + bm, _, _, err := setupBlockManager(t) + require.NoError(t, err) + + var oldReqStartHeight int32 + + bm.hdrTipToResponse[0] = &headersMsg{ + headers: &wire.MsgHeaders{}, + } + bm.hdrTipToResponse[6] = &headersMsg{ + headers: &wire.MsgHeaders{}, + } + + fakePeer, err := peer.NewOutboundPeer(&peer.Config{}, "fake:123") + require.NoError(t, err) + + blkHdrquery := &CheckpointedBlockHeadersQuery{ + blockMgr: bm, + } + req := tc.req + r, ok := tc.req.(*headerQuery) + if ok { + reqMessage, ok := req.Message().(*wire.MsgGetHeaders) + if !ok { + t.Fatalf("request message not of type wire.MsgGetHeaders") + } + reqMessage.HashStop = tc.lastBlock.BlockHash() + req = r + oldReqStartHeight = r.startHeight + } + actualProgress := blkHdrquery.handleResponse(req, tc.resp, &ServerPeer{ + Peer: fakePeer, + }) + + if tc.progress != actualProgress { + t.Fatalf("unexpected progress.Expected: %v but got: %v", tc.progress, actualProgress) + } + + if actualProgress == query.UnFinishedRequest { + resp := tc.resp.(*wire.MsgHeaders) + request := req.(*headerQuery) + if request.startHash != resp.Headers[len(resp.Headers)-1].BlockHash() { + t.Fatalf("unexpected new startHash") + } + + if request.startHeight != oldReqStartHeight+int32(len(resp.Headers)) { + t.Fatalf("unexpected new start height") + } + + requestMessage := req.Message().(*wire.MsgGetHeaders) + + if *requestMessage.BlockLocatorHashes[0] != request.startHash { + t.Fatalf("unexpected new blockLocator") + } + } + + assertPeerDisconnected(tc.peerDisconnected, fakePeer, t) + }) + } +} + +// testEqualReqMessage tests if two query.Request are same. +func testEqualReqMessage(a, b *query.Request, t *testing.T) { + aMessage := a.Req.(*headerQuery) + bMessage := b.Req.(*headerQuery) + + if aMessage.startHeight != bMessage.startHeight { + t.Fatalf("dissimilar startHeight") + } + if aMessage.startHash != bMessage.startHash { + t.Fatalf("dissimilar startHash") + } + if aMessage.endHeight != bMessage.endHeight { + t.Fatalf("dissimilar endHash") + } + if aMessage.initialHash != bMessage.initialHash { + t.Fatalf("dissimilar initialHash") + } + + aMessageGetHeaders := aMessage.Message().(*wire.MsgGetHeaders) + bMessageGetHeaders := bMessage.Message().(*wire.MsgGetHeaders) + + if !reflect.DeepEqual(aMessageGetHeaders.BlockLocatorHashes, bMessageGetHeaders.BlockLocatorHashes) { + t.Fatalf("dissimilar blocklocator hash") + } + + if aMessageGetHeaders.HashStop != bMessageGetHeaders.HashStop { + t.Fatalf("dissimilar hashstop") + } + if a.Req.PriorityIndex() != b.Req.PriorityIndex() { + t.Fatalf("dissimilar priority index") + } } diff --git a/neutrino.go b/neutrino.go index 7ee45edd..5179e0d9 100644 --- a/neutrino.go +++ b/neutrino.go @@ -195,6 +195,31 @@ func NewServerPeer(s *ChainService, isPersistent bool) *ServerPeer { } } +// IsSyncCandidate returns whether or not the peer is a candidate to consider +// syncing from. +func (sp *ServerPeer) IsSyncCandidate() bool { + // The peer is not a candidate for sync if it's not a full node. + return sp.Services()&wire.SFNodeNetwork == wire.SFNodeNetwork +} + +// IsPeerBehindStartHeight returns a boolean indicating if the peer's last block height +// is behind the start height of the request. If the peer is not behind the request start +// height false is returned, otherwise, true is. +func (sp *ServerPeer) IsPeerBehindStartHeight(req query.ReqMessage) bool { + queryGetHeaders, ok := req.(*headerQuery) + + if !ok { + log.Debugf("request is not type headerQuery") + + return true + } + + if sp.LastBlock() < queryGetHeaders.startHeight { + return true + } + return false +} + // newestBlock returns the current best block hash and height using the format // required by the configuration for the peer package. func (sp *ServerPeer) newestBlock() (*chainhash.Hash, int32, error) { @@ -800,15 +825,21 @@ func NewChainService(cfg Config) (*ChainService, error) { } bm, err := newBlockManager(&blockManagerCfg{ - ChainParams: s.chainParams, - BlockHeaders: s.BlockHeaders, - RegFilterHeaders: s.RegFilterHeaders, - TimeSource: s.timeSource, - QueryDispatcher: s.workManager, - BanPeer: s.BanPeer, - GetBlock: s.GetBlock, - firstPeerSignal: s.firstPeerConnect, - queryAllPeers: s.queryAllPeers, + ChainParams: s.chainParams, + BlockHeaders: s.BlockHeaders, + RegFilterHeaders: s.RegFilterHeaders, + TimeSource: s.timeSource, + cfHeaderQueryDispatcher: s.workManager, + BanPeer: s.BanPeer, + GetBlock: s.GetBlock, + firstPeerSignal: s.firstPeerConnect, + queryAllPeers: s.queryAllPeers, + blkHdrCheckptQueryDispatcher: query.NewWorkManager(&query.Config{ + ConnectedPeers: s.ConnectedPeers, + NewWorker: query.NewWorker, + Ranking: query.NewPeerRanking(), + IsEligibleWorkerFunc: query.IsWorkerEligibleForBlkHdrFetch, + }), }) if err != nil { return nil, err @@ -1610,6 +1641,9 @@ func (s *ChainService) Start() error { s.addrManager.Start() s.blockManager.Start() s.blockSubscriptionMgr.Start() + if err := s.blockManager.cfg.blkHdrCheckptQueryDispatcher.Start(); err != nil { + return fmt.Errorf("unable to start block header work manager: %v", err) + } if err := s.workManager.Start(); err != nil { return fmt.Errorf("unable to start work manager: %v", err) } diff --git a/query.go b/query.go index 66a506dd..40c968a5 100644 --- a/query.go +++ b/query.go @@ -66,10 +66,7 @@ var ( // noProgress will be used to indicate to a query.WorkManager that a // response makes no progress towards the completion of the query. - noProgress = query.Progress{ - Finished: false, - Progressed: false, - } + noProgress = query.NoResponse ) // queries are a set of options that can be modified per-query, unlike global @@ -430,26 +427,50 @@ type cfiltersQuery struct { headerIndex map[chainhash.Hash]int targetHash chainhash.Hash targetFilter *gcs.Filter + mtx sync.Mutex } // request couples a query message with the handler to be used for the response // in a query.Request struct. func (q *cfiltersQuery) request() *query.Request { - msg := wire.NewMsgGetCFilters( - q.filterType, uint32(q.startHeight), q.stopHash, - ) + msg := &encodedQuery{ + message: wire.NewMsgGetCFilters( + q.filterType, uint32(q.startHeight), q.stopHash, + ), + encoding: wire.WitnessEncoding, + } return &query.Request{ Req: msg, HandleResp: q.handleResponse, + SendQuery: sendQueryMessageWithEncoding, + CloneReq: func(req query.ReqMessage) query.ReqMessage { + oldReq, ok := req.(*encodedQuery) + if !ok { + log.Errorf("request not of type *encodedQuery") + } + oldReqMessage, ok := oldReq.message.(*wire.MsgGetCFilters) + if !ok { + log.Errorf("request not of type *wire.MsgGetCFilters") + } + newReq := &encodedQuery{ + message: wire.NewMsgGetCFilters( + oldReqMessage.FilterType, oldReqMessage.StartHeight, &oldReqMessage.StopHash, + ), + encoding: oldReq.encoding, + priorityIndex: oldReq.priorityIndex, + } + return newReq + }, } } // handleResponse validates that the cfilter response we get from a peer is // sane given the getcfilter query that we made. -func (q *cfiltersQuery) handleResponse(req, resp wire.Message, - _ string) query.Progress { +func (q *cfiltersQuery) handleResponse(r query.ReqMessage, resp wire.Message, + peer query.Peer) query.Progress { + req := r.Message() // The request must have been a "getcfilters" msg. request, ok := req.(*wire.MsgGetCFilters) if !ok { @@ -476,6 +497,8 @@ func (q *cfiltersQuery) handleResponse(req, resp wire.Message, // If this filter is for a block not in our index, we can ignore it, as // we either already got it, or it is out of our queried range. + q.mtx.Lock() + defer q.mtx.Unlock() i, ok := q.headerIndex[response.BlockHash] if !ok { return noProgress @@ -548,17 +571,11 @@ func (q *cfiltersQuery) handleResponse(req, resp wire.Message, // If there are still entries left in the headerIndex then the query // has made progress but has not yet completed. if len(q.headerIndex) != 0 { - return query.Progress{ - Finished: false, - Progressed: true, - } + return query.Progressed } // The headerIndex is empty and so this query is complete. - return query.Progress{ - Finished: true, - Progressed: true, - } + return query.Finished } // prepareCFiltersQuery creates a cfiltersQuery that can be used to fetch a @@ -759,6 +776,7 @@ func (s *ChainService) GetCFilter(blockHash chainhash.Hash, query.Cancel(s.quit), query.Encoding(qo.encoding), query.NumRetries(qo.numRetries), + query.ErrChan(make(chan error, 1)), } errChan := s.workManager.Query( @@ -833,13 +851,22 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, // Construct the appropriate getdata message to fetch the target block. getData := wire.NewMsgGetData() _ = getData.AddInvVect(inv) + msg := &encodedQuery{ + message: getData, + encoding: wire.WitnessEncoding, + } var foundBlock *btcutil.Block // handleResp will be called for each message received from a peer. It // will be used to signal to the work manager whether progress has been // made or not. - handleResp := func(req, resp wire.Message, peer string) query.Progress { + handleResp := func(request query.ReqMessage, resp wire.Message, sp query.Peer) query.Progress { + req := request.Message() + peer := "" + if sp != nil { + peer = sp.Addr() + } // The request must have been a "getdata" msg. _, ok := req.(*wire.MsgGetData) if !ok { @@ -904,16 +931,31 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, // we declare it sane. We can kill the query and pass the // response back to the caller. foundBlock = block - return query.Progress{ - Finished: true, - Progressed: true, - } + return query.Finished } // Prepare the query request. request := &query.Request{ - Req: getData, + Req: msg, HandleResp: handleResp, + SendQuery: sendQueryMessageWithEncoding, + CloneReq: func(req query.ReqMessage) query.ReqMessage { + newMsg := wire.NewMsgGetData() + _ = newMsg.AddInvVect(inv) + + oldReq, ok := req.(*encodedQuery) + if !ok { + log.Errorf("request not of type *encodedQuery") + } + + newReq := &encodedQuery{ + message: newMsg, + encoding: oldReq.encoding, + priorityIndex: oldReq.priorityIndex, + } + + return newReq + }, } // Prepare the query options. @@ -921,6 +963,7 @@ func (s *ChainService) GetBlock(blockHash chainhash.Hash, query.Encoding(qo.encoding), query.NumRetries(qo.numRetries), query.Cancel(s.quit), + query.ErrChan(make(chan error, 1)), } // Send the request to the work manager and await a response. diff --git a/query/interface.go b/query/interface.go index dca5f42d..1a6517d4 100644 --- a/query/interface.go +++ b/query/interface.go @@ -43,6 +43,9 @@ type queryOptions struct { // that a query can be retried. If this is set then numRetries has no // effect. noRetryMax bool + + // errChan error channel with which the workmananger sends error. + errChan chan error } // QueryOption is a functional option argument to any of the network query @@ -67,6 +70,14 @@ func (qo *queryOptions) applyQueryOptions(options ...QueryOption) { } } +// ErrChan is a query option that specifies the error channel which the workmanager +// sends any error to. +func ErrChan(err chan error) QueryOption { + return func(qo *queryOptions) { + qo.errChan = err + } +} + // NumRetries is a query option that specifies the number of times a query // should be retried. func NumRetries(num uint8) QueryOption { @@ -107,25 +118,40 @@ func Cancel(cancel chan struct{}) QueryOption { } } -// Progress encloses the result of handling a response for a given Request, -// determining whether the response did progress the query. -type Progress struct { - // Finished is true if the query was finished as a result of the - // received response. - Finished bool - - // Progressed is true if the query made progress towards fully - // answering the request as a result of the received response. This is - // used for the requests types where more than one response is - // expected. - Progressed bool -} +// Progress encloses the result of handling a response for a given Request. +type Progress string + +var ( + + // Finished indicates we have received the complete, valid response for this request, + // and so we are done with it. + Finished Progress = "Received complete and valid response for request." + + // Progressed indicates that we have received a valid response, but we are expecting more. + Progressed Progress = "Received valid response, expecting more response for query." + + // UnFinishedRequest indicates that we have received some response, but we need to rescheule the job + // to completely fetch all the response required for this request. + UnFinishedRequest Progress = "Received valid response, reschedule to complete request" + + // ResponseErr indicates we obtained a valid response but response fails checks and needs to + // be rescheduled. + ResponseErr Progress = "Received valid response but fails checks " + + // IgnoreRequest indicates that we have received a valid response but the workmanager need take + // no action on the result of this job. + IgnoreRequest Progress = "Received response but ignoring" + + // NoResponse indicates that we have received an invalid response for this request, and we need + // to wait for a valid one. + NoResponse Progress = "Received invalid response" +) // Request is the main struct that defines a bitcoin network query to be sent to // connected peers. type Request struct { - // Req is the message request to send. - Req wire.Message + // Req contains the message request to send. + Req ReqMessage // HandleResp is a response handler that will be called for every // message received from the peer that the request was made to. It @@ -138,7 +164,26 @@ type Request struct { // should validate the response and immediately return the progress. // The response should be handed off to another goroutine for // processing. - HandleResp func(req, resp wire.Message, peer string) Progress + HandleResp func(req ReqMessage, resp wire.Message, peer Peer) Progress + + // SendQuery handles sending request to the worker's peer. It returns an error, + // if one is encountered while sending the request. + SendQuery func(peer Peer, request ReqMessage) error + + // CloneReq clones the message. + CloneReq func(message ReqMessage) ReqMessage +} + +// ReqMessage is an interface which all structs containing information +// required to process a message request must implement. +type ReqMessage interface { + + // Message returns the message request. + Message() wire.Message + + // PriorityIndex returns the priority the caller prefers the request + // would take. + PriorityIndex() float64 } // WorkManager defines an API for a manager that dispatches queries to bitcoin @@ -167,11 +212,6 @@ type Dispatcher interface { // Peer is the interface that defines the methods needed by the query package // to be able to make requests and receive responses from a network peer. type Peer interface { - // QueueMessageWithEncoding adds the passed bitcoin message to the peer - // send queue. - QueueMessageWithEncoding(msg wire.Message, doneChan chan<- struct{}, - encoding wire.MessageEncoding) - // SubscribeRecvMsg adds a OnRead subscription to the peer. All bitcoin // messages received from this peer will be sent on the returned // channel. A closure is also returned, that should be called to cancel @@ -184,4 +224,11 @@ type Peer interface { // OnDisconnect returns a channel that will be closed when this peer is // disconnected. OnDisconnect() <-chan struct{} + + // IsPeerBehindStartHeight returns a boolean indicating if the peer's known last height is behind + // the request's start Height which it receives as an argument. + IsPeerBehindStartHeight(req ReqMessage) bool + + // IsSyncCandidate returns true if the peer is a sync candidate. + IsSyncCandidate() bool } diff --git a/query/worker.go b/query/worker.go index dc15a18c..dec49ae0 100644 --- a/query/worker.go +++ b/query/worker.go @@ -3,8 +3,6 @@ package query import ( "errors" "time" - - "github.com/btcsuite/btcd/wire" ) var ( @@ -19,15 +17,22 @@ var ( // ErrJobCanceled is returned if the job is canceled before the query // has been answered. ErrJobCanceled = errors.New("job canceled") + + // ErrIgnoreRequest is returned if we want to ignore the request after getting + // a response. + ErrIgnoreRequest = errors.New("ignore request") + + // ErrResponseErr is returned if we received a compatible response for the query but, it did not pass + // preliminary verification. + ErrResponseErr = errors.New("received response with error") ) // queryJob is the internal struct that wraps the Query to work on, in // addition to some information about the query. type queryJob struct { tries uint8 - index uint64 + index float64 timeout time.Duration - encoding wire.MessageEncoding cancelChan <-chan struct{} *Request } @@ -39,15 +44,16 @@ var _ Task = (*queryJob)(nil) // Index returns the queryJob's index within the work queue. // // NOTE: Part of the Task interface. -func (q *queryJob) Index() uint64 { +func (q *queryJob) Index() float64 { return q.index } // jobResult is the final result of the worker's handling of the queryJob. type jobResult struct { - job *queryJob - peer Peer - err error + job *queryJob + peer Peer + err error + unfinished bool } // worker is responsible for polling work from its work queue, and handing it @@ -89,6 +95,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { msgChan, cancel := peer.SubscribeRecvMsg() defer cancel() +nextJobLoop: for { log.Tracef("Worker %v waiting for more work", peer.Addr()) @@ -133,17 +140,33 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { log.Tracef("Worker %v queuing job %T with index %v", peer.Addr(), job.Req, job.Index()) - peer.QueueMessageWithEncoding(job.Req, nil, job.encoding) + err := job.SendQuery(peer, job.Req) + + // If any error occurs while sending query, quickly send the result + // containing the error to the workmanager. + if err != nil { + select { + case results <- &jobResult{ + job: job, + peer: peer, + err: err, + }: + case <-quit: + return + } + goto nextJobLoop + } } // Wait for the correct response to be received from the peer, // or an error happening. var ( - jobErr error - timeout = time.NewTimer(job.timeout) + jobErr error + jobUnfinished bool + timeout = time.NewTimer(job.timeout) ) - Loop: + feedbackLoop: for { select { // A message was received from the peer, use the @@ -151,37 +174,50 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { // our request. case resp := <-msgChan: progress := job.HandleResp( - job.Req, resp, peer.Addr(), + job.Req, resp, peer, ) log.Tracef("Worker %v handled msg %T while "+ - "waiting for response to %T (job=%v). "+ - "Finished=%v, progressed=%v", - peer.Addr(), resp, job.Req, job.Index(), - progress.Finished, progress.Progressed) - - // If the response did not answer our query, we - // check whether it did progress it. - if !progress.Finished { - // If it did make progress we reset the - // timeout. This ensures that the - // queries with multiple responses - // expected won't timeout before all - // responses have been handled. - // TODO(halseth): separate progress - // timeout value. - if progress.Progressed { - timeout.Stop() - timeout = time.NewTimer( - job.timeout, - ) - } - continue Loop + "waiting for response to %T (job=%v). ", + peer.Addr(), resp, job.Req, job.Index()) + + switch { + case progress == Finished: + + // Wait for valid response if we have not gotten any one yet. + case progress == NoResponse: + + continue feedbackLoop + + // Increase job's timeout if valid response has been received, and we + // are awaiting more to prevent premature timeout. + case progress == Progressed: + + timeout.Stop() + timeout = time.NewTimer( + job.timeout, + ) + + continue feedbackLoop + + // Assign true to jobUnfinished to indicate that we need to reschedule job to complete request. + case progress == UnFinishedRequest: + + jobUnfinished = true + + // Assign ErrIgnoreRequest to indicate that workmanager should take no action on receipt of + // this request. + case progress == IgnoreRequest: + + jobErr = ErrIgnoreRequest + + // Assign ErrResponseErr to jobErr if we received a valid response that did not pass checks. + case progress == ResponseErr: + + jobErr = ErrResponseErr } - // We did get a valid response, and can break - // the loop. - break Loop + break feedbackLoop // If the timeout is reached before a valid response // has been received, we exit with an error. @@ -193,7 +229,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { "with job index %v", peer.Addr(), job.Req, job.Index()) - break Loop + break feedbackLoop // If the peer disconnects before giving us a valid // answer, we'll also exit with an error. @@ -203,7 +239,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { job.Index()) jobErr = ErrPeerDisconnected - break Loop + break feedbackLoop // If the job was canceled, we report this back to the // work manager. @@ -212,7 +248,7 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { peer.Addr(), job.Index()) jobErr = ErrJobCanceled - break Loop + break feedbackLoop case <-quit: return @@ -222,18 +258,49 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { // Stop to allow garbage collection. timeout.Stop() + // This is necessary to avoid a situation where future changes to the job's request affect the current job. + // For example: suppose we want to fetch headers between checkpoints 0 and 20,000. The maximum number of headers + // that a peer can send in one message is 2000. When we receive 2000 headers for one request, + // we update the job's request, changing its startheight and blocklocator to match the next batch of headers + // that we want to fetch. Since we are not done with fetching our target of 20,000 headers, + // we will have to make more changes to the job's request in the future. This could alter previous requests, + // resulting in unwanted behaviour. + resultJob := &queryJob{ + index: job.Index(), + Request: &Request{ + Req: job.CloneReq(job.Req), + HandleResp: job.Request.HandleResp, + CloneReq: job.Request.CloneReq, + SendQuery: job.Request.SendQuery, + }, + cancelChan: job.cancelChan, + tries: job.tries, + timeout: job.timeout, + } + // We have a result ready for the query, hand it off before // getting a new job. select { case results <- &jobResult{ - job: job, - peer: peer, - err: jobErr, + job: resultJob, + peer: peer, + err: jobErr, + unfinished: jobUnfinished, }: case <-quit: return } + // If the error is a timeout still wait for the response as we are assured a response as long as there was a + // request but reschedule on another worker to quickly fetch a response so as not to be slowed down by this + // worker. We either get a response or the peer stalls (i.e. disconnects due to an elongated time without + // a response) + if jobErr == ErrQueryTimeout { + jobErr = nil + + goto feedbackLoop + } + // If the peer disconnected, we can exit immediately. if jobErr == ErrPeerDisconnected { return @@ -241,6 +308,28 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { } } +func (w *worker) IsSyncCandidate() bool { + return w.peer.IsSyncCandidate() +} + +func (w *worker) IsPeerBehindStartHeight(req ReqMessage) bool { + return w.peer.IsPeerBehindStartHeight(req) +} + +// IsWorkerEligibleForBlkHdrFetch is the eligibility function used for the BlockHdrWorkManager to determine workers +// eligible to receive jobs (the job is to fetch headers). If the peer is not a sync candidate or if its last known +// block height is behind the job query's start height, it returns false. Otherwise, it returns true. +func IsWorkerEligibleForBlkHdrFetch(r *activeWorker, next *queryJob) bool { + if !r.w.IsSyncCandidate() { + return false + } + + if r.w.IsPeerBehindStartHeight(next.Req) { + return false + } + return true +} + // NewJob returns a channel where work that is to be handled by the worker can // be sent. If the worker reads a queryJob from this channel, it is guaranteed // that a response will eventually be deliverd on the results channel (except diff --git a/query/worker_test.go b/query/worker_test.go index 8cb5c17d..dcb640f3 100644 --- a/query/worker_test.go +++ b/query/worker_test.go @@ -1,6 +1,7 @@ package query import ( + "errors" "fmt" "testing" "time" @@ -8,14 +9,42 @@ import ( "github.com/btcsuite/btcd/wire" ) +type mockQueryEncoded struct { + message *wire.MsgGetData + encoding wire.MessageEncoding + index float64 + startHeight int +} + +func (m *mockQueryEncoded) Message() wire.Message { + return m.message +} + +func (m *mockQueryEncoded) PriorityIndex() float64 { + return m.index +} + var ( - req = &wire.MsgGetData{} + msg = &wire.MsgGetData{} + req = &mockQueryEncoded{ + message: msg, + encoding: wire.WitnessEncoding, + } progressResp = &wire.MsgTx{ Version: 111, } finalResp = &wire.MsgTx{ Version: 222, } + UnfinishedRequestResp = &wire.MsgTx{ + Version: 333, + } + finalRespWithErr = &wire.MsgTx{ + Version: 444, + } + IgnoreRequestResp = &wire.MsgTx{ + Version: 444, + } ) type mockPeer struct { @@ -24,16 +53,13 @@ type mockPeer struct { responses chan<- wire.Message subscriptions chan chan wire.Message quit chan struct{} + bestHeight int + fullNode bool + err error } var _ Peer = (*mockPeer)(nil) -func (m *mockPeer) QueueMessageWithEncoding(msg wire.Message, - doneChan chan<- struct{}, encoding wire.MessageEncoding) { - - m.requests <- msg -} - func (m *mockPeer) SubscribeRecvMsg() (<-chan wire.Message, func()) { msgChan := make(chan wire.Message) m.subscriptions <- msgChan @@ -49,37 +75,70 @@ func (m *mockPeer) Addr() string { return m.addr } +func (m *mockPeer) IsPeerBehindStartHeight(request ReqMessage) bool { + r := request.(*mockQueryEncoded) + return m.bestHeight < r.startHeight +} + +func (m *mockPeer) IsSyncCandidate() bool { + return m.fullNode +} + // makeJob returns a new query job that will be done when it is given the // finalResp message. Similarly ot will progress on being given the // progressResp message, while any other message will be ignored. func makeJob() *queryJob { q := &Request{ Req: req, - HandleResp: func(req, resp wire.Message, _ string) Progress { + HandleResp: func(req ReqMessage, resp wire.Message, peer Peer) Progress { if resp == finalResp { - return Progress{ - Finished: true, - Progressed: true, - } + return Finished } if resp == progressResp { - return Progress{ - Finished: false, - Progressed: true, - } + return Progressed + } + + if resp == UnfinishedRequestResp { + return UnFinishedRequest } - return Progress{ - Finished: false, - Progressed: false, + if resp == finalRespWithErr { + return ResponseErr } + if resp == IgnoreRequestResp { + return IgnoreRequest + } + return NoResponse + }, + SendQuery: func(peer Peer, req ReqMessage) error { + m := peer.(*mockPeer) + + if m.err != nil { + return m.err + } + + m.requests <- req.Message() + return nil + }, + CloneReq: func(req ReqMessage) ReqMessage { + oldReq := req.(*mockQueryEncoded) + + newMsg := &wire.MsgGetData{ + InvList: oldReq.message.InvList, + } + + clone := &mockQueryEncoded{ + message: newMsg, + } + + return clone }, } + return &queryJob{ index: 123, timeout: 30 * time.Second, - encoding: defaultQueryEncoding, cancelChan: nil, Request: q, } @@ -185,9 +244,20 @@ func TestWorkerIgnoreMsgs(t *testing.T) { t.Fatalf("response error: %v", result.err) } - // Make sure the result was given for the intended job. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") + } + + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") } // And the correct peer. @@ -240,9 +310,15 @@ func TestWorkerTimeout(t *testing.T) { t.Fatalf("expected timeout, got: %v", result.err) } - // Make sure the result was given for the intended job. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -251,11 +327,16 @@ func TestWorkerTimeout(t *testing.T) { ctx.peer.Addr(), result.peer) } + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") + } + // It will immediately attempt to fetch another task. select { case ctx.nextJob <- task: + t.Fatalf("worker still in feedback loop picked up job") case <-time.After(1 * time.Second): - t.Fatalf("did not pick up job") } } @@ -299,9 +380,15 @@ func TestWorkerDisconnect(t *testing.T) { t.Fatalf("expected peer disconnect, got: %v", result.err) } - // Make sure the result was given for the intended job. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") } // And the correct peer. @@ -310,6 +397,11 @@ func TestWorkerDisconnect(t *testing.T) { ctx.peer.Addr(), result.peer) } + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") + } + // No more jobs should be accepted by the worker after it has exited. select { case ctx.nextJob <- task: @@ -338,64 +430,121 @@ func TestWorkerProgress(t *testing.T) { } // Create a task with a small timeout, and give it to the worker. - task := makeJob() - task.timeout = taskTimeout - - select { - case ctx.nextJob <- task: - case <-time.After(1 * time.Second): - t.Fatalf("did not pick up job") + type testResp struct { + name string + response *wire.MsgTx + err *error + unfinished bool } - // The request should be given to the peer. - select { - case <-ctx.peer.requests: - case <-time.After(time.Second): - t.Fatalf("request not sent") - } + testCases := []testResp{ - // Send a few other responses that indicates progress, but not success. - // We add a small delay between each time we send a response. In total - // the delay will be larger than the query timeout, but since we are - // making progress, the timeout won't trigger. - for i := 0; i < 5; i++ { - select { - case ctx.peer.responses <- progressResp: - case <-time.After(time.Second): - t.Fatalf("resp not received") - } + { + name: "final response.", + response: finalResp, + }, - time.Sleep(taskTimeout / 2) - } + { + name: "Unfinished request response.", + response: UnfinishedRequestResp, + unfinished: true, + }, - // Finally send the final response. - select { - case ctx.peer.responses <- finalResp: - case <-time.After(time.Second): - t.Fatalf("resp not received") - } + { + name: "ignore request", + response: IgnoreRequestResp, + err: &ErrIgnoreRequest, + }, - // The worker should respond with a job finised. - var result *jobResult - select { - case result = <-ctx.jobResults: - case <-time.After(time.Second): - t.Fatalf("response not received") + { + name: "final response, with err", + response: finalRespWithErr, + err: &ErrResponseErr, + }, } - if result.err != nil { - t.Fatalf("expected no error, got: %v", result.err) - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + task := makeJob() + task.timeout = taskTimeout - // Make sure the result was given for the intended task. - if result.job != task { - t.Fatalf("got result for unexpected job") - } + select { + case ctx.nextJob <- task: + case <-time.After(1 * time.Second): + t.Fatalf("did not pick up job") + } - // And the correct peer. - if result.peer != ctx.peer { - t.Fatalf("expected peer to be %v, was %v", - ctx.peer.Addr(), result.peer) + // The request should be given to the peer. + select { + case <-ctx.peer.requests: + case <-time.After(time.Second): + t.Fatalf("request not sent") + } + + // Send a few other responses that indicates progress, but not success. + // We add a small delay between each time we send a response. In total + // the delay will be larger than the query timeout, but since we are + // making progress, the timeout won't trigger. + for i := 0; i < 5; i++ { + select { + case ctx.peer.responses <- progressResp: + case <-time.After(time.Second): + t.Fatalf("resp not received") + } + + time.Sleep(taskTimeout / 2) + } + + // Finally send the final response. + select { + case ctx.peer.responses <- tc.response: + case <-time.After(time.Second): + t.Fatalf("resp not received") + } + + // The worker should respond with a job finished. + var result *jobResult + select { + case result = <-ctx.jobResults: + case <-time.After(time.Second): + t.Fatalf("response not received") + } + + if tc.err == nil && result.err != nil { + t.Fatalf("expected no error, got: %v", result.err) + } + + if tc.err != nil && result.err != *tc.err { + t.Fatalf("expected error, %v but got: %v", *tc.err, + result.err) + } + + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") + } + + // And the correct peer. + if result.peer != ctx.peer { + t.Fatalf("expected peer to be %v, was %v", + ctx.peer.Addr(), result.peer) + } + + // Make sure job does not return as unfinished. + if tc.unfinished && !result.unfinished { + t.Fatalf("expected job unfinished but got job finished") + } + + if !tc.unfinished && result.unfinished { + t.Fatalf("expected job finished but got unfinished job") + } + }) } } @@ -460,9 +609,20 @@ func TestWorkerJobCanceled(t *testing.T) { t.Fatalf("expected job canceled, got: %v", result.err) } - // Make sure the result was given for the intended task. - if result.job != task { - t.Fatalf("got result for unexpected job") + // Make sure the QueryJob instance in the result is different from the initial one + // supplied to the worker + if result.job == task { + t.Fatalf("result's job should be different from the task's") + } + + // Make sure we are receiving the corresponding result for the given task. + if result.job.Index() != task.Index() { + t.Fatalf("result's job index should not be different from task's") + } + + // Make sure job does not return as unfinished. + if result.unfinished { + t.Fatalf("got unfinished job") } // And the correct peer. @@ -472,3 +632,71 @@ func TestWorkerJobCanceled(t *testing.T) { } } } + +// TestWorkerSendQueryErr will test if the result would return an error +// that would be handled by the worker if there is an error returned while +// sending a query. +func TestWorkerSendQueryErr(t *testing.T) { + t.Parallel() + + ctx, err := startWorker() + if err != nil { + t.Fatalf("unable to start worker: %v", err) + } + + cancelChan := make(chan struct{}) + + // Give the worker a new job. + taskJob := makeJob() + taskJob.cancelChan = cancelChan + + // Assign error to be returned while sending query. + ctx.peer.err = errors.New("query error") + + // Send job to worker + select { + case ctx.nextJob <- taskJob: + case <-time.After(1 * time.Second): + t.Fatalf("did not pick up job") + } + + // Request should not be sent as there should be an error while + // querying. + select { + case <-ctx.peer.requests: + t.Fatalf("request sent when query failed") + case <-time.After(time.Second): + } + + // jobResult should be sent by worker at this point. + var result *jobResult + select { + case result = <-ctx.jobResults: + case <-time.After(time.Second): + t.Fatalf("response not received") + } + + // jobResult should contain error. + if result.err != ctx.peer.err { + t.Fatalf("expected result's error to be %v, was %v", + ctx.peer.err, result.err) + } + + // Make sure the QueryJob instance in the result is same as the taskJob's. + if result.job != taskJob { + t.Fatalf("result's job should be same as the taskJob's") + } + + // And the correct peer. + if result.peer != ctx.peer { + t.Fatalf("expected peer to be %v, was %v", + ctx.peer.Addr(), result.peer) + } + + // The worker should be in the nextJob Loop. + select { + case ctx.nextJob <- taskJob: + case <-time.After(1 * time.Second): + t.Fatalf("did not pick up job") + } +} diff --git a/query/workmanager.go b/query/workmanager.go index e99f57ab..0f691a96 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -26,7 +26,6 @@ var ( type batch struct { requests []*Request options *queryOptions - errChan chan error } // Worker is the interface that must be satisfied by workers managed by the @@ -47,6 +46,13 @@ type Worker interface { // delivered on the results channel (except when the quit channel has // been closed). NewJob() chan<- *queryJob + + // IsPeerBehindStartHeight returns a boolean indicating if the peer's known last height is behind + // the request's start Height which it receives as an argument. + IsPeerBehindStartHeight(req ReqMessage) bool + + // IsSyncCandidate returns if the peer is a sync candidate. + IsSyncCandidate() bool } // PeerRanking is an interface that must be satisfied by the underlying module @@ -94,6 +100,9 @@ type Config struct { // Ranking is used to rank the connected peers when determining who to // give work to. Ranking PeerRanking + + // IsEligibleWorkerFunc determines which peer is eligible to receive a job. + IsEligibleWorkerFunc func(r *activeWorker, next *queryJob) bool } // peerWorkManager is the main access point for outside callers, and satisfies @@ -192,15 +201,17 @@ func (w *peerWorkManager) workDispatcher() { // batches and send on their error channel. defer func() { for _, b := range currentBatches { - b.errChan <- ErrWorkManagerShuttingDown + if b.errChan != nil { + b.errChan <- ErrWorkManagerShuttingDown + } } }() // We set up a counter that we'll increase with each incoming query, // and will serve as the priority of each. In addition we map each // query to the batch they are part of. - queryIndex := uint64(0) - currentQueries := make(map[uint64]uint64) + queryIndex := float64(0) + currentQueries := make(map[float64]uint64) workers := make(map[string]*activeWorker) @@ -212,7 +223,7 @@ Loop: next := work.Peek().(*queryJob) // Find the peers with free work slots available. - var freeWorkers []string + var freeEligibleWorkers []string for p, r := range workers { // Only one active job at a time is currently // supported. @@ -220,15 +231,24 @@ Loop: continue } - freeWorkers = append(freeWorkers, p) + // If there is a specified eligibilty function for + // the peer, use it to determine which peers we can + // send jobs to. + if w.cfg.IsEligibleWorkerFunc != nil { + if !w.cfg.IsEligibleWorkerFunc(r, next) { + continue + } + } + + freeEligibleWorkers = append(freeEligibleWorkers, p) } // Use the historical data to rank them. - w.cfg.Ranking.Order(freeWorkers) + w.cfg.Ranking.Order(freeEligibleWorkers) // Give the job to the highest ranked peer with free // slots available. - for _, p := range freeWorkers { + for _, p := range freeEligibleWorkers { r := workers[p] // The worker has free work slots, it should @@ -298,7 +318,9 @@ Loop: // Delete the job from the worker's active job, such // that the slot gets opened for more work. r := workers[result.peer.Addr()] - r.activeJob = nil + if result.err != ErrQueryTimeout { + r.activeJob = nil + } // Get the index of this query's batch, and delete it // from the map of current queries, since we don't have @@ -320,13 +342,19 @@ Loop: // batch's error channel. We do this since a // cancellation applies to the whole batch. if batch != nil { - batch.errChan <- result.err + if batch.errChan != nil { + batch.errChan <- result.err + } delete(currentBatches, batchNum) log.Debugf("Canceled batch %v", batchNum) continue Loop } + // Take no action if we are to ignore request. + case result.err == ErrIgnoreRequest: + log.Debugf("received ignore request") + continue Loop // If the query ended with any other error, put it back // into the work queue if it has not reached the @@ -354,7 +382,9 @@ Loop: // Return the error and cancel the // batch. - batch.errChan <- result.err + if batch.errChan != nil { + batch.errChan <- result.err + } delete(currentBatches, batchNum) log.Debugf("Canceled batch %v", @@ -386,6 +416,19 @@ Loop: // Reward the peer for the successful query. w.cfg.Ranking.Reward(result.peer.Addr()) + // If the result is unfinished add 0.0005 to the job index to maintain the + // required priority then push to work queue + if result.unfinished { + result.job.index = result.job.Index() + 0.0005 + log.Debugf("job %v is unfinished, creating new index", result.job.Index()) + + heap.Push(work, result.job) + batch.rem++ + currentQueries[result.job.Index()] = batchNum + } else { + log.Debugf("job %v is Finished", result.job.Index()) + } + // Decrement the number of queries remaining in // the batch. if batch != nil { @@ -397,7 +440,9 @@ Loop: // for this batch, we can notify that // it finished, and delete it. if batch.rem == 0 { - batch.errChan <- nil + if batch.errChan != nil { + batch.errChan <- nil + } delete(currentBatches, batchNum) log.Tracef("Batch %v done", @@ -412,7 +457,9 @@ Loop: if batch != nil { select { case <-batch.timeout: - batch.errChan <- ErrQueryTimeout + if batch.errChan != nil { + batch.errChan <- ErrQueryTimeout + } delete(currentBatches, batchNum) log.Warnf("Query(%d) failed with "+ @@ -435,15 +482,25 @@ Loop: "work queue", batchIndex, len(batch.requests)) for _, q := range batch.requests { + idx := queryIndex + + // If priority index is set, use that index. + if q.Req.PriorityIndex() != 0 { + idx = q.Req.PriorityIndex() + } heap.Push(work, &queryJob{ - index: queryIndex, + index: idx, timeout: minQueryTimeout, - encoding: batch.options.encoding, cancelChan: batch.options.cancelChan, Request: q, }) - currentQueries[queryIndex] = batchIndex - queryIndex++ + currentQueries[idx] = batchIndex + + // Only increment queryIndex if it was + // assigned to this job. + if q.Req.PriorityIndex() == 0 { + queryIndex++ + } } currentBatches[batchIndex] = &batchProgress{ @@ -451,7 +508,7 @@ Loop: maxRetries: batch.options.numRetries, timeout: time.After(batch.options.timeout), rem: len(batch.requests), - errChan: batch.errChan, + errChan: batch.options.errChan, } batchIndex++ @@ -470,18 +527,19 @@ func (w *peerWorkManager) Query(requests []*Request, qo := defaultQueryOptions() qo.applyQueryOptions(options...) - errChan := make(chan error, 1) + newBatch := &batch{ + requests: requests, + options: qo, + } // Add query messages to the queue of batches to handle. select { - case w.newBatches <- &batch{ - requests: requests, - options: qo, - errChan: errChan, - }: + case w.newBatches <- newBatch: case <-w.quit: - errChan <- ErrWorkManagerShuttingDown + if newBatch.options.errChan != nil { + newBatch.options.errChan <- ErrWorkManagerShuttingDown + } } - return errChan + return newBatch.options.errChan } diff --git a/query/workmanager_test.go b/query/workmanager_test.go index b7bec809..075d1463 100644 --- a/query/workmanager_test.go +++ b/query/workmanager_test.go @@ -15,6 +15,14 @@ type mockWorker struct { results chan *jobResult } +func (m *mockWorker) IsPeerBehindStartHeight(req ReqMessage) bool { + return m.peer.IsPeerBehindStartHeight(req) +} + +func (m *mockWorker) IsSyncCandidate() bool { + return m.peer.IsSyncCandidate() +} + var _ Worker = (*mockWorker)(nil) func (m *mockWorker) NewJob() chan<- *queryJob { @@ -63,9 +71,14 @@ func (p *mockPeerRanking) Punish(peer string) { func (p *mockPeerRanking) Reward(peer string) { } +type ctx struct { + wm WorkManager + peerChan chan Peer +} + // startWorkManager starts a new workmanager with the given number of mock // workers. -func startWorkManager(t *testing.T, numWorkers int) (WorkManager, +func startWorkManager(t *testing.T, numWorkers int) (ctx, []*mockWorker) { // We set up a custom NewWorker closure for the WorkManager, such that @@ -116,7 +129,10 @@ func startWorkManager(t *testing.T, numWorkers int) (WorkManager, workers[i] = w } - return wm, workers + return ctx{ + wm: wm, + peerChan: peerChan, + }, workers } // TestWorkManagerWorkDispatcherSingleWorker tests that the workDispatcher @@ -126,22 +142,25 @@ func TestWorkManagerWorkDispatcherSingleWorker(t *testing.T) { const numQueries = 100 // Start work manager with a sinlge worker. - wm, workers := startWorkManager(t, 1) - + c, workers := startWorkManager(t, 1) + wm := c.wm // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries = append(queries, q) } - errChan := wm.Query(queries) + errChan := wm.Query(queries, ErrChan(make(chan error, 1))) wk := workers[0] // Each query should be sent on the nextJob queue, in the order they // had in their batch. - for i := uint64(0); i < numQueries; i++ { + for i := float64(0); i < numQueries; i++ { var job *queryJob select { case job = <-wk.nextJob: @@ -186,7 +205,8 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { // Start work manager with as many workers as queries. This is not very // realistic, but makes the work manager able to schedule all queries // concurrently. - wm, workers := startWorkManager(t, numQueries) + c, workers := startWorkManager(t, numQueries) + wm := c.wm // When the jobs gets scheduled, keep track of which worker was // assigned the job. @@ -199,7 +219,9 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { var scheduledJobs [numQueries]chan sched var queries [numQueries]*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries[i] = q scheduledJobs[i] = make(chan sched) } @@ -212,7 +234,7 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { go func() { for { job := <-wk.nextJob - scheduledJobs[job.index] <- sched{ + scheduledJobs[int(job.index)] <- sched{ wk: wk, job: job, } @@ -221,14 +243,14 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { } // Send the batch, and Retrieve all jobs immediately. - errChan := wm.Query(queries[:]) + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) var jobs [numQueries]sched - for i := uint64(0); i < numQueries; i++ { + for i := 0; i < numQueries; i++ { var s sched select { case s = <-scheduledJobs[i]: - if s.job.index != i { + if s.job.index != float64(i) { t.Fatalf("wrong index") } @@ -238,7 +260,7 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { t.Fatalf("next job not received") } - jobs[s.job.index] = s + jobs[int(s.job.index)] = s } // Go backwards, and fail half of them. @@ -262,10 +284,10 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { // Finally, make sure the failed jobs are being retried, in the same // order as they were originally scheduled. - for i := uint64(0); i < numQueries; i += 2 { + for i := float64(0); i < numQueries; i += 2 { var s sched select { - case s = <-scheduledJobs[i]: + case s = <-scheduledJobs[int(i)]: if s.job.index != i { t.Fatalf("wrong index") } @@ -297,25 +319,121 @@ func TestWorkManagerWorkDispatcherFailures(t *testing.T) { } } +// TestWorkManagerErrQueryTimeout tests that the workers that return query +// timeout are not sent jobs until they return a different error. +func TestWorkManagerErrQueryTimeout(t *testing.T) { + const numQueries = 1 + const numWorkers = 1 + + // Start work manager. + c, workers := startWorkManager(t, numWorkers) + wm := c.wm + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var scheduledJobs [numQueries]chan sched + var queries [numQueries]*Request + for i := 0; i < numQueries; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries[i] = q + scheduledJobs[i] = make(chan sched) + } + + // Spin up goroutine for only one worker. Forward gotten jobs + // to our slice of scheduled jobs, such that we can handle them in + // order. + wk := workers[0] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) + + var s sched + + // Ensure job is sent to the worker. + select { + case s = <-scheduledJobs[0]: + if s.job.index != float64(0) { + t.Fatalf("wrong index") + } + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + // Return jobResult with an ErrQueryTimeout. + select { + case s.wk.results <- &jobResult{ + job: s.job, + err: ErrQueryTimeout, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + + // Make sure the job is not retried as there are no available + // peer to retry it. The only available worker should be waiting in the + // worker feedback loop. + select { + case <-scheduledJobs[0]: + t.Fatalf("did not expect job rescheduled") + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + } + + // There should be no errChan message as query is still incomplete. + select { + case err := <-errChan: + if err != nil { + t.Fatalf("got error: %v", err) + } + t.Fatalf("expected no errChan message") + case <-time.After(time.Second): + } +} + // TestWorkManagerCancelBatch checks that we can cancel a batch query midway, // and that the jobs it contains are canceled. func TestWorkManagerCancelBatch(t *testing.T) { const numQueries = 100 // Start the workDispatcher goroutine. - wm, workers := startWorkManager(t, 1) + c, workers := startWorkManager(t, 1) + wm := c.wm wk := workers[0] // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries = append(queries, q) } // Send the query, and include a channel to cancel the batch. cancelChan := make(chan struct{}) - errChan := wm.Query(queries, Cancel(cancelChan)) + errChan := wm.Query(queries, Cancel(cancelChan), ErrChan(make(chan error, 1))) // Respond with a result to half of the queries. for i := 0; i < numQueries/2; i++ { @@ -386,7 +504,8 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { const numQueries = 4 const numWorkers = 8 - workMgr, workers := startWorkManager(t, numWorkers) + c, workers := startWorkManager(t, numWorkers) + workMgr := c.wm require.IsType(t, workMgr, &peerWorkManager{}) wm := workMgr.(*peerWorkManager) //nolint:forcetypeassert @@ -399,19 +518,21 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { // Schedule a batch of queries. var queries []*Request for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries = append(queries, q) } // Send the batch, and Retrieve all jobs immediately. - errChan := wm.Query(queries) + errChan := wm.Query(queries, ErrChan(make(chan error, 1))) // The 4 first workers should get the job. var jobs []*queryJob for i := 0; i < numQueries; i++ { select { case job := <-workers[i].nextJob: - if job.index != uint64(i) { + if job.index != float64(i) { t.Fatalf("unexpected job") } jobs = append(jobs, job) @@ -462,10 +583,12 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { // Send a new set of queries. queries = nil for i := 0; i < numQueries; i++ { - q := &Request{} + q := &Request{ + Req: &mockQueryEncoded{}, + } queries = append(queries, q) } - _ = wm.Query(queries) + _ = wm.Query(queries, ErrChan(make(chan error, 1))) // The new jobs should be scheduled on the even numbered workers. for i := 0; i < len(workers); i += 2 { @@ -476,3 +599,499 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { } } } + +// TestWorkManagerSchedulePriorityIndex tests that the workmanager acknowledges +// priority index. +func TestWorkManagerSchedulePriorityIndex(t *testing.T) { + const numQueries = 3 + + // Start work manager with as many workers as queries. This is not very + // realistic, but makes the work manager able to schedule all queries + // concurrently. + c, workers := startWorkManager(t, numQueries) + wm := c.wm + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var scheduledJobs [5]chan sched + var queries [numQueries]*Request + for i := 0; i < numQueries; i++ { + var q *Request + idx := i + if i == 0 { + q = &Request{ + Req: &mockQueryEncoded{}, + } + } else { + // Assign priority index. + idx = i + 2 + q = &Request{ + Req: &mockQueryEncoded{ + index: float64(idx), + }, + } + } + queries[i] = q + scheduledJobs[idx] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) + + var jobs [numQueries]sched + for i := uint64(0); i < numQueries; i++ { + var expectedIndex float64 + + if i == 0 { + expectedIndex = float64(0) + } else { + expectedIndex = float64(i + 2) + } + var s sched + select { + case s = <-scheduledJobs[int(expectedIndex)]: + + if s.job.index != expectedIndex { + t.Fatalf("wrong index: Got %v but expected %v", s.job.index, + expectedIndex) + } + case <-errChan: + t.Fatalf("did not expect an errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + jobs[i] = s + } + + // Go backwards send results for job. + for i := numQueries - 1; i >= 0; i-- { + select { + case jobs[i].wk.results <- &jobResult{ + job: jobs[i].job, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // Finally, make sure no jobs are retried. + for i := uint64(0); i < numQueries; i++ { + select { + case <-scheduledJobs[i]: + t.Fatalf("did not expect a retried job") + case <-time.After(time.Second): + } + } + + // The query should ultimately succeed. + select { + case err := <-errChan: + if err != nil { + t.Fatalf("got error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("nothing received on errChan") + } +} + +// TestPeerWorkManager_Stop tests the workmanager shutdown. +func TestPeerWorkManager_Stop(t *testing.T) { + const numQueries = 5 + + c, _ := startWorkManager(t, 0) + wm := c.wm + + createRequest := func(numQuery int) []*Request { + var queries []*Request + for i := 0; i < numQuery; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries = append(queries, q) + } + + return queries + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(createRequest(numQueries), ErrChan(make(chan error, 1))) + errChan2 := wm.Query(createRequest(numQueries)) + + if errChan2 != nil { + t.Fatalf("expected Query call without ErrChan option func to return" + + "niil errChan") + } + + errChan3 := make(chan error, 1) + go func() { + err := wm.Stop() + + errChan3 <- err + }() + + select { + case <-errChan: + case <-time.After(time.Second): + t.Fatalf("expected error workmanager shutting down") + } + + select { + case err := <-errChan3: + if err != nil { + t.Fatalf("unexpected error while stopping workmanager: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("workmanager stop functunction should return error") + } +} + +// TestWorkManagerErrResponseExistForQuery tests a scenario in which a workmanager handles +// an ErrIgnoreRequest. +func TestWorkManagerErrResponseExistForQuery(t *testing.T) { + const numQueries = 5 + + // Start work manager with as many workers as queries. This is not very + // realistic, but makes the work manager able to schedule all queries + // concurrently. + c, workers := startWorkManager(t, numQueries) + wm := c.wm + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var ( + queries [numQueries]*Request + scheduledJobs [numQueries]chan sched + ) + for i := 0; i < numQueries; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries[i] = q + scheduledJobs[i] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) + var jobs [numQueries]sched + for i := 0; i < numQueries; i++ { + var s sched + select { + case s = <-scheduledJobs[i]: + if s.job.index != float64(i) { + t.Fatalf("wrong index") + } + + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + jobs[int(s.job.index)] = s + } + + // Go backwards, and make half of it return with an ErrIgnoreRequest. + for i := numQueries - 1; i >= 0; i-- { + select { + case jobs[i].wk.results <- &jobResult{ + job: jobs[i].job, + err: ErrIgnoreRequest, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // Finally, make sure the failed jobs are not retried. + for i := 0; i < numQueries; i++ { + var s sched + select { + case s = <-scheduledJobs[i]: + t.Fatalf("did not expect any retried job but job"+ + "%v\n retried", s.job.index) + case <-errChan: + t.Fatalf("did not expect an errChan") + case <-time.After(time.Second): + } + } +} + +// TestWorkManagerResultUnfinished tests the workmanager handling a result with an unfinished boolean set +// to true. +func TestWorkManagerResultUnfinished(t *testing.T) { + const numQueries = 10 + + // Start work manager with as many workers as queries. This is not very + // realistic, but makes the work manager able to schedule all queries + // concurrently. + c, workers := startWorkManager(t, numQueries) + wm := c.wm + + // When the jobs gets scheduled, keep track of which worker was + // assigned the job. + type sched struct { + wk *mockWorker + job *queryJob + } + + // Schedule a batch of queries. + var ( + queries [numQueries]*Request + scheduledJobs [numQueries]chan sched + ) + for i := 0; i < numQueries; i++ { + q := &Request{ + Req: &mockQueryEncoded{}, + } + queries[i] = q + scheduledJobs[i] = make(chan sched) + } + + // Fot each worker, spin up a goroutine that will forward the job it + // got to our slice of scheduled jobs, such that we can handle them in + // order. + for i := 0; i < len(workers); i++ { + wk := workers[i] + go func() { + for { + job := <-wk.nextJob + scheduledJobs[int(job.index)] <- sched{ + wk: wk, + job: job, + } + } + }() + } + + // Send the batch, and Retrieve all jobs immediately. + errChan := wm.Query(queries[:], ErrChan(make(chan error, 1))) + var jobs [numQueries]sched + for i := 0; i < numQueries; i++ { + var s sched + select { + case s = <-scheduledJobs[i]: + if s.job.index != float64(i) { + t.Fatalf("wrong index") + } + + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + + jobs[int(s.job.index)] = s + } + + // Go backwards, and make half of it unfinished. + for i := numQueries - 1; i >= 0; i-- { + var ( + unfinished bool + ) + if i%2 == 0 { + unfinished = true + } + + select { + case jobs[i].wk.results <- &jobResult{ + job: jobs[i].job, + unfinished: unfinished, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // Finally, make sure the failed jobs are being retried, in the same + // order as they were originally scheduled. + for i := 0; i < numQueries; i += 2 { + var s sched + select { + case s = <-scheduledJobs[i]: + + // The new tindex the job should have. + idx := float64(i) + 0.0005 + if idx != s.job.index { + t.Fatalf("expected index %v for job"+ + "but got, %v\n", idx, s.job.index) + } + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + select { + case s.wk.results <- &jobResult{ + job: s.job, + err: nil, + }: + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(time.Second): + t.Fatalf("result not handled") + } + } + + // The query should ultimately succeed. + select { + case err := <-errChan: + if err != nil { + t.Fatalf("got error: %v", err) + } + case <-time.After(time.Second): + t.Fatalf("nothing received on errChan") + } +} + +// TestIsWorkerEligibleForBlkHdrFetch tests the IsWorkerEligibleForBlkHdrFetch function. +func TestIsWorkerEligibleForBlkHdrFetch(t *testing.T) { + type testArgs struct { + name string + activeWorker *activeWorker + job *queryJob + expectedEligibility bool + } + + testCases := []testArgs{ + { + name: "peer sync candidate, best height behind job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 5, + fullNode: true, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 10, + }, + }, + }, + expectedEligibility: false, + }, + + { + name: "peer sync candidate, best height ahead job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 10, + fullNode: true, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 5, + }, + }, + }, + expectedEligibility: true, + }, + + { + name: "peer not sync candidate, best height behind job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 5, + fullNode: false, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 10, + }, + }, + }, + expectedEligibility: false, + }, + + { + name: "peer not sync candidate, best height ahead job start Height", + activeWorker: &activeWorker{ + w: &mockWorker{ + peer: &mockPeer{ + bestHeight: 10, + fullNode: false, + }, + }, + }, + job: &queryJob{ + Request: &Request{ + Req: &mockQueryEncoded{ + startHeight: 5, + }, + }, + }, + expectedEligibility: false, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + isEligible := IsWorkerEligibleForBlkHdrFetch(test.activeWorker, test.job) + if isEligible != test.expectedEligibility { + t.Fatalf("Expected '%v'for eligibility check but got"+ + "'%v'\n", test.expectedEligibility, isEligible) + } + }) + } +} diff --git a/query/workqueue.go b/query/workqueue.go index 9a92ce8f..cc9b2f2e 100644 --- a/query/workqueue.go +++ b/query/workqueue.go @@ -4,7 +4,7 @@ package query // work queue. type Task interface { // Index returns this Task's index in the work queue. - Index() uint64 + Index() float64 } // workQueue is struct implementing the heap interface, and is used to keep a diff --git a/query/workqueue_test.go b/query/workqueue_test.go index d9abc66a..b0f7bae0 100644 --- a/query/workqueue_test.go +++ b/query/workqueue_test.go @@ -6,12 +6,12 @@ import ( ) type task struct { - index uint64 + index float64 } var _ Task = (*task)(nil) -func (t *task) Index() uint64 { +func (t *task) Index() float64 { return t.index } @@ -27,7 +27,7 @@ func TestWorkQueue(t *testing.T) { // Create a simple list of tasks and add them all to the queue. var tasks []*task - for i := uint64(0); i < numTasks; i++ { + for i := float64(0); i < numTasks; i++ { tasks = append(tasks, &task{ index: i, }) @@ -40,11 +40,11 @@ func TestWorkQueue(t *testing.T) { // Check that it reports the expected number of elements. l := q.Len() if l != numTasks { - t.Fatalf("expected %d length, was %d", numTasks, l) + t.Fatalf("expected %v length, was %v", numTasks, l) } // Pop half, and make sure they arrive in the right order. - for i := uint64(0); i < numTasks/2; i++ { + for i := float64(0); i < numTasks/2; i++ { peek := q.Peek().(*task) pop := heap.Pop(q) @@ -54,7 +54,7 @@ func TestWorkQueue(t *testing.T) { } if peek.index != i { - t.Fatalf("wrong index: %d", peek.index) + t.Fatalf("wrong index: %v", peek.index) } } @@ -63,7 +63,7 @@ func TestWorkQueue(t *testing.T) { heap.Push(q, tasks[0]) } - for i := uint64(numTasks/2 - 3); i < numTasks; i++ { + for i := float64(numTasks/2 - 3); i < numTasks; i++ { peek := q.Peek().(*task) pop := heap.Pop(q) @@ -80,13 +80,13 @@ func TestWorkQueue(t *testing.T) { } if peek.index != exp { - t.Fatalf("wrong index: %d", peek.index) + t.Fatalf("wrong index: %v", peek.index) } } // Finally, the queue should be empty. l = q.Len() if l != 0 { - t.Fatalf("expected %d length, was %d", 0, l) + t.Fatalf("expected %v length, was %v", 0, l) } } diff --git a/query_test.go b/query_test.go index cbc9a74a..f8610b6e 100644 --- a/query_test.go +++ b/query_test.go @@ -303,9 +303,9 @@ func TestBlockCache(t *testing.T) { defer close(errChan) require.Len(t, reqs, 1) - require.IsType(t, &wire.MsgGetData{}, reqs[0].Req) + require.IsType(t, &wire.MsgGetData{}, reqs[0].Req.Message()) - getData := reqs[0].Req.(*wire.MsgGetData) + getData := reqs[0].Req.Message().(*wire.MsgGetData) require.Len(t, getData.InvList, 1) inv := getData.InvList[0] @@ -324,10 +324,10 @@ func TestBlockCache(t *testing.T) { Header: *header, Transactions: b.MsgBlock().Transactions, } - - progress := reqs[0].HandleResp(getData, resp, "") - require.True(t, progress.Progressed) - require.True(t, progress.Finished) + progress := reqs[0].HandleResp(&encodedQuery{ + message: getData, + }, resp, nil) + require.Equal(t, query.Finished, progress) // Notify the test about the query. select {