Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AsyncRuntime::oneshot #1025

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions openraft/src/async_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// Type of a thread-local random number generator.
type ThreadLocalRng: rand::Rng;

/// Type of a `oneshot` sender.
type OneshotSender<T: OptionalSend>: AsyncOneshotSendExt<T> + OptionalSend + OptionalSync + Debug + Sized;

type OneshotReceiverError: std::error::Error + OptionalSend;

/// Type of a `oneshot` receiver.
type OneshotReceiver<T: OptionalSend>: OptionalSend
+ OptionalSync
+ Future<Output = Result<T, Self::OneshotReceiverError>>
+ Unpin;

/// Spawn a new task.
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
where
Expand Down Expand Up @@ -72,12 +83,24 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// This is a per-thread instance, which cannot be shared across threads or
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;

/// Creates a new one-shot channel for sending single values.
///
/// The function returns separate "send" and "receive" handles. The `Sender`
/// handle is used by the producer to send the value. The `Receiver` handle is
/// used by the consumer to receive the value.
///
/// Each handle can be used on separate tasks.
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend;
}

/// `Tokio` is the default asynchronous executor.
#[derive(Debug, Default)]
pub struct TokioRuntime;

pub struct TokioSendWrapper<T: OptionalSend>(pub tokio::sync::oneshot::Sender<T>);

