Skip to content

Commit 23b26aa

Browse files
committed
refactor: move Router and RouterBuilder into protocol.rs
1 parent fa1331e commit 23b26aa

File tree

4 files changed

+217
-228
lines changed

4 files changed

+217
-228
lines changed

iroh/src/discovery/local_swarm_discovery.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
//! .filter(|remote| {
2121
//! remote.sources().iter().any(|(source, duration)| {
2222
//! if let Source::Discovery { name } = source {
23-
//! name == iroh::discovery::local_swarm_discovery::NAME
24-
//! && *duration <= recent
23+
//! name == iroh::discovery::local_swarm_discovery::NAME && *duration <= recent
2524
//! } else {
2625
//! false
2726
//! }

iroh/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ pub mod endpoint;
242242
mod magicsock;
243243
pub mod metrics;
244244
pub mod protocol;
245-
pub mod router;
246245
pub mod ticket;
247246
pub mod tls;
248247

iroh/src/protocol.rs

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,42 @@
11
//! TODO(matheus23) docs
22
use std::{any::Any, collections::BTreeMap, sync::Arc};
33

4-
use anyhow::Result;
4+
use anyhow::{anyhow, Result};
55
use futures_buffered::join_all;
66
use futures_lite::future::Boxed as BoxedFuture;
7+
use futures_util::{
8+
future::{MapErr, Shared},
9+
FutureExt, TryFutureExt,
10+
};
11+
use tokio::task::{JoinError, JoinSet};
12+
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
13+
use tracing::{debug, error, warn};
714

8-
use crate::endpoint::Connecting;
15+
use crate::{endpoint::Connecting, Endpoint};
16+
17+
/// TODO(matheus23): docs
18+
#[derive(Clone, Debug)]
19+
pub struct Router {
20+
endpoint: Endpoint,
21+
protocols: Arc<ProtocolMap>,
22+
// `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl.
23+
// So we need
24+
// - `Shared` so we can `task.await` from all `Node` clones
25+
// - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone`
26+
// - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped
27+
// (`Shared` acts like an `Arc` around its inner future).
28+
task: Shared<MapErr<AbortOnDropHandle<()>, JoinErrToStr>>,
29+
cancel_token: CancellationToken,
30+
}
31+
32+
type JoinErrToStr = Box<dyn Fn(JoinError) -> String + Send + Sync + 'static>;
33+
34+
/// TODO(matheus23): docs
35+
#[derive(Debug)]
36+
pub struct RouterBuilder {
37+
endpoint: Endpoint,
38+
protocols: ProtocolMap,
39+
}
940

