|
1 | 1 | //! TODO(matheus23) docs
|
2 | 2 | use std::{any::Any, collections::BTreeMap, sync::Arc};
|
3 | 3 |
|
4 |
| -use anyhow::Result; |
| 4 | +use anyhow::{anyhow, Result}; |
5 | 5 | use futures_buffered::join_all;
|
6 | 6 | use futures_lite::future::Boxed as BoxedFuture;
|
| 7 | +use futures_util::{ |
| 8 | + future::{MapErr, Shared}, |
| 9 | + FutureExt, TryFutureExt, |
| 10 | +}; |
| 11 | +use tokio::task::{JoinError, JoinSet}; |
| 12 | +use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; |
| 13 | +use tracing::{debug, error, warn}; |
7 | 14 |
|
8 |
| -use crate::endpoint::Connecting; |
| 15 | +use crate::{endpoint::Connecting, Endpoint}; |
| 16 | + |
| 17 | +/// TODO(matheus23): docs |
| 18 | +#[derive(Clone, Debug)] |
| 19 | +pub struct Router { |
| 20 | + endpoint: Endpoint, |
| 21 | + protocols: Arc<ProtocolMap>, |
| 22 | + // `Router` needs to be `Clone + Send`, and we need to `task.await` in its `shutdown()` impl. |
| 23 | + // So we need |
| 24 | + // - `Shared` so we can `task.await` from all `Node` clones |
| 25 | + // - `MapErr` to map the `JoinError` to a `String`, because `JoinError` is `!Clone` |
| 26 | + // - `AbortOnDropHandle` to make sure that the `task` is cancelled when all `Node`s are dropped |
| 27 | + // (`Shared` acts like an `Arc` around its inner future). |
| 28 | + task: Shared<MapErr<AbortOnDropHandle<()>, JoinErrToStr>>, |
| 29 | + cancel_token: CancellationToken, |
| 30 | +} |
| 31 | + |
| 32 | +type JoinErrToStr = Box<dyn Fn(JoinError) -> String + Send + Sync + 'static>; |
| 33 | + |
| 34 | +/// TODO(matheus23): docs |
| 35 | +#[derive(Debug)] |
| 36 | +pub struct RouterBuilder { |
| 37 | + endpoint: Endpoint, |
| 38 | + protocols: ProtocolMap, |
| 39 | +} |
9 | 40 |
|
10 | 41 | /// Handler for incoming connections.
|
11 | 42 | ///
|
@@ -78,3 +109,186 @@ impl ProtocolMap {
|
78 | 109 | join_all(handlers).await;
|
79 | 110 | }
|
80 | 111 | }
|
| 112 | + |
| 113 | +impl Router { |
| 114 | + /// TODO(matheus23): docs |
| 115 | + pub fn builder(endpoint: Endpoint) -> RouterBuilder { |
| 116 | + RouterBuilder::new(endpoint) |
| 117 | + } |
| 118 | + |
| 119 | + /// Returns a protocol handler for an ALPN. |
| 120 | + /// |
| 121 | + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` |
| 122 | + /// does not match the passed type. |
| 123 | + pub fn get_protocol<P: ProtocolHandler>(&self, alpn: &[u8]) -> Option<Arc<P>> { |
| 124 | + self.protocols.get_typed(alpn) |
| 125 | + } |
| 126 | + |
| 127 | + /// TODO(matheus23): docs |
| 128 | + pub fn endpoint(&self) -> &Endpoint { |
| 129 | + &self.endpoint |
| 130 | + } |
| 131 | + |
| 132 | + /// TODO(matheus23): docs |
| 133 | + pub async fn shutdown(self) -> Result<()> { |
| 134 | + // Trigger shutdown of the main run task by activating the cancel token. |
| 135 | + self.cancel_token.cancel(); |
| 136 | + |
| 137 | + // Wait for the main task to terminate. |
| 138 | + self.task.await.map_err(|err| anyhow!(err))?; |
| 139 | + |
| 140 | + Ok(()) |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +impl RouterBuilder { |
| 145 | + /// TODO(matheus23): docs |
| 146 | + pub fn new(endpoint: Endpoint) -> Self { |
| 147 | + Self { |
| 148 | + endpoint, |
| 149 | + protocols: ProtocolMap::default(), |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + /// TODO(matheus23): docs |
| 154 | + pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: Arc<dyn ProtocolHandler>) -> Self { |
| 155 | + self.protocols.insert(alpn.as_ref().to_vec(), handler); |
| 156 | + self |
| 157 | + } |
| 158 | + |
| 159 | + /// Returns the [`Endpoint`] of the node. |
| 160 | + pub fn endpoint(&self) -> &Endpoint { |
| 161 | + &self.endpoint |
| 162 | + } |
| 163 | + |
| 164 | + /// Returns a protocol handler for an ALPN. |
| 165 | + /// |
| 166 | + /// This downcasts to the concrete type and returns `None` if the handler registered for `alpn` |
| 167 | + /// does not match the passed type. |
| 168 | + pub fn get_protocol<P: ProtocolHandler>(&self, alpn: &[u8]) -> Option<Arc<P>> { |
| 169 | + self.protocols.get_typed(alpn) |
| 170 | + } |
| 171 | + |
| 172 | + /// TODO(matheus23): docs |
| 173 | + pub async fn spawn(self) -> Result<Router> { |
| 174 | + // Update the endpoint with our alpns. |
| 175 | + let alpns = self |
| 176 | + .protocols |
| 177 | + .alpns() |
| 178 | + .map(|alpn| alpn.to_vec()) |
| 179 | + .collect::<Vec<_>>(); |
| 180 | + |
| 181 | + let protocols = Arc::new(self.protocols); |
| 182 | + if let Err(err) = self.endpoint.set_alpns(alpns) { |
| 183 | + shutdown(&self.endpoint, protocols.clone()).await; |
| 184 | + return Err(err); |
| 185 | + } |
| 186 | + |
| 187 | + let mut join_set = JoinSet::new(); |
| 188 | + let endpoint = self.endpoint.clone(); |
| 189 | + let protos = protocols.clone(); |
| 190 | + let cancel = CancellationToken::new(); |
| 191 | + let cancel_token = cancel.clone(); |
| 192 | + |
| 193 | + let run_loop_fut = async move { |
| 194 | + let protocols = protos; |
| 195 | + loop { |
| 196 | + tokio::select! { |
| 197 | + biased; |
| 198 | + _ = cancel_token.cancelled() => { |
| 199 | + break; |
| 200 | + }, |
| 201 | + // handle incoming p2p connections. |
| 202 | + incoming = endpoint.accept() => { |
| 203 | + let Some(incoming) = incoming else { |
| 204 | + break; |
| 205 | + }; |
| 206 | + |
| 207 | + let protocols = protocols.clone(); |
| 208 | + join_set.spawn(async move { |
| 209 | + handle_connection(incoming, protocols).await; |
| 210 | + anyhow::Ok(()) |
| 211 | + }); |
| 212 | + }, |
| 213 | + // handle task terminations and quit on panics. |
| 214 | + res = join_set.join_next(), if !join_set.is_empty() => { |
| 215 | + match res { |
| 216 | + Some(Err(outer)) => { |
| 217 | + if outer.is_panic() { |
| 218 | + error!("Task panicked: {outer:?}"); |
| 219 | + break; |
| 220 | + } else if outer.is_cancelled() { |
| 221 | + debug!("Task cancelled: {outer:?}"); |
| 222 | + } else { |
| 223 | + error!("Task failed: {outer:?}"); |
| 224 | + break; |
| 225 | + } |
| 226 | + } |
| 227 | + Some(Ok(Err(inner))) => { |
| 228 | + debug!("Task errored: {inner:?}"); |
| 229 | + } |
| 230 | + _ => {} |
| 231 | + } |
| 232 | + }, |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + shutdown(&endpoint, protocols).await; |
| 237 | + |
| 238 | + // Abort remaining tasks. |
| 239 | + tracing::info!("Shutting down remaining tasks"); |
| 240 | + join_set.shutdown().await; |
| 241 | + }; |
| 242 | + let task = tokio::task::spawn(run_loop_fut); |
| 243 | + let task = AbortOnDropHandle::new(task) |
| 244 | + .map_err(Box::new(|e: JoinError| e.to_string()) as JoinErrToStr) |
| 245 | + .shared(); |
| 246 | + |
| 247 | + Ok(Router { |
| 248 | + endpoint: self.endpoint, |
| 249 | + protocols, |
| 250 | + task, |
| 251 | + cancel_token: cancel, |
| 252 | + }) |
| 253 | + } |
| 254 | +} |
| 255 | + |
| 256 | +/// Shutdown the different parts of the router concurrently. |
| 257 | +async fn shutdown(endpoint: &Endpoint, protocols: Arc<ProtocolMap>) { |
| 258 | + let error_code = 1u16; |
| 259 | + |
| 260 | + // We ignore all errors during shutdown. |
| 261 | + let _ = tokio::join!( |
| 262 | + // Close the endpoint. |
| 263 | + // Closing the Endpoint is the equivalent of calling Connection::close on all |
| 264 | + // connections: Operations will immediately fail with ConnectionError::LocallyClosed. |
| 265 | + // All streams are interrupted, this is not graceful. |
| 266 | + endpoint.close(error_code.into(), b"provider terminating"), |
| 267 | + // Shutdown protocol handlers. |
| 268 | + protocols.shutdown(), |
| 269 | + ); |
| 270 | +} |
| 271 | + |
| 272 | +async fn handle_connection(incoming: crate::endpoint::Incoming, protocols: Arc<ProtocolMap>) { |
| 273 | + let mut connecting = match incoming.accept() { |
| 274 | + Ok(conn) => conn, |
| 275 | + Err(err) => { |
| 276 | + warn!("Ignoring connection: accepting failed: {err:#}"); |
| 277 | + return; |
| 278 | + } |
| 279 | + }; |
| 280 | + let alpn = match connecting.alpn().await { |
| 281 | + Ok(alpn) => alpn, |
| 282 | + Err(err) => { |
| 283 | + warn!("Ignoring connection: invalid handshake: {err:#}"); |
| 284 | + return; |
| 285 | + } |
| 286 | + }; |
| 287 | + let Some(handler) = protocols.get(&alpn) else { |
| 288 | + warn!("Ignoring connection: unsupported ALPN protocol"); |
| 289 | + return; |
| 290 | + }; |
| 291 | + if let Err(err) = handler.accept(connecting).await { |
| 292 | + warn!("Handling incoming connection ended with error: {err}"); |
| 293 | + } |
| 294 | +} |
0 commit comments