Skip to content

Commit

Permalink
update shutdown logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Mar 25, 2024
1 parent 5d06842 commit 6344bf8
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 74 deletions.
12 changes: 2 additions & 10 deletions iroh-cli/src/commands/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,18 @@ where
.instrument(info_span!("command"))
});

let node2 = node.clone();
tokio::select! {
biased;
// always abort on signal-c
_ = tokio::signal::ctrl_c(), if run_type != RunType::SingleCommandNoAbort => {
command_task.abort();
node.shutdown();
node.await?;
node.shutdown().await?;
}
// abort if the command task finishes (will run forever if not in single-command mode)
res = &mut command_task => {
node.shutdown();
let _ = node.await;
let _ = node.shutdown().await;
res??;
}
// abort if the node future completes (shutdown called or error)
res = node2 => {
command_task.abort();
res?;
}
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion iroh/examples/collection-provide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ async fn main() -> anyhow::Result<()> {
println!("\tcargo run --example collection-fetch {}", ticket);
// wait for the node to finish, this will block indefinitely
// stop with SIGINT (ctrl+c)
node.await?;
node.shutdown().await?;
Ok(())
}
2 changes: 1 addition & 1 deletion iroh/examples/hello-world-provide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ async fn main() -> anyhow::Result<()> {
println!("\t cargo run --example hello-world-fetch {}", ticket);
// wait for the node to finish, this will block indefinitely
// stop with SIGINT (ctrl+c)
node.await?;
node.shutdown().await?;
Ok(())
}
10 changes: 6 additions & 4 deletions iroh/src/client/blobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ where
.rpc
.server_streaming(BlobConsistencyCheckRequest { repair })
.await?;
Ok(stream.map_err(anyhow::Error::from))
Ok(stream.map(|r| r.map_err(anyhow::Error::from)))
}

/// Download a blob from another node and add it to the local database.
Expand Down Expand Up @@ -258,7 +258,9 @@ where
mode,
};
let stream = self.rpc.server_streaming(req).await?;
Ok(BlobExportProgress::new(stream.map_err(anyhow::Error::from)))
Ok(BlobExportProgress::new(
stream.map(|r| r.map_err(anyhow::Error::from)),
))
}

/// List all complete blobs.
Expand Down Expand Up @@ -624,7 +626,7 @@ impl BlobExportProgress {
impl Stream for BlobExportProgress {
type Item = Result<ExportProgress>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.stream.poll_next_unpin(cx)
Pin::new(&mut self.stream).poll_next(cx)
}
}

