diff --git a/node/core/approval-voting/src/lib.rs b/node/core/approval-voting/src/lib.rs index f5e888c7c538..a6a74da50480 100644 --- a/node/core/approval-voting/src/lib.rs +++ b/node/core/approval-voting/src/lib.rs @@ -2519,6 +2519,17 @@ async fn launch_approval( // do nothing. we'll just be a no-show and that'll cause others to rise up. metrics_guard.take().on_approval_unavailable(); }, + &RecoveryError::ChannelClosed => { + gum::warn!( + target: LOG_TARGET, + ?para_id, + ?candidate_hash, + "Channel closed while recovering data for candidate {:?}", + (candidate_hash, candidate.descriptor.para_id), + ); + // do nothing. we'll just be a no-show and that'll cause others to rise up. + metrics_guard.take().on_approval_unavailable(); + }, &RecoveryError::Invalid => { gum::warn!( target: LOG_TARGET, diff --git a/node/core/dispute-coordinator/src/participation/mod.rs b/node/core/dispute-coordinator/src/participation/mod.rs index b6a41bcff9dd..25b7352807f6 100644 --- a/node/core/dispute-coordinator/src/participation/mod.rs +++ b/node/core/dispute-coordinator/src/participation/mod.rs @@ -319,7 +319,7 @@ async fn participate( send_result(&mut result_sender, req, ParticipationOutcome::Invalid).await; return }, - Ok(Err(RecoveryError::Unavailable)) => { + Ok(Err(RecoveryError::Unavailable)) | Ok(Err(RecoveryError::ChannelClosed)) => { send_result(&mut result_sender, req, ParticipationOutcome::Unavailable).await; return }, diff --git a/node/network/availability-recovery/src/lib.rs b/node/network/availability-recovery/src/lib.rs index c771e31a6c40..e8503ee454a2 100644 --- a/node/network/availability-recovery/src/lib.rs +++ b/node/network/availability-recovery/src/lib.rs @@ -20,24 +20,28 @@ use std::{ collections::{HashMap, VecDeque}, + iter::Iterator, num::NonZeroUsize, pin::Pin, time::Duration, }; use futures::{ - channel::oneshot, - future::{FutureExt, RemoteHandle}, + channel::oneshot::{self, channel}, + future::{Future, FutureExt, RemoteHandle}, pin_mut, prelude::*, - stream::FuturesUnordered, + sink::SinkExt, + stream::{FuturesUnordered, StreamExt}, task::{Context, Poll}, }; use lru::LruCache; use rand::seq::SliceRandom; use fatality::Nested; -use polkadot_erasure_coding::{branch_hash, branches, obtain_chunks_v1, recovery_threshold}; +use polkadot_erasure_coding::{ + branch_hash, branches, obtain_chunks_v1, recovery_threshold, Error as ErasureEncodingError, +}; #[cfg(not(test))] use polkadot_node_network_protocol::request_response::CHUNK_REQUEST_TIMEOUT; use polkadot_node_network_protocol::{ @@ -150,6 +154,8 @@ struct RequestFromBackers { // a random shuffling of the validators from the backing group which indicates the order // in which we connect to them and request the chunk. shuffled_backers: Vec, + // channel to the erasure task handler. + erasure_task_tx: futures::channel::mpsc::Sender, } struct RequestChunksFromValidators { @@ -162,9 +168,12 @@ struct RequestChunksFromValidators { /// a random shuffling of the validators which indicates the order in which we connect to the validators and /// request the chunk from them. shuffling: VecDeque, + /// Chunks received so far. received_chunks: HashMap, /// Pending chunk requests with soft timeout. requesting_chunks: FuturesUndead, (ValidatorIndex, RequestError)>>, + // channel to the erasure task handler. + erasure_task_tx: futures::channel::mpsc::Sender, } struct RecoveryParams { @@ -198,6 +207,18 @@ enum Source { RequestChunks(RequestChunksFromValidators), } +/// Expensive erasure coding computations that we want to run on a blocking thread. +enum ErasureTask { + /// Reconstructs `AvailableData` from chunks given `n_validators`. + Reconstruct( + usize, + HashMap, + oneshot::Sender>, + ), + /// Re-encode `AvailableData` into erasure chunks in order to verify the provided root hash of the Merkle tree. + Reencode(usize, Hash, AvailableData, oneshot::Sender>), +} + /// A stateful reconstruction of availability data in reference to /// a candidate hash. struct RecoveryTask { @@ -208,13 +229,19 @@ struct RecoveryTask { /// The source to obtain the availability data from. source: Source, + + // channel to the erasure task handler. + erasure_task_tx: futures::channel::mpsc::Sender, } impl RequestFromBackers { - fn new(mut backers: Vec) -> Self { + fn new( + mut backers: Vec, + erasure_task_tx: futures::channel::mpsc::Sender, + ) -> Self { backers.shuffle(&mut rand::thread_rng()); - RequestFromBackers { shuffled_backers: backers } + RequestFromBackers { shuffled_backers: backers, erasure_task_tx } } // Run this phase to completion. @@ -251,12 +278,21 @@ impl RequestFromBackers { match response.await { Ok(req_res::v1::AvailableDataFetchingResponse::AvailableData(data)) => { - if reconstructed_data_matches_root( - params.validators.len(), - ¶ms.erasure_root, - &data, - ¶ms.metrics, - ) { + let (reencode_tx, reencode_rx) = channel(); + self.erasure_task_tx + .send(ErasureTask::Reencode( + params.validators.len(), + params.erasure_root, + data, + reencode_tx, + )) + .await + .map_err(|_| RecoveryError::ChannelClosed)?; + + let reencode_response = + reencode_rx.await.map_err(|_| RecoveryError::ChannelClosed)?; + + if let Some(data) = reencode_response { gum::trace!( target: LOG_TARGET, candidate_hash = ?params.candidate_hash, @@ -289,7 +325,10 @@ impl RequestFromBackers { } impl RequestChunksFromValidators { - fn new(n_validators: u32) -> Self { + fn new( + n_validators: u32, + erasure_task_tx: futures::channel::mpsc::Sender, + ) -> Self { let mut shuffling: Vec<_> = (0..n_validators).map(ValidatorIndex).collect(); shuffling.shuffle(&mut rand::thread_rng()); @@ -299,20 +338,29 @@ impl RequestChunksFromValidators { shuffling: shuffling.into(), received_chunks: HashMap::new(), requesting_chunks: FuturesUndead::new(), + erasure_task_tx, } } fn is_unavailable(&self, params: &RecoveryParams) -> bool { is_unavailable( - self.received_chunks.len(), + self.chunk_count(), self.requesting_chunks.total_len(), self.shuffling.len(), params.threshold, ) } + fn chunk_count(&self) -> usize { + self.received_chunks.len() + } + + fn insert_chunk(&mut self, validator_index: ValidatorIndex, chunk: ErasureChunk) { + self.received_chunks.insert(validator_index, chunk); + } + fn can_conclude(&self, params: &RecoveryParams) -> bool { - self.received_chunks.len() >= params.threshold || self.is_unavailable(params) + self.chunk_count() >= params.threshold || self.is_unavailable(params) } /// Desired number of parallel requests. @@ -329,7 +377,7 @@ impl RequestChunksFromValidators { // 4. We request more chunks to make up for it ... let max_requests_boundary = std::cmp::min(N_PARALLEL, threshold); // How many chunks are still needed? - let remaining_chunks = threshold.saturating_sub(self.received_chunks.len()); + let remaining_chunks = threshold.saturating_sub(self.chunk_count()); // What is the current error rate, so we can make up for it? let inv_error_rate = self.total_received_responses.checked_div(self.error_count).unwrap_or(0); @@ -430,7 +478,7 @@ impl RequestChunksFromValidators { validator_index = ?chunk.index, "Received valid chunk", ); - self.received_chunks.insert(chunk.index, chunk); + self.insert_chunk(chunk.index, chunk); } else { metrics.on_chunk_request_invalid(); self.error_count += 1; @@ -488,7 +536,7 @@ impl RequestChunksFromValidators { gum::debug!( target: LOG_TARGET, candidate_hash = ?params.candidate_hash, - received_chunks_count = ?self.received_chunks.len(), + received_chunks_count = ?self.chunk_count(), requested_chunks_count = ?self.requesting_chunks.len(), threshold = ?params.threshold, "Can conclude availability for a candidate", @@ -530,7 +578,7 @@ impl RequestChunksFromValidators { validator_index = ?chunk.index, "Found valid chunk on disk" ); - self.received_chunks.insert(chunk.index, chunk); + self.insert_chunk(chunk.index, chunk); } else { gum::error!( target: LOG_TARGET, @@ -557,7 +605,7 @@ impl RequestChunksFromValidators { target: LOG_TARGET, candidate_hash = ?params.candidate_hash, erasure_root = ?params.erasure_root, - received = %self.received_chunks.len(), + received = %self.chunk_count(), requesting = %self.requesting_chunks.len(), total_requesting = %self.requesting_chunks.total_len(), n_validators = %params.validators.len(), @@ -575,20 +623,41 @@ impl RequestChunksFromValidators { // If received_chunks has more than threshold entries, attempt to recover the data. // If that fails, or a re-encoding of it doesn't match the expected erasure root, // return Err(RecoveryError::Invalid) - if self.received_chunks.len() >= params.threshold { + if self.chunk_count() >= params.threshold { let recovery_duration = metrics.time_erasure_recovery(); - return match polkadot_erasure_coding::reconstruct_v1( - params.validators.len(), - self.received_chunks.values().map(|c| (&c.chunk[..], c.index.0 as usize)), - ) { + // Send request to reconstruct available data from chunks. + let (avilable_data_tx, available_data_rx) = channel(); + self.erasure_task_tx + .send(ErasureTask::Reconstruct( + params.validators.len(), + std::mem::take(&mut self.received_chunks), + avilable_data_tx, + )) + .await + .map_err(|_| RecoveryError::ChannelClosed)?; + + let available_data_response = + available_data_rx.await.map_err(|_| RecoveryError::ChannelClosed)?; + + return match available_data_response { Ok(data) => { - if reconstructed_data_matches_root( - params.validators.len(), - ¶ms.erasure_root, - &data, - &metrics, - ) { + // Send request to re-encode the chunks and check merkle root. + let (reencode_tx, reencode_rx) = channel(); + self.erasure_task_tx + .send(ErasureTask::Reencode( + params.validators.len(), + params.erasure_root, + data, + reencode_tx, + )) + .await + .map_err(|_| RecoveryError::ChannelClosed)?; + + let reencode_response = + reencode_rx.await.map_err(|_| RecoveryError::ChannelClosed)?; + + if let Some(data) = reencode_response { gum::trace!( target: LOG_TARGET, candidate_hash = ?params.candidate_hash, @@ -746,9 +815,12 @@ where match from_backers.run(&self.params, &mut self.sender).await { Ok(data) => break Ok(data), Err(RecoveryError::Invalid) => break Err(RecoveryError::Invalid), + Err(RecoveryError::ChannelClosed) => + break Err(RecoveryError::ChannelClosed), Err(RecoveryError::Unavailable) => self.source = Source::RequestChunks(RequestChunksFromValidators::new( self.params.validators.len() as _, + self.erasure_task_tx.clone(), )), } }, @@ -838,6 +910,7 @@ impl TryFrom> for CachedRecovery { // We don't want to cache unavailable state, as that state might change, so if // requested again we want to try again! Err(RecoveryError::Unavailable) => Err(()), + Err(RecoveryError::ChannelClosed) => Err(()), } } } @@ -904,9 +977,9 @@ async fn launch_recovery_task( response_sender: oneshot::Sender>, metrics: &Metrics, recovery_strategy: &RecoveryStrategy, + erasure_task_tx: futures::channel::mpsc::Sender, ) -> error::Result<()> { let candidate_hash = receipt.hash(); - let params = RecoveryParams { validator_authority_keys: session_info.discovery_keys.clone(), validators: session_info.validators.clone(), @@ -943,12 +1016,21 @@ async fn launch_recovery_task( let phase = backing_group .and_then(|g| session_info.validator_groups.get(g)) - .map(|group| Source::RequestFromBackers(RequestFromBackers::new(group.clone()))) + .map(|group| { + Source::RequestFromBackers(RequestFromBackers::new( + group.clone(), + erasure_task_tx.clone(), + )) + }) .unwrap_or_else(|| { - Source::RequestChunks(RequestChunksFromValidators::new(params.validators.len() as _)) + Source::RequestChunks(RequestChunksFromValidators::new( + params.validators.len() as _, + erasure_task_tx.clone(), + )) }); - let recovery_task = RecoveryTask { sender: ctx.sender().clone(), params, source: phase }; + let recovery_task = + RecoveryTask { sender: ctx.sender().clone(), params, source: phase, erasure_task_tx }; let (remote, remote_handle) = recovery_task.run().remote_handle(); @@ -980,6 +1062,7 @@ async fn handle_recover( response_sender: oneshot::Sender>, metrics: &Metrics, recovery_strategy: &RecoveryStrategy, + erasure_task_tx: futures::channel::mpsc::Sender, ) -> error::Result<()> { let candidate_hash = receipt.hash(); @@ -1024,6 +1107,7 @@ async fn handle_recover( response_sender, metrics, recovery_strategy, + erasure_task_tx, ) .await, None => { @@ -1061,6 +1145,7 @@ async fn query_chunk_size( rx.await.map_err(error::Error::CanceledQueryFullData) } + #[overseer::contextbounds(AvailabilityRecovery, prefix = self::overseer)] impl AvailabilityRecoverySubsystem { /// Create a new instance of `AvailabilityRecoverySubsystem` which never requests the @@ -1106,10 +1191,65 @@ impl AvailabilityRecoverySubsystem { let mut state = State::default(); let Self { recovery_strategy, mut req_receiver, metrics } = self; + let (erasure_task_tx, erasure_task_rx) = futures::channel::mpsc::channel(16); + let mut erasure_task_rx = erasure_task_rx.fuse(); + + // `ThreadPoolBuilder` spawns the tasks using `spawn_blocking`. For each worker there will be a `mpsc` channel created. + // Each of these workers take the `Receiver` and poll it in an infinite loop. + // All of the sender ends of the channel are sent as a vec which we then use to create a `Cycle` iterator. + // We use this iterator to assign work in a round-robin fashion to the workers in the pool. + // + // How work is dispatched to the pool from the recovery tasks: + // - Once a recovery task finishes retrieving the availability data, it needs to reconstruct from chunks and/or + // re-encode the data which are heavy CPU computations. + // To do so it sends an `ErasureTask` to the main loop via the `erasure_task` channel, and waits for the results + // over a `oneshot` channel. + // - In the subsystem main loop we poll the `erasure_task_rx` receiver. + // - We forward the received `ErasureTask` to the `next()` sender yielded by the `Cycle` iterator. + // - Some worker thread handles it and sends the response over the `oneshot` channel. + + // Create a thread pool with 2 workers. + let mut to_pool = ThreadPoolBuilder::build( + // Pool is guaranteed to have at least 1 worker thread. + NonZeroUsize::new(2).expect("There are 2 threads; qed"), + metrics.clone(), + &mut ctx, + ) + .into_iter() + .cycle(); + loop { let recv_req = req_receiver.recv(|| vec![COST_INVALID_REQUEST]).fuse(); pin_mut!(recv_req); futures::select! { + erasure_task = erasure_task_rx.next() => { + match erasure_task { + Some(task) => { + let send_result = to_pool + .next() + .expect("Pool size is `NonZeroUsize`; qed") + .send(task) + .await + .map_err(|_| RecoveryError::ChannelClosed); + + if let Err(err) = send_result { + gum::warn!( + target: LOG_TARGET, + ?err, + "Failed to send erasure coding task", + ); + } + }, + None => { + gum::debug!( + target: LOG_TARGET, + "Erasure task channel closed", + ); + + return Err(SubsystemError::with_origin("availability-recovery", RecoveryError::ChannelClosed)) + } + } + } v = ctx.recv().fuse() => { match v? { FromOrchestra::Signal(signal) => if handle_signal( @@ -1135,6 +1275,7 @@ impl AvailabilityRecoverySubsystem { response_sender, &metrics, &recovery_strategy, + erasure_task_tx.clone(), ).await { gum::warn!( target: LOG_TARGET, @@ -1194,3 +1335,92 @@ impl AvailabilityRecoverySubsystem { } } } + +// A simple thread pool implementation using `spawn_blocking` threads. +struct ThreadPoolBuilder; + +const MAX_THREADS: NonZeroUsize = match NonZeroUsize::new(4) { + Some(max_threads) => max_threads, + None => panic!("MAX_THREADS must be non-zero"), +}; + +impl ThreadPoolBuilder { + // Creates a pool of `size` workers, where 1 <= `size` <= `MAX_THREADS`. + // + // Each worker is created by `spawn_blocking` and takes the receiver side of a channel + // while all of the senders are returned to the caller. Each worker runs `erasure_task_thread` that + // polls the `Receiver` for an `ErasureTask` which is expected to be CPU intensive. The larger + // the input (more or larger chunks/availability data), the more CPU cycles will be spent. + // + // For example, for 32KB PoVs, we'd expect re-encode to eat as much as 90ms and 500ms for 2.5MiB. + // + // After executing such a task, the worker sends the response via a provided `oneshot` sender. + // + // The caller is responsible for routing work to the workers. + #[overseer::contextbounds(AvailabilityRecovery, prefix = self::overseer)] + pub fn build( + size: NonZeroUsize, + metrics: Metrics, + ctx: &mut Context, + ) -> Vec> { + // At least 1 task, at most `MAX_THREADS. + let size = std::cmp::min(size, MAX_THREADS); + let mut senders = Vec::new(); + + for index in 0..size.into() { + let (tx, rx) = futures::channel::mpsc::channel(8); + senders.push(tx); + + if let Err(e) = ctx + .spawn_blocking("erasure-task", Box::pin(erasure_task_thread(metrics.clone(), rx))) + { + gum::warn!( + target: LOG_TARGET, + err = ?e, + index, + "Failed to spawn a erasure task", + ); + } + } + senders + } +} + +// Handles CPU intensive operation on a dedicated blocking thread. +async fn erasure_task_thread( + metrics: Metrics, + mut ingress: futures::channel::mpsc::Receiver, +) { + loop { + match ingress.next().await { + Some(ErasureTask::Reconstruct(n_validators, chunks, sender)) => { + let _ = sender.send(polkadot_erasure_coding::reconstruct_v1( + n_validators, + chunks.values().map(|c| (&c.chunk[..], c.index.0 as usize)), + )); + }, + Some(ErasureTask::Reencode(n_validators, root, available_data, sender)) => { + let metrics = metrics.clone(); + + let maybe_data = if reconstructed_data_matches_root( + n_validators, + &root, + &available_data, + &metrics, + ) { + Some(available_data) + } else { + None + }; + + let _ = sender.send(maybe_data); + }, + None => { + gum::debug!( + target: LOG_TARGET, + "Erasure task channel closed. Node shutting down ?", + ); + }, + } + } +} diff --git a/node/network/availability-recovery/src/tests.rs b/node/network/availability-recovery/src/tests.rs index b9c5abee191f..26a99e91a5e2 100644 --- a/node/network/availability-recovery/src/tests.rs +++ b/node/network/availability-recovery/src/tests.rs @@ -1584,7 +1584,9 @@ fn invalid_local_chunk_is_ignored() { fn parallel_request_calculation_works_as_expected() { let num_validators = 100; let threshold = recovery_threshold(num_validators).unwrap(); - let mut phase = RequestChunksFromValidators::new(100); + let (erasure_task_tx, _erasure_task_rx) = futures::channel::mpsc::channel(16); + + let mut phase = RequestChunksFromValidators::new(100, erasure_task_tx); assert_eq!(phase.get_desired_request_count(threshold), threshold); phase.error_count = 1; phase.total_received_responses = 1; @@ -1593,20 +1595,20 @@ fn parallel_request_calculation_works_as_expected() { let dummy_chunk = ErasureChunk { chunk: Vec::new(), index: ValidatorIndex(0), proof: Proof::dummy_proof() }; - phase.received_chunks.insert(ValidatorIndex(0), dummy_chunk.clone()); + phase.insert_chunk(ValidatorIndex(0), dummy_chunk.clone()); phase.total_received_responses = 2; // With given error rate - still saturating: assert_eq!(phase.get_desired_request_count(threshold), threshold); for i in 1..9 { - phase.received_chunks.insert(ValidatorIndex(i), dummy_chunk.clone()); + phase.insert_chunk(ValidatorIndex(i), dummy_chunk.clone()); } phase.total_received_responses += 8; // error rate: 1/10 // remaining chunks needed: threshold (34) - 9 // expected: 24 * (1+ 1/10) = (next greater integer) = 27 assert_eq!(phase.get_desired_request_count(threshold), 27); - phase.received_chunks.insert(ValidatorIndex(9), dummy_chunk.clone()); + phase.insert_chunk(ValidatorIndex(9), dummy_chunk.clone()); phase.error_count = 0; // With error count zero - we should fetch exactly as needed: - assert_eq!(phase.get_desired_request_count(threshold), threshold - phase.received_chunks.len()); + assert_eq!(phase.get_desired_request_count(threshold), threshold - phase.chunk_count()); } diff --git a/node/subsystem-types/src/errors.rs b/node/subsystem-types/src/errors.rs index d633ac2ef959..44136362a69e 100644 --- a/node/subsystem-types/src/errors.rs +++ b/node/subsystem-types/src/errors.rs @@ -75,6 +75,9 @@ pub enum RecoveryError { /// A requested chunk is unavailable. Unavailable, + + /// Erasure task channel closed, usually means node is shutting down. + ChannelClosed, } impl std::fmt::Display for RecoveryError { @@ -82,6 +85,7 @@ impl std::fmt::Display for RecoveryError { let msg = match self { RecoveryError::Invalid => "Invalid", RecoveryError::Unavailable => "Unavailable", + RecoveryError::ChannelClosed => "ChannelClosed", }; write!(f, "{}", msg)