Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncRuntime::oneshot #1026

Merged
merged 11 commits into from
Mar 2, 2024
11 changes: 8 additions & 3 deletions cluster_benchmark/tests/benchmark/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use openraft::Entry;
use openraft::EntryPayload;
use openraft::LogId;
use openraft::OptionalSend;
use openraft::OptionalSync;
use openraft::RaftLogId;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
Expand Down Expand Up @@ -225,8 +224,14 @@ impl RaftLogStorage<TypeConfig> for Arc<LogStore> {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> + Send {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<TypeConfig>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
{
{
let mut log = self.log.write().await;
log.extend(entries.into_iter().map(|entry| (entry.get_log_id().index, entry)));
Expand Down
12 changes: 3 additions & 9 deletions examples/memstore/src/log_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<C: RaftTypeConfig> LogStoreInner<C> {
Ok(self.vote)
}

async fn append<I>(&mut self, entries: I, callback: LogFlushed<C::NodeId>) -> Result<(), StorageError<C::NodeId>>
async fn append<I>(&mut self, entries: I, callback: LogFlushed<C>) -> Result<(), StorageError<C::NodeId>>
where I: IntoIterator<Item = C::Entry> {
// Simple implementation that calls the flush-before-return `append_to_log`.
for entry in entries {
Expand Down Expand Up @@ -188,14 +188,8 @@ mod impl_log_store {
inner.read_vote().await
}

async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<C::NodeId>,
) -> Result<(), StorageError<C::NodeId>>
where
I: IntoIterator<Item = C::Entry>,
{
async fn append<I>(&mut self, entries: I, callback: LogFlushed<C>) -> Result<(), StorageError<C::NodeId>>
where I: IntoIterator<Item = C::Entry> {
let mut inner = self.inner.lock().await;
inner.append(entries, callback).await
}
Expand Down
2 changes: 1 addition & 1 deletion examples/raft-kv-memstore-singlethreaded/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl RaftLogStorage<TypeConfig> for Rc<LogStore> {
}

#[tracing::instrument(level = "trace", skip(self, entries, callback))]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
async fn append<I>(&mut self, entries: I, callback: LogFlushed<TypeConfig>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> {
// Simple implementation that calls the flush-before-return `append_to_log`.
let mut log = self.log.borrow_mut();
Expand Down
2 changes: 1 addition & 1 deletion examples/raft-kv-rocksdb/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ impl RaftLogStorage<TypeConfig> for LogStore {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> StorageResult<()>
async fn append<I>(&mut self, entries: I, callback: LogFlushed<TypeConfig>) -> StorageResult<()>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
I::IntoIter: Send,
Expand Down
63 changes: 61 additions & 2 deletions openraft/src/async_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::TokioInstant;
/// ## Note
///
/// The default asynchronous runtime is `tokio`.
pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static {
pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + OptionalSync + 'static {
/// The error type of [`Self::JoinHandle`].
type JoinError: Debug + Display + OptionalSend;

Expand All @@ -44,6 +44,18 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// Type of a thread-local random number generator.
type ThreadLocalRng: rand::Rng;

/// Type of a `oneshot` sender.
type OneshotSender<T: OptionalSend>: AsyncOneshotSendExt<T> + OptionalSend + OptionalSync + Debug + Sized;

/// Type of a `oneshot` receiver error.
type OneshotReceiverError: std::error::Error + OptionalSend;

/// Type of a `oneshot` receiver.
type OneshotReceiver<T: OptionalSend>: OptionalSend
+ OptionalSync
+ Future<Output = Result<T, Self::OneshotReceiverError>>
+ Unpin;

/// Spawn a new task.
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
where
Expand Down Expand Up @@ -72,12 +84,24 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// This is a per-thread instance, which cannot be shared across threads or
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;

/// Creates a new one-shot channel for sending single values.
///
/// The function returns separate "send" and "receive" handles. The `Sender`
/// handle is used by the producer to send the value. The `Receiver` handle is
/// used by the consumer to receive the value.
///
/// Each handle can be used on separate tasks.
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend;
}

/// `Tokio` is the default asynchronous executor.
#[derive(Debug, Default)]
#[derive(Debug, Default, PartialEq, Eq)]
pub struct TokioRuntime;

pub struct TokioOneShotSender<T: OptionalSend>(pub tokio::sync::oneshot::Sender<T>);

impl AsyncRuntime for TokioRuntime {
type JoinError = tokio::task::JoinError;
type JoinHandle<T: OptionalSend + 'static> = tokio::task::JoinHandle<T>;
Expand All @@ -86,6 +110,9 @@ impl AsyncRuntime for TokioRuntime {
type TimeoutError = tokio::time::error::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
type ThreadLocalRng = rand::rngs::ThreadRng;
type OneshotSender<T: OptionalSend> = TokioOneShotSender<T>;
type OneshotReceiver<T: OptionalSend> = tokio::sync::oneshot::Receiver<T>;
type OneshotReceiverError = tokio::sync::oneshot::error::RecvError;

#[inline]
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
Expand Down Expand Up @@ -132,4 +159,36 @@ impl AsyncRuntime for TokioRuntime {
fn thread_rng() -> Self::ThreadLocalRng {
rand::thread_rng()
}

#[inline]
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend {
let (tx, rx) = tokio::sync::oneshot::channel();
(TokioOneShotSender(tx), rx)
}
}

pub trait AsyncOneshotSendExt<T>: Unpin {
/// Attempts to send a value on this channel, returning it back if it could
/// not be sent.
///
/// This method consumes `self` as only one value may ever be sent on a `oneshot`
/// channel. It is not marked async because sending a message to an `oneshot`
/// channel never requires any form of waiting. Because of this, the `send`
/// method can be used in both synchronous and asynchronous code without
/// problems.
fn send(self, t: T) -> Result<(), T>;
}

impl<T: OptionalSend> AsyncOneshotSendExt<T> for TokioOneShotSender<T> {
#[inline]
fn send(self, t: T) -> Result<(), T> {
self.0.send(t)
}
}

impl<T: OptionalSend> Debug for TokioOneShotSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TokioSendWrapper").finish()
}
}
35 changes: 20 additions & 15 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::config::Config;
use crate::config::RuntimeConfig;
use crate::core::balancer::Balancer;
Expand Down Expand Up @@ -215,7 +215,10 @@ where
SM: RaftStateMachine<C>,
{
/// The main loop of the Raft protocol.
pub(crate) async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
pub(crate) async fn main(
mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
let span = tracing::span!(parent: &self.span, Level::DEBUG, "main");
let res = self.do_main(rx_shutdown).instrument(span).await;

Expand All @@ -239,7 +242,10 @@ where
}

#[tracing::instrument(level="trace", skip_all, fields(id=display(self.id), cluster=%self.config.cluster_name))]
async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn do_main(
&mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
tracing::debug!("raft node is initializing");

self.engine.startup();
Expand Down Expand Up @@ -432,7 +438,7 @@ where
&mut self,
changes: ChangeMembers<C::NodeId, C::Node>,
retain: bool,
tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<C, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
) {
let res = self.engine.state.membership_state.change_handler().apply(changes, retain);
let new_membership = match res {
Expand Down Expand Up @@ -593,7 +599,7 @@ where
pub(crate) fn handle_initialize(
&mut self,
member_nodes: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<C, (), InitializeError<C::NodeId, C::Node>>,
) {
tracing::debug!(member_nodes = debug(&member_nodes), "{}", func_name!());

Expand All @@ -616,7 +622,7 @@ where

/// Reject a request due to the Raft node being in a state which prohibits the request.
#[tracing::instrument(level = "trace", skip(self, tx))]
pub(crate) fn reject_with_forward_to_leader<T, E>(&self, tx: ResultSender<T, E>)
pub(crate) fn reject_with_forward_to_leader<T: OptionalSend, E: OptionalSend>(&self, tx: ResultSender<C, T, E>)
where E: From<ForwardToLeader<C::NodeId, C::Node>> {
let mut leader_id = self.current_leader();
let leader_node = self.get_leader_node(leader_id);
Expand Down Expand Up @@ -680,7 +686,7 @@ where
{
tracing::debug!("append_to_log");

let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();
let callback = LogFlushed::new(Some(last_log_id), tx);
self.log_store.append(entries, callback).await?;
rx.await
Expand Down Expand Up @@ -865,7 +871,10 @@ where

/// Run an event handling loop
#[tracing::instrument(level="debug", skip_all, fields(id=display(self.id)))]
async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn runtime_loop(
&mut self,
mut rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
// Ratio control the ratio of number of RaftMsg to process to number of Notify to process.
let mut balancer = Balancer::new(10_000);

Expand Down Expand Up @@ -1067,7 +1076,7 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C::NodeId>) {
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C>) {
tracing::info!(req = display(req.summary()), func = func_name!());

let resp = self.engine.handle_vote_req(req);
Expand All @@ -1078,11 +1087,7 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_append_entries_request(
&mut self,
req: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
) {
pub(super) fn handle_append_entries_request(&mut self, req: AppendEntriesRequest<C>, tx: AppendEntriesTx<C>) {
tracing::debug!(req = display(req.summary()), func = func_name!());

let is_ok = self.engine.handle_append_entries(&req.vote, req.prev_log_id, req.entries, Some(tx));
Expand Down Expand Up @@ -1657,7 +1662,7 @@ where

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
Expand Down
2 changes: 1 addition & 1 deletion openraft/src/core/raft_msg/external_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) enum ExternalCommand<C: RaftTypeConfig> {
Snapshot,

/// Get a snapshot from the state machine, send back via a oneshot::Sender.
GetSnapshot { tx: ResultSender<Option<Snapshot<C>>> },
GetSnapshot { tx: ResultSender<C, Option<Snapshot<C>>> },

/// Purge logs covered by a snapshot up to a specified index.
///
Expand Down
30 changes: 16 additions & 14 deletions openraft/src/core/raft_msg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::collections::BTreeMap;

use tokio::sync::oneshot;

use crate::core::raft_msg::external_command::ExternalCommand;
use crate::error::CheckIsLeaderError;
use crate::error::ClientWriteError;
Expand All @@ -15,10 +13,13 @@ use crate::raft::ClientWriteResponse;
use crate::raft::SnapshotResponse;
use crate::raft::VoteRequest;
use crate::raft::VoteResponse;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::NodeIdOf;
use crate::type_config::alias::NodeOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::AsyncRuntime;
use crate::ChangeMembers;
use crate::MessageSummary;
use crate::RaftTypeConfig;
Expand All @@ -28,22 +29,23 @@ use crate::Vote;
pub(crate) mod external_command;

/// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`.
pub(crate) type ResultSender<T, E = Infallible> = oneshot::Sender<Result<T, E>>;
pub(crate) type ResultSender<C, T, E = Infallible> = OneshotSenderOf<C, Result<T, E>>;

pub(crate) type ResultReceiver<T, E = Infallible> = oneshot::Receiver<Result<T, E>>;
pub(crate) type ResultReceiver<C, T, E = Infallible> =
<AsyncRuntimeOf<C> as AsyncRuntime>::OneshotReceiver<Result<T, E>>;

/// TX for Vote Response
pub(crate) type VoteTx<NID> = ResultSender<VoteResponse<NID>>;
pub(crate) type VoteTx<C> = ResultSender<C, VoteResponse<NodeIdOf<C>>>;

/// TX for Append Entries Response
pub(crate) type AppendEntriesTx<NID> = ResultSender<AppendEntriesResponse<NID>>;
pub(crate) type AppendEntriesTx<C> = ResultSender<C, AppendEntriesResponse<NodeIdOf<C>>>;

/// TX for Client Write Response
pub(crate) type ClientWriteTx<C> = ResultSender<ClientWriteResponse<C>, ClientWriteError<NodeIdOf<C>, NodeOf<C>>>;
pub(crate) type ClientWriteTx<C> = ResultSender<C, ClientWriteResponse<C>, ClientWriteError<NodeIdOf<C>, NodeOf<C>>>;

/// TX for Linearizable Read Response
pub(crate) type ClientReadTx<C> =
ResultSender<(Option<LogIdOf<C>>, Option<LogIdOf<C>>), CheckIsLeaderError<NodeIdOf<C>, NodeOf<C>>>;
ResultSender<C, (Option<LogIdOf<C>>, Option<LogIdOf<C>>), CheckIsLeaderError<NodeIdOf<C>, NodeOf<C>>>;

/// A message sent by application to the [`RaftCore`].
///
Expand All @@ -53,18 +55,18 @@ where C: RaftTypeConfig
{
AppendEntries {
rpc: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
tx: AppendEntriesTx<C>,
},

RequestVote {
rpc: VoteRequest<C::NodeId>,
tx: VoteTx<C::NodeId>,
tx: VoteTx<C>,
},

InstallFullSnapshot {
vote: Vote<C::NodeId>,
snapshot: Snapshot<C>,
tx: ResultSender<SnapshotResponse<C::NodeId>>,
tx: ResultSender<C, SnapshotResponse<C::NodeId>>,
},

/// Begin receiving a snapshot from the leader.
Expand All @@ -74,7 +76,7 @@ where C: RaftTypeConfig
/// will be returned in a Err
BeginReceivingSnapshot {
vote: Vote<C::NodeId>,
tx: ResultSender<Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
tx: ResultSender<C, Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
},

ClientWriteRequest {
Expand All @@ -88,7 +90,7 @@ where C: RaftTypeConfig

Initialize {
members: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<C, (), InitializeError<C::NodeId, C::Node>>,
},

ChangeMembership {
Expand All @@ -98,7 +100,7 @@ where C: RaftTypeConfig
/// config will be converted into learners, otherwise they will be removed.
retain: bool,

tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<C, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
},

ExternalCoreRequest {
Expand Down
Loading
Loading