impl AsyncRuntime for TokioRuntime {
type JoinError = tokio::task::JoinError;
type JoinHandle<T: OptionalSend + 'static> = tokio::task::JoinHandle<T>;
Expand All @@ -86,6 +109,9 @@ impl AsyncRuntime for TokioRuntime {
type TimeoutError = tokio::time::error::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
type ThreadLocalRng = rand::rngs::ThreadRng;
type OneshotSender<T: OptionalSend> = TokioSendWrapper<T>;
type OneshotReceiver<T: OptionalSend> = tokio::sync::oneshot::Receiver<T>;
type OneshotReceiverError = tokio::sync::oneshot::error::RecvError;

#[inline]
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
Expand Down Expand Up @@ -132,4 +158,42 @@ impl AsyncRuntime for TokioRuntime {
fn thread_rng() -> Self::ThreadLocalRng {
rand::thread_rng()
}

#[inline]
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend {
let (tx, rx) = tokio::sync::oneshot::channel();
(TokioSendWrapper(tx), rx)
}
}

pub trait AsyncOneshotSendExt<T>: Unpin {
/// Attempts to send a value on this channel, returning it back if it could
/// not be sent.
///
/// This method consumes `self` as only one value may ever be sent on a `oneshot`
/// channel. It is not marked async because sending a message to an `oneshot`
/// channel never requires any form of waiting. Because of this, the `send`
/// method can be used in both synchronous and asynchronous code without
/// problems.
fn send(self, t: T) -> Result<(), T>;
}

impl<T: OptionalSend> AsyncOneshotSendExt<T> for TokioSendWrapper<T> {
#[inline]
fn send(self, t: T) -> Result<(), T> {
self.0.send(t)
}
}

impl<T: OptionalSend> Debug for TokioSendWrapper<T> {
default fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TokioSendWrapper").finish()
}
}

impl<T: Debug + OptionalSend> Debug for TokioSendWrapper<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TokioSendWrapper").field(&self.0).finish()
}
}
41 changes: 29 additions & 12 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::config::Config;
use crate::config::RuntimeConfig;
use crate::core::balancer::Balancer;
Expand Down Expand Up @@ -215,7 +215,10 @@ where
SM: RaftStateMachine<C>,
{
/// The main loop of the Raft protocol.
pub(crate) async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
pub(crate) async fn main(
mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
let span = tracing::span!(parent: &self.span, Level::DEBUG, "main");
let res = self.do_main(rx_shutdown).instrument(span).await;

Expand All @@ -239,7 +242,10 @@ where
}

#[tracing::instrument(level="trace", skip_all, fields(id=display(self.id), cluster=%self.config.cluster_name))]
async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn do_main(
&mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
tracing::debug!("raft node is initializing");

self.engine.startup();
Expand Down Expand Up @@ -432,7 +438,7 @@ where
&mut self,
changes: ChangeMembers<C::NodeId, C::Node>,
retain: bool,
tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<AsyncRuntimeOf<C>, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
) {
let res = self.engine.state.membership_state.change_handler().apply(changes, retain);
let new_membership = match res {
Expand Down Expand Up @@ -593,7 +599,7 @@ where
pub(crate) fn handle_initialize(
&mut self,
member_nodes: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<AsyncRuntimeOf<C>, (), InitializeError<C::NodeId, C::Node>>,
) {
tracing::debug!(member_nodes = debug(&member_nodes), "{}", func_name!());

Expand All @@ -616,8 +622,12 @@ where

/// Reject a request due to the Raft node being in a state which prohibits the request.
#[tracing::instrument(level = "trace", skip(self, tx))]
pub(crate) fn reject_with_forward_to_leader<T, E>(&self, tx: ResultSender<T, E>)
where E: From<ForwardToLeader<C::NodeId, C::Node>> {
pub(crate) fn reject_with_forward_to_leader<T: OptionalSend, E: OptionalSend>(
&self,
tx: ResultSender<AsyncRuntimeOf<C>, T, E>,
) where
E: From<ForwardToLeader<C::NodeId, C::Node>>,
{
let mut leader_id = self.current_leader();
let leader_node = self.get_leader_node(leader_id);

Expand Down Expand Up @@ -680,7 +690,7 @@ where
{
tracing::debug!("append_to_log");

let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();
let callback = LogFlushed::new(Some(last_log_id), tx);
self.log_store.append(entries, callback).await?;
rx.await
Expand Down Expand Up @@ -865,7 +875,10 @@ where

/// Run an event handling loop
#[tracing::instrument(level="debug", skip_all, fields(id=display(self.id)))]
async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn runtime_loop(
&mut self,
mut rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
// Ratio control the ratio of number of RaftMsg to process to number of Notify to process.
let mut balancer = Balancer::new(10_000);

Expand Down Expand Up @@ -1067,7 +1080,11 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C::NodeId>) {
pub(super) fn handle_vote_request(
&mut self,
req: VoteRequest<C::NodeId>,
tx: VoteTx<AsyncRuntimeOf<C>, C::NodeId>,
) {
tracing::info!(req = display(req.summary()), func = func_name!());

let resp = self.engine.handle_vote_req(req);
Expand All @@ -1081,7 +1098,7 @@ where
pub(super) fn handle_append_entries_request(
&mut self,
req: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
tx: AppendEntriesTx<AsyncRuntimeOf<C>, C::NodeId>,
) {
tracing::debug!(req = display(req.summary()), func = func_name!());

Expand Down Expand Up @@ -1657,7 +1674,7 @@ where

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
Expand Down
5 changes: 4 additions & 1 deletion openraft/src/core/raft_msg/external_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::fmt;

use crate::core::raft_msg::ResultSender;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::RaftTypeConfig;
use crate::Snapshot;

Expand All @@ -23,7 +24,9 @@ pub(crate) enum ExternalCommand<C: RaftTypeConfig> {
Snapshot,

/// Get a snapshot from the state machine, send back via a oneshot::Sender.
GetSnapshot { tx: ResultSender<Option<Snapshot<C>>> },
GetSnapshot {
tx: ResultSender<AsyncRuntimeOf<C>, Option<Snapshot<C>>>,
},

/// Purge logs covered by a snapshot up to a specified index.
///
Expand Down
34 changes: 19 additions & 15 deletions openraft/src/core/raft_msg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::collections::BTreeMap;

use tokio::sync::oneshot;

use crate::core::raft_msg::external_command::ExternalCommand;
use crate::error::CheckIsLeaderError;
use crate::error::ClientWriteError;
Expand All @@ -15,10 +13,12 @@ use crate::raft::ClientWriteResponse;
use crate::raft::SnapshotResponse;
use crate::raft::VoteRequest;
use crate::raft::VoteResponse;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::NodeIdOf;
use crate::type_config::alias::NodeOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::AsyncRuntime;
use crate::ChangeMembers;
use crate::MessageSummary;
use crate::RaftTypeConfig;
Expand All @@ -28,22 +28,26 @@ use crate::Vote;
pub(crate) mod external_command;

/// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`.
pub(crate) type ResultSender<T, E = Infallible> = oneshot::Sender<Result<T, E>>;
pub(crate) type ResultSender<Runtime, T, E = Infallible> = <Runtime as AsyncRuntime>::OneshotSender<Result<T, E>>;

pub(crate) type ResultReceiver<T, E = Infallible> = oneshot::Receiver<Result<T, E>>;
pub(crate) type ResultReceiver<Runtime, T, E = Infallible> = <Runtime as AsyncRuntime>::OneshotReceiver<Result<T, E>>;

/// TX for Vote Response
pub(crate) type VoteTx<NID> = ResultSender<VoteResponse<NID>>;
pub(crate) type VoteTx<Runtime, NID> = ResultSender<Runtime, VoteResponse<NID>>;

/// TX for Append Entries Response
pub(crate) type AppendEntriesTx<NID> = ResultSender<AppendEntriesResponse<NID>>;
pub(crate) type AppendEntriesTx<Runtime, NID> = ResultSender<Runtime, AppendEntriesResponse<NID>>;

/// TX for Client Write Response
pub(crate) type ClientWriteTx<C> = ResultSender<ClientWriteResponse<C>, ClientWriteError<NodeIdOf<C>, NodeOf<C>>>;
pub(crate) type ClientWriteTx<C> =
ResultSender<AsyncRuntimeOf<C>, ClientWriteResponse<C>, ClientWriteError<NodeIdOf<C>, NodeOf<C>>>;

/// TX for Linearizable Read Response
pub(crate) type ClientReadTx<C> =
ResultSender<(Option<LogIdOf<C>>, Option<LogIdOf<C>>), CheckIsLeaderError<NodeIdOf<C>, NodeOf<C>>>;
pub(crate) type ClientReadTx<C> = ResultSender<
AsyncRuntimeOf<C>,
(Option<LogIdOf<C>>, Option<LogIdOf<C>>),
CheckIsLeaderError<NodeIdOf<C>, NodeOf<C>>,
>;

/// A message sent by application to the [`RaftCore`].
///
Expand All @@ -53,18 +57,18 @@ where C: RaftTypeConfig
{
AppendEntries {
rpc: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
tx: AppendEntriesTx<AsyncRuntimeOf<C>, C::NodeId>,
},

RequestVote {
rpc: VoteRequest<C::NodeId>,
tx: VoteTx<C::NodeId>,
tx: VoteTx<AsyncRuntimeOf<C>, C::NodeId>,
},

InstallCompleteSnapshot {
vote: Vote<C::NodeId>,
snapshot: Snapshot<C>,
tx: ResultSender<SnapshotResponse<C::NodeId>>,
tx: ResultSender<AsyncRuntimeOf<C>, SnapshotResponse<C::NodeId>>,
},

/// Begin receiving a snapshot from the leader.
Expand All @@ -74,7 +78,7 @@ where C: RaftTypeConfig
/// will be returned in a Err
BeginReceivingSnapshot {
vote: Vote<C::NodeId>,
tx: ResultSender<Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
tx: ResultSender<AsyncRuntimeOf<C>, Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
},

ClientWriteRequest {
Expand All @@ -88,7 +92,7 @@ where C: RaftTypeConfig

Initialize {
members: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<AsyncRuntimeOf<C>, (), InitializeError<C::NodeId, C::Node>>,
},

ChangeMembership {
Expand All @@ -98,7 +102,7 @@ where C: RaftTypeConfig
/// config will be converted into learners, otherwise they will be removed.
retain: bool,

tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<AsyncRuntimeOf<C>, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
},

ExternalCoreRequest {
Expand Down
10 changes: 6 additions & 4 deletions openraft/src/core/sm/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ where C: RaftTypeConfig
Command::new(payload)
}

pub(crate) fn get_snapshot(tx: ResultSender<Option<Snapshot<C>>>) -> Self {
pub(crate) fn get_snapshot(tx: ResultSender<C::AsyncRuntime, Option<Snapshot<C>>>) -> Self {
let payload = CommandPayload::GetSnapshot { tx };
Command::new(payload)
}

pub(crate) fn begin_receiving_snapshot(tx: ResultSender<Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>) -> Self {
pub(crate) fn begin_receiving_snapshot(
tx: ResultSender<C::AsyncRuntime, Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
) -> Self {
let payload = CommandPayload::BeginReceivingSnapshot { tx };
Command::new(payload)
}
Expand Down Expand Up @@ -91,11 +93,11 @@ where C: RaftTypeConfig

/// Get the latest built snapshot.
GetSnapshot {
tx: ResultSender<Option<Snapshot<C>>>,
tx: ResultSender<C::AsyncRuntime, Option<Snapshot<C>>>,
},

BeginReceivingSnapshot {
tx: ResultSender<Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
tx: ResultSender<C::AsyncRuntime, Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
},

InstallCompleteSnapshot {
Expand Down
7 changes: 6 additions & 1 deletion openraft/src/core/sm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

use tokio::sync::mpsc;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::core::ApplyResult;
use crate::core::ApplyingEntry;
use crate::entry::RaftPayload;
use crate::storage::RaftStateMachine;
use crate::summary::MessageSummary;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::AsyncRuntime;
use crate::RaftLogId;
use crate::RaftSnapshotBuilder;
Expand Down Expand Up @@ -219,7 +221,10 @@ where
}

#[tracing::instrument(level = "info", skip_all)]
async fn get_snapshot(&mut self, tx: ResultSender<Option<Snapshot<C>>>) -> Result<(), StorageError<C::NodeId>> {
async fn get_snapshot(
&mut self,
tx: ResultSender<AsyncRuntimeOf<C>, Option<Snapshot<C>>>,
) -> Result<(), StorageError<C::NodeId>> {
tracing::info!("{}", func_name!());

let snapshot = self.state_machine.get_current_snapshot().await?;
Expand Down
Loading
Loading