Expand All @@ -633,7 +635,7 @@ impl Future for BlobExportProgress {

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match self.stream.poll_next_unpin(cx) {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
return Poll::Ready(Err(anyhow!("Response stream ended prematurely")))
Expand Down
29 changes: 11 additions & 18 deletions iroh/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
//!
//! To shut down the node, call [`Node::shutdown`].
use std::fmt::Debug;
use std::future::Future;
use std::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;

use anyhow::{anyhow, Result};
use futures_lite::{future::Boxed as BoxFuture, FutureExt, StreamExt};
Expand All @@ -29,7 +26,7 @@ use iroh_sync::store::Store as DocStore;
use quic_rpc::transport::flume::FlumeConnection;
use quic_rpc::RpcClient;
use tokio::sync::{mpsc, RwLock};
use tokio::task::JoinError;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tokio_util::task::LocalPoolHandle;
use tracing::debug;
Expand Down Expand Up @@ -90,7 +87,7 @@ impl iroh_bytes::provider::EventSender for Callbacks {
#[derive(Debug, Clone)]
pub struct Node<D> {
inner: Arc<NodeInner<D>>,
task: (), // Arc<BoxFuture<anyhow::Result<()>>>,
task: Arc<JoinHandle<()>>,
client: crate::client::mem::Iroh,
}

Expand Down Expand Up @@ -235,12 +232,18 @@ impl<D: BaoStore> Node<D> {
/// Aborts the node.
///
/// This does not gracefully terminate currently: all connections are closed and
/// anything in-transit is lost. The task will stop running and awaiting this
/// [`Node`] will complete.
/// anything in-transit is lost. The task will stop running.
/// If this is the last copy of the `Node`, this will finish once the task is
/// fully shutdown.
///
/// The shutdown behaviour will become more graceful in the future.
pub fn shutdown(&self) {
pub async fn shutdown(self) -> Result<()> {
self.inner.cancel_token.cancel();

if let Ok(task) = Arc::try_unwrap(self.task) {
task.await?;
}
Ok(())
}

/// Returns a token that can be used to cancel the node.
Expand All @@ -249,16 +252,6 @@ impl<D: BaoStore> Node<D> {
}
}

/// The future completes when the spawned tokio task finishes.
impl<D> Future for Node<D> {
type Output = Result<(), Arc<JoinError>>;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
// Pin::new(&mut self.task).poll(cx)
todo!()
}
}

impl<D> std::ops::Deref for Node<D> {
type Target = crate::client::mem::Iroh;

Expand Down
12 changes: 2 additions & 10 deletions iroh/src/node/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
};

use anyhow::{bail, Context, Result};
use futures_lite::{FutureExt, StreamExt};
use futures_lite::StreamExt;
use iroh_base::key::SecretKey;
use iroh_bytes::{
downloader::Downloader,
Expand Down Expand Up @@ -382,17 +382,9 @@ where
)
};

/*let task = Arc::new(
async move {
task.await?;
anyhow::Ok(())
}
.boxed(),
);*/

let node = Node {
inner,
task: (),
task: Arc::new(task),
client,
};

Expand Down
2 changes: 1 addition & 1 deletion iroh/src/node/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::time::Duration;

use anyhow::{anyhow, Result};
use futures_buffered::BufferedStreamExt;
use futures_lite::{FutureExt, Stream, StreamExt};
use futures_lite::{Stream, StreamExt};
use genawaiter::sync::{Co, Gen};
use iroh_base::rpc::RpcResult;
use iroh_bytes::export::ExportProgress;
Expand Down
21 changes: 11 additions & 10 deletions iroh/src/sync_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
use std::{io, sync::Arc};

use anyhow::Result;
use futures_lite::{future::Boxed as BoxFuture, FutureExt, Stream, StreamExt};
use futures_lite::{Stream, StreamExt};
use iroh_bytes::downloader::Downloader;
use iroh_bytes::{store::EntryStatus, Hash};
use iroh_gossip::net::Gossip;
use iroh_net::{key::PublicKey, MagicEndpoint, NodeAddr};
use iroh_sync::{actor::SyncHandle, ContentStatus, ContentStatusCallback, Entry, NamespaceId};
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tracing::{error, error_span, Instrument};

mod gossip;
Expand Down Expand Up @@ -42,8 +43,8 @@ pub struct SyncEngine {
pub(crate) endpoint: MagicEndpoint,
pub(crate) sync: SyncHandle,
to_live_actor: mpsc::Sender<ToLiveActor>,
#[debug("Arc<BoxFuture<()>>")]
tasks_fut: (), // Arc<BoxFuture<()>>,
#[debug("Arc<JoinHandle<()>>")]
tasks: Arc<JoinHandle<()>>,
#[debug("ContentStatusCallback")]
content_status_cb: ContentStatusCallback,
}
Expand Down Expand Up @@ -107,7 +108,7 @@ impl SyncEngine {
}
.instrument(error_span!("sync", %me)),
);
let tasks_fut = async move {
let tasks = tokio::task::spawn(async move {
if let Err(err) = live_actor_task.await {
error!("Error while joining actor task: {err:?}");
}
Expand All @@ -117,14 +118,13 @@ impl SyncEngine {
error!("Error while joining gossip recv task task: {err:?}");
}
}
}
.boxed();
});

Self {
endpoint,
sync,
to_live_actor: live_actor_tx,
tasks_fut: (),
tasks: Arc::new(tasks),
content_status_cb,
}
}
Expand Down Expand Up @@ -223,10 +223,11 @@ impl SyncEngine {
}

/// Shutdown the sync engine.
pub async fn shutdown(&self) -> Result<()> {
pub async fn shutdown(self) -> Result<()> {
self.to_live_actor.send(ToLiveActor::Shutdown).await?;
// TODO
// self.tasks_fut.clone().await;
if let Ok(tasks) = Arc::try_unwrap(self.tasks) {
tasks.await?;
}
Ok(())
}
}
Expand Down
2 changes: 1 addition & 1 deletion iroh/src/sync_engine/gossip.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashSet;