1041
/// Handler for incoming connections.
1142
///
@@ -78,3 +109,186 @@ impl ProtocolMap {
78109
join_all(handlers).await;
79110
}
80111
}
112+
113+
impl Router {
114+
/// TODO(matheus23): docs
115+
pub fn builder(endpoint: Endpoint) -> RouterBuilder {
116+
RouterBuilder::new(endpoint)
117+
}
118+
119+
/// Returns a protocol handler for an ALPN.
120+
///
121+
/// This downcasts to the concrete type and returns `None` if the handler registered for `alpn`
122+
/// does not match the passed type.
123+
pub fn get_protocol<P: ProtocolHandler>(&self, alpn: &[u8]) -> Option<Arc<P>> {
124+
self.protocols.get_typed(alpn)
125+
}
126+
127+
/// TODO(matheus23): docs
128+
pub fn endpoint(&self) -> &Endpoint {
129+
&self.endpoint
130+
}
131+
132+
/// TODO(matheus23): docs
133+
pub async fn shutdown(self) -> Result<()> {
134+
// Trigger shutdown of the main run task by activating the cancel token.
135+
self.cancel_token.cancel();
136+
137+
// Wait for the main task to terminate.
138+
self.task.await.map_err(|err| anyhow!(err))?;
139+
140+
Ok(())
141+
}
142+
}
143+
144+
impl RouterBuilder {
145+
/// TODO(matheus23): docs
146+
pub fn new(endpoint: Endpoint) -> Self {
147+
Self {
148+
endpoint,
149+
protocols: ProtocolMap::default(),
150+
}
151+
}
152+
153+
/// TODO(matheus23): docs
154+
pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: Arc<dyn ProtocolHandler>) -> Self {
155+
self.protocols.insert(alpn.as_ref().to_vec(), handler);
156+
self
157+
}
158+
159+
/// Returns the [`Endpoint`] of the node.
160+
pub fn endpoint(&self) -> &Endpoint {
161+
&self.endpoint
162+
}
163+
164+
/// Returns a protocol handler for an ALPN.
165+
///
166+
/// This downcasts to the concrete type and returns `None` if the handler registered for `alpn`
167+
/// does not match the passed type.
168+
pub fn get_protocol<P: ProtocolHandler>(&self, alpn: &[u8]) -> Option<Arc<P>> {
169+
self.protocols.get_typed(alpn)
170+
}
171+
172+
/// TODO(matheus23): docs
173+
pub async fn spawn(self) -> Result<Router> {
174+
// Update the endpoint with our alpns.
175+
let alpns = self
176+
.protocols
177+
.alpns()
178+
.map(|alpn| alpn.to_vec())
179+
.collect::<Vec<_>>();
180+
181+
let protocols = Arc::new(self.protocols);
182+
if let Err(err) = self.endpoint.set_alpns(alpns) {
183+
shutdown(&self.endpoint, protocols.clone()).await;
184+
return Err(err);
185+
}
186+
187+
let mut join_set = JoinSet::new();
188+
let endpoint = self.endpoint.clone();
189+
let protos = protocols.clone();
190+
let cancel = CancellationToken::new();
191+
let cancel_token = cancel.clone();
192+
193+
let run_loop_fut = async move {
194+
let protocols = protos;
195+
loop {
196+
tokio::select! {
197+
biased;
198+
_ = cancel_token.cancelled() => {
199+
break;
200+
},
201+
// handle incoming p2p connections.
202+
incoming = endpoint.accept() => {
203+
let Some(incoming) = incoming else {
204+
break;
205+
};
206+
207+
let protocols = protocols.clone();
208+
join_set.spawn(async move {
209+
handle_connection(incoming, protocols).await;
210+
anyhow::Ok(())
211+
});
212+
},
213+
// handle task terminations and quit on panics.
214+
res = join_set.join_next(), if !join_set.is_empty() => {
215+
match res {
216+
Some(Err(outer)) => {
217+
if outer.is_panic() {
218+
error!("Task panicked: {outer:?}");
219+
break;
220+
} else if outer.is_cancelled() {
221+
debug!("Task cancelled: {outer:?}");
222+
} else {
223+
error!("Task failed: {outer:?}");
224+
break;
225+
}
226+
}
227+
Some(Ok(Err(inner))) => {
228+
debug!("Task errored: {inner:?}");
229+
}
230+
_ => {}
231+
}
232+
},
233+
}
234+
}
235+
236+
shutdown(&endpoint, protocols).await;
237+
238+
// Abort remaining tasks.
239+
tracing::info!("Shutting down remaining tasks");
240+
join_set.shutdown().await;
241+
};
242+
let task = tokio::task::spawn(run_loop_fut);
243+
let task = AbortOnDropHandle::new(task)
244+
.map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr)
245+
.shared();
246+
247+
Ok(Router {
248+
endpoint: self.endpoint,
249+
protocols,
250+
task,
251+
cancel_token: cancel,
252+
})
253+
}
254+
}
255+
256+
/// Shutdown the different parts of the router concurrently.
257+
async fn shutdown(endpoint: &Endpoint, protocols: Arc<ProtocolMap>) {
258+
let error_code = 1u16;
259+
260+
// We ignore all errors during shutdown.
261+
let _ = tokio::join!(
262+
// Close the endpoint.
263+
// Closing the Endpoint is the equivalent of calling Connection::close on all
264+
// connections: Operations will immediately fail with ConnectionError::LocallyClosed.
265+
// All streams are interrupted, this is not graceful.
266+
endpoint.close(error_code.into(), b"provider terminating"),
267+
// Shutdown protocol handlers.
268+
protocols.shutdown(),
269+
);
270+
}
271+
272+
async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<ProtocolMap>) {
273+
let mut connecting = match incoming.accept() {
274+
Ok(conn) => conn,
275+
Err(err) => {
276+
warn!("Ignoring connection: accepting failed: {err:#}");
277+
return;
278+
}
279+
};
280+
let alpn = match connecting.alpn().await {
281+
Ok(alpn) => alpn,
282+
Err(err) => {
283+
warn!("Ignoring connection: invalid handshake: {err:#}");
284+
return;
285+
}
286+
};
287+
let Some(handler) = protocols.get(&alpn) else {
288+
warn!("Ignoring connection: unsupported ALPN protocol");
289+
return;
290+
};
291+
if let Err(err) = handler.accept(connecting).await {
292+
warn!("Handling incoming connection ended with error: {err}");
293+
}
294+
}

0 commit comments

Comments
 (0)