Skip to content

Commit

Permalink
Merge pull request #91 from fjarri/tokio
Browse files Browse the repository at this point in the history
Add a `session::tokio` module with a convenience function for executing a session
  • Loading branch information
fjarri authored Feb 10, 2025
2 parents f4c9515 + 20f4509 commit 40d1f38
Show file tree
Hide file tree
Showing 11 changed files with 822 additions and 242 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Conversion from `u8` to `RoundId` and comparison of `RoundId` with `u8`. ([#84])
- `Misbehaving::override_finalize()` for malicious finalization logic. ([#87])
- `From<SerializableMap<K, V>> for BTreeMap<K, V>` impl. ([#88])
- `session::tokio` submodule containing functions for executing a session in an async `tokio` environment and supporting types. Gated behind the `tokio` feature. ([#91])
- `dev::tokio` submoduloe containing functions for executing multiple sessions in an async `tokio` environment. Gated behind the `tokio` feature. ([#91])


### Fixed
Expand All @@ -52,6 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#86]: https://github.com/entropyxyz/manul/pull/86
[#87]: https://github.com/entropyxyz/manul/pull/87
[#88]: https://github.com/entropyxyz/manul/pull/88
[#91]: https://github.com/entropyxyz/manul/pull/91


## [0.1.0] - 2024-11-19
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ displaydoc = "0.2"
tokio = { version = "1", features = ["rt", "sync", "time", "macros"] }
rand = "0.8"
digest = "0.10"
manul = { path = "../manul", features = ["dev"] }
manul = { path = "../manul", features = ["dev", "tokio"] }
test-log = { version = "0.2", features = ["trace", "color"] }
260 changes: 20 additions & 240 deletions examples/tests/async_runner.rs
Original file line number Diff line number Diff line change
@@ -1,263 +1,43 @@
extern crate alloc;

use alloc::collections::{BTreeMap, BTreeSet};
use alloc::collections::BTreeSet;

use manul::{
dev::{BinaryFormat, TestSessionParams, TestSigner},
protocol::Protocol,
session::{CanFinalize, LocalError, Message, RoundOutcome, Session, SessionId, SessionParameters, SessionReport},
dev::{tokio::run_async, BinaryFormat, TestSessionParams, TestSigner},
signature::Keypair,
};
use manul_example::simple::{SimpleProtocol, SimpleProtocolEntryPoint};
use rand::Rng;
use manul_example::simple::SimpleProtocolEntryPoint;
use rand_core::OsRng;
use tokio::{
sync::mpsc,
time::{sleep, Duration},
};
use tracing::{debug, trace};

struct MessageOut<SP: SessionParameters> {
from: SP::Verifier,
to: SP::Verifier,
message: Message<SP::Verifier>,
}

struct MessageIn<SP: SessionParameters> {
from: SP::Verifier,
message: Message<SP::Verifier>,
}

/// Runs a session. Simulates what each participating party would run as the protocol progresses.
async fn run_session<P, SP>(
tx: mpsc::Sender<MessageOut<SP>>,
rx: mpsc::Receiver<MessageIn<SP>>,
session: Session<P, SP>,
) -> Result<SessionReport<P, SP>, LocalError>
where
P: Protocol<SP::Verifier>,
SP: SessionParameters,
{
let rng = &mut OsRng;

let mut rx = rx;

let mut session = session;
// Some rounds can finalize early and put off sending messages to the next round. Such messages
// will be stored here and applied after the messages for this round are sent.
let mut cached_messages = Vec::new();

let key = session.verifier();

// Each iteration of the loop progresses the session as follows:
// - Send out messages as dictated by the session "destinations".
// - Apply any cached messages.
// - Enter a nested loop:
// - Try to finalize the session; if we're done, exit the inner loop.
// - Wait until we get an incoming message.
// - Process the message we received and continue the loop.
// - When all messages have been sent and received as specified by the protocol, finalize the
// round.
// - If the protocol outcome is a new round, go to the top of the loop and start over with a
// new session.
loop {
debug!("{key:?}: *** starting round {:?} ***", session.round_id());

// This is kept in the main task since it's mutable,
// and we don't want to bother with synchronization.
let mut accum = session.make_accumulator();

// Note: generating/sending messages and verifying newly received messages
// can be done in parallel, with the results being assembled into `accum`
// sequentially in the host task.

let destinations = session.message_destinations();
for destination in destinations.iter() {
// In production usage, this will happen in a spawned task
// (since it can take some time to create a message),
// and the artifact will be sent back to the host task
// to be added to the accumulator.
let (message, artifact) = session.make_message(rng, destination)?;
debug!("{key:?}: Sending a message to {destination:?}",);
tx.send(MessageOut {
from: key.clone(),
to: destination.clone(),
message,
})
.await
.unwrap();

// This would happen in a host task
session.add_artifact(&mut accum, artifact)?;
}

for preprocessed in cached_messages {
// In production usage, this would happen in a spawned task and relayed back to the main task.
debug!("{key:?}: Applying a cached message");
let processed = session.process_message(preprocessed);

// This would happen in a host task.
session.add_processed_message(&mut accum, processed)?;
}

loop {
match session.can_finalize(&accum) {
CanFinalize::Yes => break,
CanFinalize::NotYet => {}
// Due to already registered invalid messages from nodes,
// even if the remaining nodes send correct messages, it won't be enough.
// Terminating.
CanFinalize::Never => {
tracing::warn!("{key:?}: This session cannot ever be finalized. Terminating.");
return session.terminate_due_to_errors(accum);
}
}

debug!("{key:?}: Waiting for a message");
let incoming = rx.recv().await.unwrap();

// Perform quick checks before proceeding with the verification.
match session
.preprocess_message(&mut accum, &incoming.from, incoming.message)?
.ok()
{
Some(preprocessed) => {
// In production usage, this would happen in a separate task.
debug!("{key:?}: Applying a message from {:?}", incoming.from);
let processed = session.process_message(preprocessed);
// In production usage, this would be a host task.
session.add_processed_message(&mut accum, processed)?;
}
None => {
trace!("{key:?} Pre-processing complete. Current state: {accum:?}")
}
}
}

debug!("{key:?}: Finalizing the round");

match session.finalize_round(rng, accum)? {
RoundOutcome::Finished(report) => break Ok(report),
RoundOutcome::AnotherRound {
session: new_session,
cached_messages: new_cached_messages,
} => {
session = new_session;
cached_messages = new_cached_messages;
}
}
}
}

async fn message_dispatcher<SP>(
txs: BTreeMap<SP::Verifier, mpsc::Sender<MessageIn<SP>>>,
rx: mpsc::Receiver<MessageOut<SP>>,
) where
SP: SessionParameters,
{
let mut rx = rx;
let mut messages = Vec::<MessageOut<SP>>::new();
loop {
let msg = match rx.recv().await {
Some(msg) => msg,
None => break,
};
messages.push(msg);

while let Ok(msg) = rx.try_recv() {
messages.push(msg)
}

while !messages.is_empty() {
// Pull a random message from the list,
// to increase the chances that they are delivered out of order.
let message_idx = rand::thread_rng().gen_range(0..messages.len());
let outgoing = messages.swap_remove(message_idx);

txs[&outgoing.to]
.send(MessageIn {
from: outgoing.from,
message: outgoing.message,
})
.await
.unwrap();

// Give up execution so that the tasks could process messages.
sleep(Duration::from_millis(0)).await;

if let Ok(msg) = rx.try_recv() {
messages.push(msg);
};
}
}
}

async fn run_nodes<P, SP>(sessions: Vec<Session<P, SP>>) -> Vec<SessionReport<P, SP>>
where
P: Protocol<SP::Verifier> + Send,
SP: SessionParameters,
P::Result: Send,
SP::Signer: Send,
{
let num_parties = sessions.len();

let (dispatcher_tx, dispatcher_rx) = mpsc::channel::<MessageOut<SP>>(100);

let channels = (0..num_parties).map(|_| mpsc::channel::<MessageIn<SP>>(100));
let (txs, rxs): (Vec<_>, Vec<_>) = channels.unzip();
let tx_map = sessions
.iter()
.map(|session| session.verifier())
.zip(txs.into_iter())
.collect();

let dispatcher_task = message_dispatcher(tx_map, dispatcher_rx);
let dispatcher = tokio::spawn(dispatcher_task);

let handles = rxs
.into_iter()
.zip(sessions.into_iter())
.map(|(rx, session)| {
let node_task = run_session(dispatcher_tx.clone(), rx, session);
tokio::spawn(node_task)
})
.collect::<Vec<_>>();

// Drop the last copy of the dispatcher's incoming channel so that it can finish.
drop(dispatcher_tx);

let mut results = Vec::with_capacity(num_parties);
for handle in handles {
results.push(handle.await.unwrap().unwrap());
}

dispatcher.await.unwrap();

results
}

#[tokio::test]
async fn async_run() {
// The kind of Session we need to run the `SimpleProtocol`.
type SimpleSession = Session<SimpleProtocol, TestSessionParams<BinaryFormat>>;

async fn async_run(offload_processing: bool) {
// Create 4 parties
let signers = (0..3).map(TestSigner::new).collect::<Vec<_>>();
let all_ids = signers
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let session_id = SessionId::random::<TestSessionParams<BinaryFormat>>(&mut OsRng);

// Create 4 `Session`s
let sessions = signers
// Create 4 entry points
let entry_points = signers
.into_iter()
.map(|signer| {
let entry_point = SimpleProtocolEntryPoint::new(all_ids.clone());
SimpleSession::new(&mut OsRng, session_id.clone(), signer, entry_point).unwrap()
(signer, entry_point)
})
.collect::<Vec<_>>();

// Run the protocol
run_nodes(sessions).await;
run_async::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points, offload_processing)
.await
.unwrap();
}

#[tokio::test]
async fn async_run_no_offload() {
async_run(false).await
}

#[tokio::test]
async fn async_run_with_offload() {
async_run(true).await
}
7 changes: 7 additions & 0 deletions manul/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rand = { version = "0.8", default-features = false, optional = true }
serde-persistent-deserializer = { version = "0.3", optional = true }
postcard = { version = "1", default-features = false, features = ["alloc"], optional = true }
serde_json = { version = "1", default-features = false, features = ["alloc"], optional = true }
tokio = { version = "1", default-features = false, features = ["sync", "rt", "macros", "time"], optional = true }

[dev-dependencies]
impls = "1"
Expand All @@ -43,6 +44,7 @@ tracing = { version = "0.1", default-features = false, features = ["std"] }

[features]
dev = ["rand", "postcard", "serde_json", "tracing/std", "serde-persistent-deserializer"]
tokio = ["dep:tokio"]

[package.metadata.docs.rs]
all-features = true
Expand All @@ -52,3 +54,8 @@ rustdoc-args = ["--cfg", "docsrs"]
name = "empty_rounds"
harness = false
required-features = ["dev"]

[[bench]]
name = "async_session"
harness = false
required-features = ["dev", "tokio"]
Loading

0 comments on commit 40d1f38

Please sign in to comment.