use anyhow::{anyhow, Context, Result};
use futures_lite::{FutureExt, StreamExt};
use futures_lite::StreamExt;
use iroh_gossip::{
net::{Event, Gossip},
proto::TopicId,
Expand Down
1 change: 0 additions & 1 deletion iroh/src/sync_engine/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use iroh_sync::{Author, NamespaceSecret};
use tokio_stream::StreamExt;

use crate::rpc_protocol::{DocGetSyncPeersRequest, DocGetSyncPeersResponse};
use crate::sync_engine::LiveEvent;
use crate::{
rpc_protocol::{
AuthorCreateRequest, AuthorCreateResponse, AuthorListRequest, AuthorListResponse,
Expand Down
6 changes: 2 additions & 4 deletions iroh/tests/gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ async fn gc_basics() -> Result<()> {
step(&evs).await;
assert_eq!(bao_store.entry_status(&h2).await?, EntryStatus::NotFound);

node.shutdown();
node.await?;
node.shutdown().await?;
Ok(())
}

Expand Down Expand Up @@ -180,8 +179,7 @@ async fn gc_hashseq_impl() -> Result<()> {
assert_eq!(bao_store.entry_status(&h2).await?, EntryStatus::NotFound);
assert_eq!(bao_store.entry_status(&hr).await?, EntryStatus::NotFound);

node.shutdown();
node.await?;
node.shutdown().await?;
Ok(())
}

Expand Down
18 changes: 9 additions & 9 deletions iroh/tests/provide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
time::{Duration, Instant},
};

use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, Result};
use bytes::Bytes;
use futures_lite::FutureExt;
use iroh::{
Expand Down Expand Up @@ -264,8 +264,7 @@ where
.await
.expect("duration expired");

node.shutdown();
node.await?;
node.shutdown().await?;

assert_events(events, num_blobs + 1);

Expand Down Expand Up @@ -315,7 +314,7 @@ async fn test_server_close() {
let child_hash = db.insert(b"hello there");
let collection = Collection::from_iter([("hello", child_hash)]);
let hash = db.insert_many(collection.to_blobs()).unwrap();
let mut node = test_node(db).spawn().await.unwrap();
let node = test_node(db).spawn().await.unwrap();
let node_addr = node.local_endpoint_addresses().await.unwrap();
let peer_id = node.node_id();

Expand All @@ -338,11 +337,12 @@ async fn test_server_close() {
loop {
tokio::select! {
biased;
res = &mut node => break res.context("provider failed"),
maybe_event = events_recv.recv() => {
match maybe_event {
Some(event) => match event {
Event::ByteProvide(provider::Event::TransferCompleted { .. }) => node.shutdown(),
Event::ByteProvide(provider::Event::TransferCompleted { .. }) => {
return node.shutdown().await;
},
Event::ByteProvide(provider::Event::TransferAborted { .. }) => {
break Err(anyhow!("transfer aborted"));
}
Expand All @@ -354,9 +354,9 @@ async fn test_server_close() {
}
}
})
.await
.expect("supervisor timeout")
.expect("supervisor failed");
.await
.expect("supervisor timeout")
.expect("supervisor failed");
}

/// create an in memory test database containing the given entries and an iroh collection of all entries
Expand Down
8 changes: 4 additions & 4 deletions iroh/tests/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async fn sync_simple() -> Result<()> {
.await;

for node in nodes {
node.shutdown();
node.shutdown().await?;
}
Ok(())
}
Expand All @@ -138,7 +138,7 @@ async fn sync_subscribe_no_sync() -> Result<()> {
matches!(event, Some(Ok(LiveEvent::InsertLocal { .. }))),
"expected InsertLocal but got {event:?}"
);
node.shutdown();
node.shutdown().await?;
Ok(())
}

Expand Down Expand Up @@ -391,7 +391,7 @@ async fn sync_full_basic() -> Result<()> {

info!("shutdown");
for node in nodes {
node.shutdown();
node.shutdown().await?;
}

Ok(())
Expand Down Expand Up @@ -880,7 +880,7 @@ async fn doc_delete() -> Result<()> {
tokio::time::sleep(Duration::from_millis(200)).await;
let bytes = client.blobs.read_to_bytes(hash).await;
assert!(bytes.is_err());
node.shutdown();
node.shutdown().await?;
Ok(())
}

Expand Down

0 comments on commit 6344bf8

Please sign in to comment.