diff --git a/cluster_benchmark/tests/benchmark/store.rs b/cluster_benchmark/tests/benchmark/store.rs index f6d5360cc..2b3b9bf73 100644 --- a/cluster_benchmark/tests/benchmark/store.rs +++ b/cluster_benchmark/tests/benchmark/store.rs @@ -19,7 +19,6 @@ use openraft::Entry; use openraft::EntryPayload; use openraft::LogId; use openraft::OptionalSend; -use openraft::OptionalSync; use openraft::RaftLogId; use openraft::RaftTypeConfig; use openraft::SnapshotMeta; @@ -225,8 +224,14 @@ impl RaftLogStorage for Arc { } #[tracing::instrument(level = "trace", skip_all)] - async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> - where I: IntoIterator> + Send { + async fn append( + &mut self, + entries: I, + callback: LogFlushed, + ) -> Result<(), StorageError> + where + I: IntoIterator> + Send, + { { let mut log = self.log.write().await; log.extend(entries.into_iter().map(|entry| (entry.get_log_id().index, entry))); diff --git a/examples/memstore/src/log_store.rs b/examples/memstore/src/log_store.rs index 25715e781..867465ac6 100644 --- a/examples/memstore/src/log_store.rs +++ b/examples/memstore/src/log_store.rs @@ -93,7 +93,7 @@ impl LogStoreInner { Ok(self.vote) } - async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> + async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> where I: IntoIterator { // Simple implementation that calls the flush-before-return `append_to_log`. for entry in entries { @@ -188,14 +188,8 @@ mod impl_log_store { inner.read_vote().await } - async fn append( - &mut self, - entries: I, - callback: LogFlushed, - ) -> Result<(), StorageError> - where - I: IntoIterator, - { + async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> + where I: IntoIterator { let mut inner = self.inner.lock().await; inner.append(entries, callback).await } diff --git a/examples/raft-kv-memstore-singlethreaded/src/store.rs b/examples/raft-kv-memstore-singlethreaded/src/store.rs index 425130e19..227fba770 100644 --- a/examples/raft-kv-memstore-singlethreaded/src/store.rs +++ b/examples/raft-kv-memstore-singlethreaded/src/store.rs @@ -321,7 +321,7 @@ impl RaftLogStorage for Rc { } #[tracing::instrument(level = "trace", skip(self, entries, callback))] - async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> + async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> where I: IntoIterator> { // Simple implementation that calls the flush-before-return `append_to_log`. let mut log = self.log.borrow_mut(); diff --git a/examples/raft-kv-rocksdb/src/store.rs b/examples/raft-kv-rocksdb/src/store.rs index bd678d04a..4c1bac8b9 100644 --- a/examples/raft-kv-rocksdb/src/store.rs +++ b/examples/raft-kv-rocksdb/src/store.rs @@ -436,7 +436,7 @@ impl RaftLogStorage for LogStore { } #[tracing::instrument(level = "trace", skip_all)] - async fn append(&mut self, entries: I, callback: LogFlushed) -> StorageResult<()> + async fn append(&mut self, entries: I, callback: LogFlushed) -> StorageResult<()> where I: IntoIterator> + Send, I::IntoIter: Send, diff --git a/openraft/src/async_runtime.rs b/openraft/src/async_runtime.rs index 5e9c73e2a..c602d48a7 100644 --- a/openraft/src/async_runtime.rs +++ b/openraft/src/async_runtime.rs @@ -18,7 +18,7 @@ use crate::TokioInstant; /// ## Note /// /// The default asynchronous runtime is `tokio`. -pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static { +pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + OptionalSync + 'static { /// The error type of [`Self::JoinHandle`]. type JoinError: Debug + Display + OptionalSend; @@ -44,6 +44,18 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static /// Type of a thread-local random number generator. type ThreadLocalRng: rand::Rng; + /// Type of a `oneshot` sender. + type OneshotSender: AsyncOneshotSendExt + OptionalSend + OptionalSync + Debug + Sized; + + /// Type of a `oneshot` receiver error. + type OneshotReceiverError: std::error::Error + OptionalSend; + + /// Type of a `oneshot` receiver. + type OneshotReceiver: OptionalSend + + OptionalSync + + Future> + + Unpin; + /// Spawn a new task. fn spawn(future: T) -> Self::JoinHandle where @@ -72,12 +84,24 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static /// This is a per-thread instance, which cannot be shared across threads or /// sent to another thread. fn thread_rng() -> Self::ThreadLocalRng; + + /// Creates a new one-shot channel for sending single values. + /// + /// The function returns separate "send" and "receive" handles. The `Sender` + /// handle is used by the producer to send the value. The `Receiver` handle is + /// used by the consumer to receive the value. + /// + /// Each handle can be used on separate tasks. + fn oneshot() -> (Self::OneshotSender, Self::OneshotReceiver) + where T: OptionalSend; } /// `Tokio` is the default asynchronous executor. -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq, Eq)] pub struct TokioRuntime; +pub struct TokioOneShotSender(pub tokio::sync::oneshot::Sender); + impl AsyncRuntime for TokioRuntime { type JoinError = tokio::task::JoinError; type JoinHandle = tokio::task::JoinHandle; @@ -86,6 +110,9 @@ impl AsyncRuntime for TokioRuntime { type TimeoutError = tokio::time::error::Elapsed; type Timeout + OptionalSend> = tokio::time::Timeout; type ThreadLocalRng = rand::rngs::ThreadRng; + type OneshotSender = TokioOneShotSender; + type OneshotReceiver = tokio::sync::oneshot::Receiver; + type OneshotReceiverError = tokio::sync::oneshot::error::RecvError; #[inline] fn spawn(future: T) -> Self::JoinHandle @@ -132,4 +159,36 @@ impl AsyncRuntime for TokioRuntime { fn thread_rng() -> Self::ThreadLocalRng { rand::thread_rng() } + + #[inline] + fn oneshot() -> (Self::OneshotSender, Self::OneshotReceiver) + where T: OptionalSend { + let (tx, rx) = tokio::sync::oneshot::channel(); + (TokioOneShotSender(tx), rx) + } +} + +pub trait AsyncOneshotSendExt: Unpin { + /// Attempts to send a value on this channel, returning it back if it could + /// not be sent. + /// + /// This method consumes `self` as only one value may ever be sent on a `oneshot` + /// channel. It is not marked async because sending a message to an `oneshot` + /// channel never requires any form of waiting. Because of this, the `send` + /// method can be used in both synchronous and asynchronous code without + /// problems. + fn send(self, t: T) -> Result<(), T>; +} + +impl AsyncOneshotSendExt for TokioOneShotSender { + #[inline] + fn send(self, t: T) -> Result<(), T> { + self.0.send(t) + } +} + +impl Debug for TokioOneShotSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("TokioSendWrapper").finish() + } } diff --git a/openraft/src/core/raft_core.rs b/openraft/src/core/raft_core.rs index 9ce86f542..921a08fb7 100644 --- a/openraft/src/core/raft_core.rs +++ b/openraft/src/core/raft_core.rs @@ -15,12 +15,12 @@ use futures::TryFutureExt; use maplit::btreeset; use tokio::select; use tokio::sync::mpsc; -use tokio::sync::oneshot; use tokio::sync::watch; use tracing::Instrument; use tracing::Level; use tracing::Span; +use crate::async_runtime::AsyncOneshotSendExt; use crate::config::Config; use crate::config::RuntimeConfig; use crate::core::balancer::Balancer; @@ -215,7 +215,10 @@ where SM: RaftStateMachine, { /// The main loop of the Raft protocol. - pub(crate) async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal> { + pub(crate) async fn main( + mut self, + rx_shutdown: ::OneshotReceiver<()>, + ) -> Result<(), Fatal> { let span = tracing::span!(parent: &self.span, Level::DEBUG, "main"); let res = self.do_main(rx_shutdown).instrument(span).await; @@ -239,7 +242,10 @@ where } #[tracing::instrument(level="trace", skip_all, fields(id=display(self.id), cluster=%self.config.cluster_name))] - async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal> { + async fn do_main( + &mut self, + rx_shutdown: ::OneshotReceiver<()>, + ) -> Result<(), Fatal> { tracing::debug!("raft node is initializing"); self.engine.startup(); @@ -432,7 +438,7 @@ where &mut self, changes: ChangeMembers, retain: bool, - tx: ResultSender, ClientWriteError>, + tx: ResultSender, ClientWriteError>, ) { let res = self.engine.state.membership_state.change_handler().apply(changes, retain); let new_membership = match res { @@ -593,7 +599,7 @@ where pub(crate) fn handle_initialize( &mut self, member_nodes: BTreeMap, - tx: ResultSender<(), InitializeError>, + tx: ResultSender>, ) { tracing::debug!(member_nodes = debug(&member_nodes), "{}", func_name!()); @@ -616,7 +622,7 @@ where /// Reject a request due to the Raft node being in a state which prohibits the request. #[tracing::instrument(level = "trace", skip(self, tx))] - pub(crate) fn reject_with_forward_to_leader(&self, tx: ResultSender) + pub(crate) fn reject_with_forward_to_leader(&self, tx: ResultSender) where E: From> { let mut leader_id = self.current_leader(); let leader_node = self.get_leader_node(leader_id); @@ -680,7 +686,7 @@ where { tracing::debug!("append_to_log"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let callback = LogFlushed::new(Some(last_log_id), tx); self.log_store.append(entries, callback).await?; rx.await @@ -865,7 +871,10 @@ where /// Run an event handling loop #[tracing::instrument(level="debug", skip_all, fields(id=display(self.id)))] - async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal> { + async fn runtime_loop( + &mut self, + mut rx_shutdown: ::OneshotReceiver<()>, + ) -> Result<(), Fatal> { // Ratio control the ratio of number of RaftMsg to process to number of Notify to process. let mut balancer = Balancer::new(10_000); @@ -1067,7 +1076,7 @@ where } #[tracing::instrument(level = "debug", skip_all)] - pub(super) fn handle_vote_request(&mut self, req: VoteRequest, tx: VoteTx) { + pub(super) fn handle_vote_request(&mut self, req: VoteRequest, tx: VoteTx) { tracing::info!(req = display(req.summary()), func = func_name!()); let resp = self.engine.handle_vote_req(req); @@ -1078,11 +1087,7 @@ where } #[tracing::instrument(level = "debug", skip_all)] - pub(super) fn handle_append_entries_request( - &mut self, - req: AppendEntriesRequest, - tx: AppendEntriesTx, - ) { + pub(super) fn handle_append_entries_request(&mut self, req: AppendEntriesRequest, tx: AppendEntriesTx) { tracing::debug!(req = display(req.summary()), func = func_name!()); let is_ok = self.engine.handle_append_entries(&req.vote, req.prev_log_id, req.entries, Some(tx)); @@ -1657,7 +1662,7 @@ where // Create a channel to let state machine worker to send the snapshot and the replication // worker to receive it. - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let cmd = sm::Command::get_snapshot(tx); self.sm_handle diff --git a/openraft/src/core/raft_msg/external_command.rs b/openraft/src/core/raft_msg/external_command.rs index 17c983b3c..5714df39c 100644 --- a/openraft/src/core/raft_msg/external_command.rs +++ b/openraft/src/core/raft_msg/external_command.rs @@ -23,7 +23,7 @@ pub(crate) enum ExternalCommand { Snapshot, /// Get a snapshot from the state machine, send back via a oneshot::Sender. - GetSnapshot { tx: ResultSender>> }, + GetSnapshot { tx: ResultSender>> }, /// Purge logs covered by a snapshot up to a specified index. /// diff --git a/openraft/src/core/raft_msg/mod.rs b/openraft/src/core/raft_msg/mod.rs index b6679d43a..f4ea72cdc 100644 --- a/openraft/src/core/raft_msg/mod.rs +++ b/openraft/src/core/raft_msg/mod.rs @@ -1,7 +1,5 @@ use std::collections::BTreeMap; -use tokio::sync::oneshot; - use crate::core::raft_msg::external_command::ExternalCommand; use crate::error::CheckIsLeaderError; use crate::error::ClientWriteError; @@ -15,10 +13,13 @@ use crate::raft::ClientWriteResponse; use crate::raft::SnapshotResponse; use crate::raft::VoteRequest; use crate::raft::VoteResponse; +use crate::type_config::alias::AsyncRuntimeOf; use crate::type_config::alias::LogIdOf; use crate::type_config::alias::NodeIdOf; use crate::type_config::alias::NodeOf; +use crate::type_config::alias::OneshotSenderOf; use crate::type_config::alias::SnapshotDataOf; +use crate::AsyncRuntime; use crate::ChangeMembers; use crate::MessageSummary; use crate::RaftTypeConfig; @@ -28,22 +29,23 @@ use crate::Vote; pub(crate) mod external_command; /// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`. -pub(crate) type ResultSender = oneshot::Sender>; +pub(crate) type ResultSender = OneshotSenderOf>; -pub(crate) type ResultReceiver = oneshot::Receiver>; +pub(crate) type ResultReceiver = + as AsyncRuntime>::OneshotReceiver>; /// TX for Vote Response -pub(crate) type VoteTx = ResultSender>; +pub(crate) type VoteTx = ResultSender>>; /// TX for Append Entries Response -pub(crate) type AppendEntriesTx = ResultSender>; +pub(crate) type AppendEntriesTx = ResultSender>>; /// TX for Client Write Response -pub(crate) type ClientWriteTx = ResultSender, ClientWriteError, NodeOf>>; +pub(crate) type ClientWriteTx = ResultSender, ClientWriteError, NodeOf>>; /// TX for Linearizable Read Response pub(crate) type ClientReadTx = - ResultSender<(Option>, Option>), CheckIsLeaderError, NodeOf>>; + ResultSender>, Option>), CheckIsLeaderError, NodeOf>>; /// A message sent by application to the [`RaftCore`]. /// @@ -53,18 +55,18 @@ where C: RaftTypeConfig { AppendEntries { rpc: AppendEntriesRequest, - tx: AppendEntriesTx, + tx: AppendEntriesTx, }, RequestVote { rpc: VoteRequest, - tx: VoteTx, + tx: VoteTx, }, InstallFullSnapshot { vote: Vote, snapshot: Snapshot, - tx: ResultSender>, + tx: ResultSender>, }, /// Begin receiving a snapshot from the leader. @@ -74,7 +76,7 @@ where C: RaftTypeConfig /// will be returned in a Err BeginReceivingSnapshot { vote: Vote, - tx: ResultSender>, HigherVote>, + tx: ResultSender>, HigherVote>, }, ClientWriteRequest { @@ -88,7 +90,7 @@ where C: RaftTypeConfig Initialize { members: BTreeMap, - tx: ResultSender<(), InitializeError>, + tx: ResultSender>, }, ChangeMembership { @@ -98,7 +100,7 @@ where C: RaftTypeConfig /// config will be converted into learners, otherwise they will be removed. retain: bool, - tx: ResultSender, ClientWriteError>, + tx: ResultSender, ClientWriteError>, }, ExternalCoreRequest { diff --git a/openraft/src/core/sm/command.rs b/openraft/src/core/sm/command.rs index 2c49c2597..11ddcf8bc 100644 --- a/openraft/src/core/sm/command.rs +++ b/openraft/src/core/sm/command.rs @@ -54,12 +54,12 @@ where C: RaftTypeConfig Command::new(payload) } - pub(crate) fn get_snapshot(tx: ResultSender>>) -> Self { + pub(crate) fn get_snapshot(tx: ResultSender>>) -> Self { let payload = CommandPayload::GetSnapshot { tx }; Command::new(payload) } - pub(crate) fn begin_receiving_snapshot(tx: ResultSender>, HigherVote>) -> Self { + pub(crate) fn begin_receiving_snapshot(tx: ResultSender>, HigherVote>) -> Self { let payload = CommandPayload::BeginReceivingSnapshot { tx }; Command::new(payload) } @@ -91,11 +91,11 @@ where C: RaftTypeConfig /// Get the latest built snapshot. GetSnapshot { - tx: ResultSender>>, + tx: ResultSender>>, }, BeginReceivingSnapshot { - tx: ResultSender>, HigherVote>, + tx: ResultSender>, HigherVote>, }, InstallFullSnapshot { diff --git a/openraft/src/core/sm/mod.rs b/openraft/src/core/sm/mod.rs index 05a9d6987..af810f6c0 100644 --- a/openraft/src/core/sm/mod.rs +++ b/openraft/src/core/sm/mod.rs @@ -6,6 +6,7 @@ use tokio::sync::mpsc; +use crate::async_runtime::AsyncOneshotSendExt; use crate::core::ApplyResult; use crate::core::ApplyingEntry; use crate::entry::RaftPayload; @@ -219,7 +220,7 @@ where } #[tracing::instrument(level = "info", skip_all)] - async fn get_snapshot(&mut self, tx: ResultSender>>) -> Result<(), StorageError> { + async fn get_snapshot(&mut self, tx: ResultSender>>) -> Result<(), StorageError> { tracing::info!("{}", func_name!()); let snapshot = self.state_machine.get_current_snapshot().await?; diff --git a/openraft/src/engine/command.rs b/openraft/src/engine/command.rs index a61ff97f4..6673ebc75 100644 --- a/openraft/src/engine/command.rs +++ b/openraft/src/engine/command.rs @@ -1,7 +1,6 @@ use std::fmt::Debug; -use tokio::sync::oneshot; - +use crate::async_runtime::AsyncOneshotSendExt; use crate::core::sm; use crate::engine::CommandKind; use crate::error::Infallible; @@ -14,10 +13,11 @@ use crate::raft::InstallSnapshotResponse; use crate::raft::SnapshotResponse; use crate::raft::VoteRequest; use crate::raft::VoteResponse; +use crate::type_config::alias::OneshotSenderOf; use crate::LeaderId; use crate::LogId; -use crate::Node; use crate::NodeId; +use crate::OptionalSend; use crate::RaftTypeConfig; use crate::Vote; @@ -98,7 +98,7 @@ where C: RaftTypeConfig /// Send result to caller Respond { when: Option>, - resp: Respond, + resp: Respond, }, } @@ -218,31 +218,26 @@ where NID: NodeId } /// A command to send return value to the caller via a `oneshot::Sender`. -#[derive(Debug)] -#[derive(PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] #[derive(derive_more::From)] -pub(crate) enum Respond -where - NID: NodeId, - N: Node, +pub(crate) enum Respond +where C: RaftTypeConfig { - Vote(ValueSender, Infallible>>), - AppendEntries(ValueSender, Infallible>>), - ReceiveSnapshotChunk(ValueSender>), - InstallSnapshot(ValueSender, InstallSnapshotError>>), - InstallFullSnapshot(ValueSender, Infallible>>), - Initialize(ValueSender>>), + Vote(ValueSender, Infallible>>), + AppendEntries(ValueSender, Infallible>>), + ReceiveSnapshotChunk(ValueSender>), + InstallSnapshot(ValueSender, InstallSnapshotError>>), + InstallFullSnapshot(ValueSender, Infallible>>), + Initialize(ValueSender>>), } -impl Respond -where - NID: NodeId, - N: Node, +impl Respond +where C: RaftTypeConfig { - pub(crate) fn new(res: T, tx: oneshot::Sender) -> Self + pub(crate) fn new(res: T, tx: OneshotSenderOf) -> Self where - T: Debug + PartialEq + Eq, - Self: From>, + T: Debug + PartialEq + Eq + OptionalSend, + Self: From>, { Respond::from(ValueSender::new(res, tx)) } @@ -260,27 +255,38 @@ where } #[derive(Debug)] -pub(crate) struct ValueSender -where T: Debug + PartialEq + Eq +pub(crate) struct ValueSender +where + T: Debug + PartialEq + Eq + OptionalSend, + C: RaftTypeConfig, { value: T, - tx: oneshot::Sender, + tx: OneshotSenderOf, } -impl PartialEq for ValueSender -where T: Debug + PartialEq + Eq +impl PartialEq for ValueSender +where + T: Debug + PartialEq + Eq + OptionalSend, + C: RaftTypeConfig, { fn eq(&self, other: &Self) -> bool { self.value == other.value } } -impl Eq for ValueSender where T: Debug + PartialEq + Eq {} +impl Eq for ValueSender +where + T: Debug + PartialEq + Eq + OptionalSend, + C: RaftTypeConfig, +{ +} -impl ValueSender -where T: Debug + PartialEq + Eq +impl ValueSender +where + T: Debug + PartialEq + Eq + OptionalSend, + C: RaftTypeConfig, { - pub(crate) fn new(res: T, tx: oneshot::Sender) -> Self { + pub(crate) fn new(res: T, tx: OneshotSenderOf) -> Self { Self { value: res, tx } } diff --git a/openraft/src/engine/engine_impl.rs b/openraft/src/engine/engine_impl.rs index 8fbe07c0a..2cf2074b1 100644 --- a/openraft/src/engine/engine_impl.rs +++ b/openraft/src/engine/engine_impl.rs @@ -2,6 +2,7 @@ use std::time::Duration; use validit::Valid; +use crate::async_runtime::AsyncOneshotSendExt; use crate::core::raft_msg::AppendEntriesTx; use crate::core::raft_msg::ResultSender; use crate::core::sm; @@ -44,6 +45,7 @@ use crate::Instant; use crate::LogId; use crate::LogIdOptionExt; use crate::Membership; +use crate::OptionalSend; use crate::RaftLogId; use crate::RaftTypeConfig; use crate::Snapshot; @@ -222,9 +224,11 @@ where C: RaftTypeConfig #[tracing::instrument(level = "debug", skip_all)] pub(crate) fn get_leader_handler_or_reject( &mut self, - tx: Option>, - ) -> Option<(LeaderHandler, Option>)> + tx: Option>, + ) -> Option<(LeaderHandler, Option>)> where + T: OptionalSend, + E: OptionalSend, E: From>, { let res = self.leader_handler(); @@ -391,7 +395,7 @@ where C: RaftTypeConfig vote: &Vote, prev_log_id: Option>, entries: Vec, - tx: Option>, + tx: Option>, ) -> bool { tracing::debug!( vote = display(vote), @@ -454,7 +458,7 @@ where C: RaftTypeConfig &mut self, vote: Vote, snapshot: Snapshot, - tx: ResultSender>, + tx: ResultSender>, ) { tracing::info!(vote = display(vote), snapshot = display(&snapshot), "{}", func_name!()); @@ -487,7 +491,7 @@ where C: RaftTypeConfig pub(crate) fn handle_begin_receiving_snapshot( &mut self, vote: Vote, - tx: ResultSender>, HigherVote>, + tx: ResultSender>, HigherVote>, ) { tracing::info!(vote = display(vote), "{}", func_name!()); diff --git a/openraft/src/engine/handler/vote_handler/accept_vote_test.rs b/openraft/src/engine/handler/vote_handler/accept_vote_test.rs index 6e71a591d..7de499e7a 100644 --- a/openraft/src/engine/handler/vote_handler/accept_vote_test.rs +++ b/openraft/src/engine/handler/vote_handler/accept_vote_test.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use maplit::btreeset; use pretty_assertions::assert_eq; -use tokio::sync::oneshot; use crate::core::ServerState; use crate::engine::testing::UTConfig; @@ -12,7 +11,9 @@ use crate::engine::Respond; use crate::error::Infallible; use crate::raft::VoteResponse; use crate::testing::log_id; +use crate::type_config::alias::AsyncRuntimeOf; use crate::utime::UTime; +use crate::AsyncRuntime; use crate::EffectiveMembership; use crate::Membership; use crate::TokioInstant; @@ -51,12 +52,12 @@ fn test_accept_vote_reject_smaller_vote() -> anyhow::Result<()> { // When a vote is reject, it generate SendResultCommand and return an error. let mut eng = eng(); - let (tx, _rx) = oneshot::channel(); + let (tx, _rx) = AsyncRuntimeOf::::oneshot(); let resp = eng.vote_handler().accept_vote(&Vote::new(1, 2), tx, |_state, _err| mk_res()); assert!(resp.is_none()); - let (tx, _rx) = oneshot::channel(); + let (tx, _rx) = AsyncRuntimeOf::::oneshot(); assert_eq!( vec![ // @@ -76,7 +77,7 @@ fn test_accept_vote_granted_greater_vote() -> anyhow::Result<()> { // When a vote is accepted, it generate SaveVote command and return an Ok. let mut eng = eng(); - let (tx, _rx) = oneshot::channel(); + let (tx, _rx) = AsyncRuntimeOf::::oneshot(); let resp = eng.vote_handler().accept_vote(&Vote::new(3, 3), tx, |_state, _err| mk_res()); assert!(resp.is_some()); diff --git a/openraft/src/engine/handler/vote_handler/mod.rs b/openraft/src/engine/handler/vote_handler/mod.rs index c6b6bdd1b..92f4eb647 100644 --- a/openraft/src/engine/handler/vote_handler/mod.rs +++ b/openraft/src/engine/handler/vote_handler/mod.rs @@ -11,9 +11,11 @@ use crate::error::RejectVoteRequest; use crate::internal_server_state::InternalServerState; use crate::leader::Leading; use crate::raft_state::LogStateReader; +use crate::type_config::alias::InstantOf; use crate::utime::UTime; use crate::AsyncRuntime; use crate::Instant; +use crate::OptionalSend; use crate::RaftState; use crate::RaftTypeConfig; use crate::Vote; @@ -50,17 +52,14 @@ where C: RaftTypeConfig pub(crate) fn accept_vote( &mut self, vote: &Vote, - tx: ResultSender, + tx: ResultSender, f: F, - ) -> Option> + ) -> Option> where - T: Debug + Eq, - E: Debug + Eq, - Respond: From>>, - F: Fn( - &RaftState::Instant>, - RejectVoteRequest, - ) -> Result, + T: Debug + Eq + OptionalSend, + E: Debug + Eq + OptionalSend, + Respond: From>>, + F: Fn(&RaftState>, RejectVoteRequest) -> Result, { let vote_res = self.update_vote(vote); diff --git a/openraft/src/raft/external_request.rs b/openraft/src/raft/external_request.rs index c9b9b425e..a7309c38d 100644 --- a/openraft/src/raft/external_request.rs +++ b/openraft/src/raft/external_request.rs @@ -3,7 +3,19 @@ use crate::type_config::alias::InstantOf; use crate::type_config::alias::NodeIdOf; use crate::type_config::alias::NodeOf; +use crate::OptionalSend; use crate::RaftState; +use crate::RaftTypeConfig; + +pub trait BoxCoreFnInternal: FnOnce(&RaftState, NodeOf, InstantOf>) + OptionalSend +where C: RaftTypeConfig +{ +} + +impl, NodeOf, InstantOf>) + OptionalSend> BoxCoreFnInternal + for T +{ +} /// Boxed trait object for external request function run in `RaftCore` task. -pub(crate) type BoxCoreFn = Box, NodeOf, InstantOf>) + Send + 'static>; +pub(crate) type BoxCoreFn = Box + 'static>; diff --git a/openraft/src/raft/mod.rs b/openraft/src/raft/mod.rs index 5ae484b05..3fece00cf 100644 --- a/openraft/src/raft/mod.rs +++ b/openraft/src/raft/mod.rs @@ -28,13 +28,13 @@ pub use message::SnapshotResponse; pub use message::VoteRequest; pub use message::VoteResponse; use tokio::sync::mpsc; -use tokio::sync::oneshot; use tokio::sync::watch; use tokio::sync::Mutex; use tracing::trace_span; use tracing::Instrument; use tracing::Level; +use crate::async_runtime::AsyncOneshotSendExt; use crate::config::Config; use crate::config::RuntimeConfig; use crate::core::command_state::CommandState; @@ -70,6 +70,7 @@ use crate::ChangeMembers; use crate::LogId; use crate::LogIdOptionExt; use crate::MessageSummary; +use crate::OptionalSend; use crate::RaftState; pub use crate::RaftTypeConfig; use crate::Snapshot; @@ -180,7 +181,7 @@ where C: RaftTypeConfig let (tx_metrics, rx_metrics) = watch::channel(RaftMetrics::new_initial(id)); let (tx_data_metrics, rx_data_metrics) = watch::channel(RaftDataMetrics::default()); let (tx_server_metrics, rx_server_metrics) = watch::channel(RaftServerMetrics::default()); - let (tx_shutdown, rx_shutdown) = oneshot::channel(); + let (tx_shutdown, rx_shutdown) = C::AsyncRuntime::oneshot(); let tick_handle = Tick::spawn( Duration::from_millis(config.heartbeat_interval * 3 / 2), @@ -335,7 +336,7 @@ where C: RaftTypeConfig ) -> Result, RaftError> { tracing::debug!(rpc = display(rpc.summary()), "Raft::append_entries"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); self.inner.call_core(RaftMsg::AppendEntries { rpc, tx }, rx).await } @@ -347,7 +348,7 @@ where C: RaftTypeConfig pub async fn vote(&self, rpc: VoteRequest) -> Result, RaftError> { tracing::info!(rpc = display(rpc.summary()), "Raft::vote()"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); self.inner.call_core(RaftMsg::RequestVote { rpc, tx }, rx).await } @@ -359,7 +360,7 @@ where C: RaftTypeConfig pub async fn get_snapshot(&self) -> Result>, RaftError> { tracing::debug!("Raft::get_snapshot()"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let cmd = ExternalCommand::GetSnapshot { tx }; self.inner.call_core(RaftMsg::ExternalCommand { cmd }, rx).await } @@ -372,7 +373,7 @@ where C: RaftTypeConfig ) -> Result>, RaftError>> { tracing::info!("Raft::begin_receiving_snapshot()"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let resp = self.inner.call_core(RaftMsg::BeginReceivingSnapshot { vote, tx }, rx).await?; Ok(resp) } @@ -390,7 +391,7 @@ where C: RaftTypeConfig ) -> Result, Fatal> { tracing::info!("Raft::install_full_snapshot()"); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let res = self.inner.call_core(RaftMsg::InstallFullSnapshot { vote, snapshot, tx }, rx).await; match res { Ok(x) => Ok(x), @@ -491,7 +492,7 @@ where C: RaftTypeConfig #[deprecated(since = "0.9.0", note = "use `Raft::ensure_linearizable()` instead")] #[tracing::instrument(level = "debug", skip(self))] pub async fn is_leader(&self) -> Result<(), RaftError>> { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let _ = self.inner.call_core(RaftMsg::CheckIsLeaderRequest { tx }, rx).await?; Ok(()) } @@ -575,7 +576,7 @@ where C: RaftTypeConfig (Option>, Option>), RaftError>, > { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let (read_log_id, applied) = self.inner.call_core(RaftMsg::CheckIsLeaderRequest { tx }, rx).await?; Ok((read_log_id, applied)) } @@ -603,7 +604,7 @@ where C: RaftTypeConfig &self, app_data: C::D, ) -> Result, RaftError>> { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); self.inner.call_core(RaftMsg::ClientWriteRequest { app_data, tx }, rx).await } @@ -636,7 +637,7 @@ where C: RaftTypeConfig where T: IntoNodes + Debug, { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); self.inner .call_core( RaftMsg::Initialize { @@ -671,7 +672,7 @@ where C: RaftTypeConfig node: C::Node, blocking: bool, ) -> Result, RaftError>> { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let resp = self .inner .call_core( @@ -801,7 +802,7 @@ where C: RaftTypeConfig "change_membership: start to commit joint config" ); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); // res is error if membership can not be changed. // If no error, it will enter a joint state let res = self @@ -832,7 +833,7 @@ where C: RaftTypeConfig tracing::debug!("committed a joint config: {} {:?}", log_id, joint); tracing::debug!("the second step is to change to uniform config: {:?}", changes); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); let res = self.inner.call_core(RaftMsg::ChangeMembership { changes, retain, tx }, rx).await; if let Err(e) = &res { @@ -862,10 +863,12 @@ where C: RaftTypeConfig /// ``` pub async fn with_raft_state(&self, func: F) -> Result> where - F: FnOnce(&RaftState::Instant>) -> V + Send + 'static, - V: Send + 'static, + F: FnOnce(&RaftState::Instant>) -> V + + OptionalSend + + 'static, + V: OptionalSend + 'static, { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = C::AsyncRuntime::oneshot(); self.external_request(|st| { let result = func(st); @@ -899,7 +902,8 @@ where C: RaftTypeConfig /// If the API channel is already closed (Raft is in shutdown), then the request functor is /// destroyed right away and not called at all. pub fn external_request(&self, req: F) - where F: FnOnce(&RaftState::Instant>) + Send + 'static { + where F: FnOnce(&RaftState::Instant>) + OptionalSend + 'static + { let req: BoxCoreFn = Box::new(req); let _ignore_error = self.inner.tx_api.send(RaftMsg::ExternalCoreRequest { req }); } diff --git a/openraft/src/raft/raft_inner.rs b/openraft/src/raft/raft_inner.rs index ceecf7f4c..d13fdfeec 100644 --- a/openraft/src/raft/raft_inner.rs +++ b/openraft/src/raft/raft_inner.rs @@ -3,7 +3,6 @@ use std::fmt::Debug; use std::sync::Arc; use tokio::sync::mpsc; -use tokio::sync::oneshot; use tokio::sync::watch; use tokio::sync::Mutex; use tracing::Level; @@ -17,9 +16,11 @@ use crate::error::RaftError; use crate::metrics::RaftDataMetrics; use crate::metrics::RaftServerMetrics; use crate::raft::core_state::CoreState; +use crate::type_config::alias::OneshotSenderOf; use crate::AsyncRuntime; use crate::Config; use crate::MessageSummary; +use crate::OptionalSend; use crate::RaftMetrics; use crate::RaftTypeConfig; @@ -39,7 +40,7 @@ where C: RaftTypeConfig // TODO(xp): it does not need to be a async mutex. #[allow(clippy::type_complexity)] - pub(in crate::raft) tx_shutdown: Mutex>>, + pub(in crate::raft) tx_shutdown: Mutex>>, pub(in crate::raft) core_state: Mutex>, /// The ongoing snapshot transmission. @@ -55,10 +56,11 @@ where C: RaftTypeConfig pub(crate) async fn call_core( &self, mes: RaftMsg, - rx: oneshot::Receiver>, + rx: ::OneshotReceiver>, ) -> Result> where - E: Debug, + E: Debug + OptionalSend, + T: OptionalSend, { let sum = if tracing::enabled!(Level::DEBUG) { Some(mes.summary()) diff --git a/openraft/src/replication/mod.rs b/openraft/src/replication/mod.rs index d0927629a..d716b5c93 100644 --- a/openraft/src/replication/mod.rs +++ b/openraft/src/replication/mod.rs @@ -733,7 +733,7 @@ where #[tracing::instrument(level = "info", skip_all)] async fn stream_snapshot( &mut self, - snapshot_rx: DataWithId>>>, + snapshot_rx: DataWithId>>>, ) -> Result>, ReplicationError> { let request_id = snapshot_rx.request_id(); let rx = snapshot_rx.into_data(); diff --git a/openraft/src/replication/request.rs b/openraft/src/replication/request.rs index bbcda719c..ee2ad6d07 100644 --- a/openraft/src/replication/request.rs +++ b/openraft/src/replication/request.rs @@ -22,7 +22,7 @@ where C: RaftTypeConfig Self::Data(Data::new_logs(id, log_id_range)) } - pub(crate) fn snapshot(id: Option, snapshot_rx: ResultReceiver>>) -> Self { + pub(crate) fn snapshot(id: Option, snapshot_rx: ResultReceiver>>) -> Self { Self::Data(Data::new_snapshot(id, snapshot_rx)) } @@ -73,7 +73,7 @@ where C: RaftTypeConfig { Heartbeat, Logs(DataWithId>), - Snapshot(DataWithId>>>), + Snapshot(DataWithId>>>), SnapshotCallback(DataWithId>), } @@ -148,7 +148,7 @@ where C: RaftTypeConfig Self::Logs(DataWithId::new(request_id, log_id_range)) } - pub(crate) fn new_snapshot(request_id: Option, snapshot_rx: ResultReceiver>>) -> Self { + pub(crate) fn new_snapshot(request_id: Option, snapshot_rx: ResultReceiver>>) -> Self { Self::Snapshot(DataWithId::new(request_id, snapshot_rx)) } diff --git a/openraft/src/storage/adapter.rs b/openraft/src/storage/adapter.rs index 145123938..fd6526e77 100644 --- a/openraft/src/storage/adapter.rs +++ b/openraft/src/storage/adapter.rs @@ -147,7 +147,7 @@ where S::get_log_reader(self.storage_mut().await.deref_mut()).await } - async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> + async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> where I: IntoIterator + OptionalSend { // Default implementation that calls the flush-before-return `append_to_log`. diff --git a/openraft/src/storage/callback.rs b/openraft/src/storage/callback.rs index a0bda0895..b4f666ae2 100644 --- a/openraft/src/storage/callback.rs +++ b/openraft/src/storage/callback.rs @@ -4,26 +4,27 @@ use std::io; use tokio::sync::oneshot; +use crate::async_runtime::AsyncOneshotSendExt; use crate::display_ext::DisplayOption; +use crate::type_config::alias::OneshotSenderOf; use crate::LogId; -use crate::NodeId; use crate::RaftTypeConfig; use crate::StorageIOError; /// A oneshot callback for completion of log io operation. -pub struct LogFlushed -where NID: NodeId +pub struct LogFlushed +where C: RaftTypeConfig { - last_log_id: Option>, - tx: oneshot::Sender>, io::Error>>, + last_log_id: Option>, + tx: OneshotSenderOf>, io::Error>>, } -impl LogFlushed -where NID: NodeId +impl LogFlushed +where C: RaftTypeConfig { pub(crate) fn new( - last_log_id: Option>, - tx: oneshot::Sender>, io::Error>>, + last_log_id: Option>, + tx: OneshotSenderOf>, io::Error>>, ) -> Self { Self { last_log_id, tx } } diff --git a/openraft/src/storage/v2.rs b/openraft/src/storage/v2.rs index 77cbf28fb..2c28b7f2a 100644 --- a/openraft/src/storage/v2.rs +++ b/openraft/src/storage/v2.rs @@ -120,7 +120,7 @@ where C: RaftTypeConfig /// /// - There must not be a **hole** in logs. Because Raft only examine the last log id to ensure /// correctness. - async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> + async fn append(&mut self, entries: I, callback: LogFlushed) -> Result<(), StorageError> where I: IntoIterator + OptionalSend, I::IntoIter: OptionalSend; diff --git a/openraft/src/testing/mod.rs b/openraft/src/testing/mod.rs index 09530dd52..13a216abc 100644 --- a/openraft/src/testing/mod.rs +++ b/openraft/src/testing/mod.rs @@ -6,12 +6,12 @@ use std::collections::BTreeSet; use anyerror::AnyError; pub use store_builder::StoreBuilder; pub use suite::Suite; -use tokio::sync::oneshot; use crate::entry::RaftEntry; use crate::log_id::RaftLogId; use crate::storage::LogFlushed; use crate::storage::RaftLogStorage; +use crate::AsyncRuntime; use crate::CommittedLeaderId; use crate::LogId; use crate::RaftTypeConfig; @@ -55,7 +55,7 @@ where let entries = entries.into_iter().collect::>(); let last_log_id = entries.last().map(|e| *e.get_log_id()).unwrap(); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = ::oneshot(); let cb = LogFlushed::new(Some(last_log_id), tx); log_store.append(entries, cb).await?; rx.await.unwrap().map_err(|e| StorageIOError::write_logs(AnyError::error(e)))?; diff --git a/openraft/src/testing/suite.rs b/openraft/src/testing/suite.rs index 0e4a20a34..49eb2a895 100644 --- a/openraft/src/testing/suite.rs +++ b/openraft/src/testing/suite.rs @@ -6,7 +6,6 @@ use std::time::Duration; use anyerror::AnyError; use maplit::btreeset; -use tokio::sync::oneshot; use crate::entry::RaftEntry; use crate::log_id::RaftLogId; @@ -1171,7 +1170,7 @@ where let entries = entries.into_iter().collect::>(); let last_log_id = *entries.last().unwrap().get_log_id(); - let (tx, rx) = oneshot::channel(); + let (tx, rx) = ::oneshot(); let cb = LogFlushed::new(Some(last_log_id), tx); diff --git a/openraft/src/timer/timeout_test.rs b/openraft/src/timer/timeout_test.rs index 98de77cc9..445ebd51a 100644 --- a/openraft/src/timer/timeout_test.rs +++ b/openraft/src/timer/timeout_test.rs @@ -1,11 +1,12 @@ use std::time::Duration; -use tokio::sync::oneshot; use tokio::time::sleep; use tokio::time::Instant; +use crate::async_runtime::AsyncOneshotSendExt; use crate::timer::timeout::RaftTimer; use crate::timer::Timeout; +use crate::AsyncRuntime; use crate::TokioRuntime; #[cfg(not(feature = "singlethreaded"))] @@ -24,7 +25,7 @@ fn test_timeout() -> anyhow::Result<()> { async fn test_timeout_inner() -> anyhow::Result<()> { tracing::info!("--- set timeout, recv result"); { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = ::oneshot(); let now = Instant::now(); let _t = Timeout::::new( || { @@ -43,7 +44,7 @@ async fn test_timeout_inner() -> anyhow::Result<()> { tracing::info!("--- update timeout"); { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = ::oneshot(); let now = Instant::now(); let t = Timeout::::new( || { @@ -65,7 +66,7 @@ async fn test_timeout_inner() -> anyhow::Result<()> { tracing::info!("--- update timeout to a lower value wont take effect"); { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = ::oneshot(); let now = Instant::now(); let t = Timeout::::new( || { @@ -87,7 +88,7 @@ async fn test_timeout_inner() -> anyhow::Result<()> { tracing::info!("--- drop the `Timeout` will cancel the callback"); { - let (tx, rx) = oneshot::channel(); + let (tx, rx) = ::oneshot(); let now = Instant::now(); let t = Timeout::::new( || { diff --git a/openraft/src/type_config.rs b/openraft/src/type_config.rs index 023c37cad..9fa647d1c 100644 --- a/openraft/src/type_config.rs +++ b/openraft/src/type_config.rs @@ -91,6 +91,9 @@ pub(crate) mod alias { pub(crate) type InstantOf = as crate::AsyncRuntime>::Instant; pub(crate) type TimeoutErrorOf = as crate::AsyncRuntime>::TimeoutError; pub(crate) type TimeoutOf = as crate::AsyncRuntime>::Timeout; + pub(crate) type OneshotSenderOf = as crate::AsyncRuntime>::OneshotSender; + pub(crate) type OneshotReceiverErrorOf = as crate::AsyncRuntime>::OneshotReceiverError; + pub(crate) type OneshotReceiverOf = as crate::AsyncRuntime>::OneshotReceiver; // Usually used types pub(crate) type LogIdOf = crate::LogId>; diff --git a/stores/rocksstore-v2/src/lib.rs b/stores/rocksstore-v2/src/lib.rs index 918c52f4f..4428ac1b0 100644 --- a/stores/rocksstore-v2/src/lib.rs +++ b/stores/rocksstore-v2/src/lib.rs @@ -353,7 +353,7 @@ impl RaftLogStorage for RocksLogStore { async fn append( &mut self, entries: I, - callback: LogFlushed, + callback: LogFlushed, ) -> Result<(), StorageError> where I: IntoIterator> + Send,