diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 2c3dd1b245..6349063255 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -42,16 +42,12 @@ //! ``` use std::{any::Any, collections::BTreeMap, sync::Arc}; -use anyhow::{anyhow, Result}; +use anyhow::Result; use futures_buffered::join_all; use futures_lite::future::Boxed as BoxedFuture; -use futures_util::{ - future::{MapErr, Shared}, - FutureExt, TryFutureExt, -}; -use tokio::task::{JoinError, JoinSet}; +use tokio::{sync::Mutex, task::JoinSet}; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; -use tracing::{debug, error, warn}; +use tracing::{error, info_span, trace, warn, Instrument}; use crate::{endpoint::Connecting, Endpoint}; @@ -92,17 +88,10 @@ pub struct Router { endpoint: Endpoint, protocols: Arc, // `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl. - // So we need - // - `Shared` so we can `task.await` from all `Node` clones - // - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone` - // - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped - // (`Shared` acts like an `Arc` around its inner future). - task: Shared, JoinErrToStr>>, + task: Arc>>>, cancel_token: CancellationToken, } -type JoinErrToStr = Box String + Send + Sync + 'static>; - /// Builder for creating a [`Router`] for accepting protocols. #[derive(Debug)] pub struct RouterBuilder { @@ -201,16 +190,32 @@ impl Router { &self.endpoint } + /// Checks if the router is already shutdown. + pub fn is_shutdown(&self) -> bool { + self.cancel_token.is_cancelled() + } + /// Shuts down the accept loop cleanly. /// + /// When this function returns, all [`ProtocolHandler`]s will be shutdown and + /// `Endpoint::close` will have been called. + /// + /// If already shutdown, it returns `Ok`. + /// /// If some [`ProtocolHandler`] panicked in the accept loop, this will propagate /// that panic into the result here. - pub async fn shutdown(self) -> Result<()> { + pub async fn shutdown(&self) -> Result<()> { + if self.is_shutdown() { + return Ok(()); + } + // Trigger shutdown of the main run task by activating the cancel token. self.cancel_token.cancel(); // Wait for the main task to terminate. - self.task.await.map_err(|err| anyhow!(err))?; + if let Some(task) = self.task.lock().await.take() { + task.await?; + } Ok(()) } @@ -267,6 +272,9 @@ impl RouterBuilder { let cancel_token = cancel.clone(); let run_loop_fut = async move { + // Make sure to cancel the token, if this future ever exits. + let _cancel_guard = cancel_token.clone().drop_guard(); + let protocols = protos; loop { tokio::select! { @@ -274,18 +282,6 @@ impl RouterBuilder { _ = cancel_token.cancelled() => { break; }, - // handle incoming p2p connections. - incoming = endpoint.accept() => { - let Some(incoming) = incoming else { - break; - }; - - let protocols = protocols.clone(); - join_set.spawn(async move { - handle_connection(incoming, protocols).await; - anyhow::Ok(()) - }); - }, // handle task terminations and quit on panics. res = join_set.join_next(), if !join_set.is_empty() => { match res { @@ -294,18 +290,34 @@ impl RouterBuilder { error!("Task panicked: {outer:?}"); break; } else if outer.is_cancelled() { - debug!("Task cancelled: {outer:?}"); + trace!("Task cancelled: {outer:?}"); } else { error!("Task failed: {outer:?}"); break; } } - Some(Ok(Err(inner))) => { - debug!("Task errored: {inner:?}"); + Some(Ok(Some(()))) => { + trace!("Task finished"); + } + Some(Ok(None)) => { + trace!("Task cancelled"); } _ => {} } }, + + // handle incoming p2p connections. + incoming = endpoint.accept() => { + let Some(incoming) = incoming else { + break; + }; + + let protocols = protocols.clone(); + let token = cancel_token.child_token(); + join_set.spawn(async move { + token.run_until_cancelled(handle_connection(incoming, protocols)).await + }.instrument(info_span!("router.accept"))); + }, } } @@ -316,14 +328,12 @@ impl RouterBuilder { join_set.shutdown().await; }; let task = tokio::task::spawn(run_loop_fut); - let task = AbortOnDropHandle::new(task) - .map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr) - .shared(); + let task = AbortOnDropHandle::new(task); Ok(Router { endpoint: self.endpoint, protocols, - task, + task: Arc::new(Mutex::new(Some(task))), cancel_token: cancel, }) }