diff --git a/iroh/src/discovery.rs b/iroh/src/discovery.rs index a8ee7965f6..c23789b578 100644 --- a/iroh/src/discovery.rs +++ b/iroh/src/discovery.rs @@ -116,7 +116,8 @@ use std::{collections::BTreeSet, net::SocketAddr, time::Duration}; use anyhow::{anyhow, ensure, Result}; use futures_lite::stream::{Boxed as BoxStream, StreamExt}; use iroh_base::{NodeAddr, NodeId, RelayUrl}; -use tokio::{sync::oneshot, task::JoinHandle}; +use tokio::sync::oneshot; +use tokio_util::task::AbortOnDropHandle; use tracing::{debug, error_span, warn, Instrument}; use crate::Endpoint; @@ -285,7 +286,7 @@ const MAX_AGE: Duration = Duration::from_secs(10); /// A wrapper around a tokio task which runs a node discovery. pub(super) struct DiscoveryTask { on_first_rx: oneshot::Receiver>, - task: JoinHandle<()>, + task: AbortOnDropHandle<()>, } impl DiscoveryTask { @@ -299,7 +300,10 @@ impl DiscoveryTask { error_span!("discovery", me = %me.fmt_short(), node = %node_id.fmt_short()), ), ); - Ok(Self { task, on_first_rx }) + Ok(Self { + task: AbortOnDropHandle::new(task), + on_first_rx, + }) } /// Starts a discovery task after a delay and only if no path to the node was recently active. @@ -340,7 +344,10 @@ impl DiscoveryTask { error_span!("discovery", me = %me.fmt_short(), node = %node_id.fmt_short()), ), ); - Ok(Some(Self { task, on_first_rx })) + Ok(Some(Self { + task: AbortOnDropHandle::new(task), + on_first_rx, + })) } /// Waits until the discovery task produced at least one result. @@ -350,11 +357,6 @@ impl DiscoveryTask { Ok(()) } - /// Cancels the discovery task. - pub(super) fn cancel(&self) { - self.task.abort(); - } - fn create_stream(ep: &Endpoint, node_id: NodeId) -> Result>> { let discovery = ep .discovery() @@ -400,11 +402,7 @@ impl DiscoveryTask { let mut on_first_tx = Some(on_first_tx); debug!("discovery: start"); loop { - let next = tokio::select! { - _ = ep.cancel_token().cancelled() => break, - next = stream.next() => next - }; - match next { + match stream.next().await { Some(Ok(r)) => { if r.node_addr.is_empty() { debug!(provenance = %r.provenance, "discovery: empty address found"); diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 02493efafd..fe03fd3f35 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -23,11 +23,9 @@ use std::{ }; use anyhow::{bail, Context, Result}; -use derive_more::Debug; use iroh_base::{NodeAddr, NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::RelayMap; use pin_project::pin_project; -use tokio_util::sync::CancellationToken; use tracing::{debug, instrument, trace, warn}; use url::Url; @@ -92,7 +90,7 @@ pub enum PathSelection { /// new [`NodeId`]. /// /// To create the [`Endpoint`] call [`Builder::bind`]. -#[derive(Debug)] +#[derive(derive_more::Debug)] pub struct Builder { secret_key: Option, relay_mode: RelayMode, @@ -510,7 +508,6 @@ pub struct Endpoint { msock: Handle, endpoint: quinn::Endpoint, rtt_actor: Arc, - cancel_token: CancellationToken, static_config: Arc, } @@ -561,7 +558,6 @@ impl Endpoint { msock, endpoint, rtt_actor: Arc::new(rtt_actor::RttHandle::new()), - cancel_token: CancellationToken::new(), static_config: Arc::new(static_config), }) } @@ -618,10 +614,11 @@ impl Endpoint { let node_id = node_addr.node_id; let direct_addresses = node_addr.direct_addresses.clone(); - // Get the mapped IPv6 address from the magic socket. Quinn will connect to this address. - // Start discovery for this node if it's enabled and we have no valid or verified - // address information for this node. - let (addr, discovery) = self + // Get the mapped IPv6 address from the magic socket. Quinn will connect to this + // address. Start discovery for this node if it's enabled and we have no valid or + // verified address information for this node. Dropping the discovery cancels any + // still running task. + let (addr, _discovery_drop_guard) = self .get_mapping_addr_and_maybe_start_discovery(node_addr) .await .with_context(|| { @@ -636,16 +633,9 @@ impl Endpoint { node_id, addr, direct_addresses ); - // Start connecting via quinn. This will time out after 10 seconds if no reachable address - // is available. - let conn = self.connect_quinn(node_id, alpn, addr).await; - - // Cancel the node discovery task (if still running). - if let Some(discovery) = discovery { - discovery.cancel(); - } - - conn + // Start connecting via quinn. This will time out after 10 seconds if no reachable + // address is available. + self.connect_quinn(node_id, alpn, addr).await } #[instrument( @@ -990,7 +980,6 @@ impl Endpoint { return Ok(()); } - self.cancel_token.cancel(); tracing::debug!("Closing connections"); self.endpoint.close(0u16.into(), b""); self.endpoint.wait_idle().await; @@ -1002,16 +991,11 @@ impl Endpoint { /// Check if this endpoint is still alive, or already closed. pub fn is_closed(&self) -> bool { - self.cancel_token.is_cancelled() && self.msock.is_closed() + self.msock.is_closed() } // # Remaining private methods - /// Expose the internal [`CancellationToken`] to link shutdowns. - pub(crate) fn cancel_token(&self) -> &CancellationToken { - &self.cancel_token - } - /// Return the quic mapped address for this `node_id` and possibly start discovery /// services if discovery is enabled on this magic endpoint. /// @@ -1085,7 +1069,7 @@ impl Endpoint { } /// Future produced by [`Endpoint::accept`]. -#[derive(Debug)] +#[derive(derive_more::Debug)] #[pin_project] pub struct Accept<'a> { #[pin] diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 38f7c7936f..4aa22d34bf 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -248,9 +248,8 @@ impl RouterBuilder { let mut join_set = JoinSet::new(); let endpoint = self.endpoint.clone(); - // We use a child token of the endpoint, to ensure that this is shutdown - // when the endpoint is shutdown, but that we can shutdown ourselves independently. - let cancel = endpoint.cancel_token().child_token(); + // Our own shutdown works with a cancellation token. + let cancel = CancellationToken::new(); let cancel_token = cancel.clone(); let run_loop_fut = async move { @@ -289,7 +288,7 @@ impl RouterBuilder { // handle incoming p2p connections. incoming = endpoint.accept() => { let Some(incoming) = incoming else { - break; + break; // Endpoint is closed. }; let protocols = protocols.clone();