diff --git a/Cargo.lock b/Cargo.lock index 794071b3c6..fd404cc7ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2952,6 +2952,7 @@ dependencies = [ "smallvec", "socket2", "stun-rs", + "testresult", "thiserror 2.0.3", "time", "tokio", diff --git a/iroh-relay/Cargo.toml b/iroh-relay/Cargo.toml index b94904e4d7..967e406544 100644 --- a/iroh-relay/Cargo.toml +++ b/iroh-relay/Cargo.toml @@ -98,6 +98,7 @@ clap = { version = "4", features = ["derive"] } crypto_box = { version = "0.9.1", features = ["serde", "chacha20"] } proptest = "1.2.0" rand_chacha = "0.3.1" +testresult = "0.4.0" tokio = { version = "1", features = [ "io-util", "sync", diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index fde38b0f2c..020727743b 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -510,7 +510,7 @@ pub(crate) async fn send_packet + Unpin>( }; if let Some(rate_limiter) = rate_limiter { if rate_limiter.check_n(frame.len()).is_err() { - tracing::warn!("dropping send: rate limit reached"); + tracing::debug!("dropping send: rate limit reached"); return Ok(()); } } @@ -521,12 +521,7 @@ pub(crate) async fn send_packet + Unpin>( } pub(crate) struct RateLimiter { - inner: governor::RateLimiter< - governor::state::direct::NotKeyed, - governor::state::InMemoryState, - governor::clock::DefaultClock, - governor::middleware::NoOpMiddleware, - >, + inner: governor::DefaultDirectRateLimiter, } impl RateLimiter { diff --git a/iroh-relay/src/main.rs b/iroh-relay/src/main.rs index aad6b95352..94d81f095f 100644 --- a/iroh-relay/src/main.rs +++ b/iroh-relay/src/main.rs @@ -8,11 +8,11 @@ use std::{ path::{Path, PathBuf}, }; -use anyhow::{anyhow, bail, Context as _, Result}; +use anyhow::{bail, Context as _, Result}; use clap::Parser; use iroh_relay::{ defaults::{DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT, DEFAULT_METRICS_PORT, DEFAULT_STUN_PORT}, - server as relay, + server::{self as relay, ClientConnRateLimit}, }; use serde::{Deserialize, Serialize}; use tokio_rustls_acme::{caches::DirCache, AcmeConfig}; @@ -282,6 +282,29 @@ struct Limits { accept_conn_limit: Option, /// Burst limit for accepting new connection. Unlimited if not set. accept_conn_burst: Option, + /// Rate limiting configuration per client. + client: Option, +} + +/// Rate limit configuration for each connected client. +/// +/// The rate limiting uses a token-bucket style algorithm: +/// +/// - The base rate limit uses a steady-stream rate of bytes allowed. +/// - Additionally a burst quota allows sending bytes over this steady-stream rate +/// limit, as long as the maximum burst quota is not exceeded. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct PerClientRateLimitConfig { + /// Rate limit configuration for the incoming data from the client. + rx: Option, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct RateLimitConfig { + /// Maximum number of bytes per second. + bytes_per_second: Option, + /// Maximum number of bytes to read in a single burst. + max_burst_bytes: Option, } impl Config { @@ -295,13 +318,14 @@ impl Config { if config_path.exists() { Self::read_from_file(&config_path).await } else { - let config = Config::default(); - config.write_to_file(&config_path).await?; - - Ok(config) + Ok(Config::default()) } } + fn from_str(config: &str) -> Result { + toml::from_str(config).context("config must be valid toml") + } + async fn read_from_file(path: impl AsRef) -> Result { if !path.as_ref().is_file() { bail!("config-path must be a file"); @@ -309,27 +333,7 @@ impl Config { let config_ser = tokio::fs::read_to_string(&path) .await .context("unable to read config")?; - let config: Self = toml::from_str(&config_ser).context("config file must be valid toml")?; - - Ok(config) - } - - /// Write the content of this configuration to the provided path. - async fn write_to_file(&self, path: impl AsRef) -> Result<()> { - let p = path - .as_ref() - .parent() - .ok_or_else(|| anyhow!("invalid config file path, no parent"))?; - // TODO: correct permissions (0777 for dir, 0600 for file) - tokio::fs::create_dir_all(p) - .await - .with_context(|| format!("unable to create config-path dir: {}", p.display()))?; - let config_ser = toml::to_string(self).context("unable to serialize configuration")?; - tokio::fs::write(path, config_ser) - .await - .context("unable to write config file")?; - - Ok(()) + Self::from_str(&config_ser) } } @@ -402,17 +406,37 @@ async fn build_relay_config(cfg: Config) -> Result None, }; - let limits = relay::Limits { - accept_conn_limit: cfg - .limits - .as_ref() - .map(|l| l.accept_conn_limit) - .unwrap_or_default(), - accept_conn_burst: cfg - .limits - .as_ref() - .map(|l| l.accept_conn_burst) - .unwrap_or_default(), + let limits = match cfg.limits { + Some(ref limits) => { + let client_rx = match &limits.client { + Some(PerClientRateLimitConfig { rx: Some(rx) }) => { + if rx.bytes_per_second.is_none() && rx.max_burst_bytes.is_some() { + bail!("bytes_per_seconds must be specified to enable the rate-limiter"); + } + match rx.bytes_per_second { + Some(bps) => Some(ClientConnRateLimit { + bytes_per_second: bps + .try_into() + .context("bytes_per_second must be non-zero u32")?, + max_burst_bytes: rx + .max_burst_bytes + .map(|v| { + v.try_into().context("max_burst_bytes must be non-zero u32") + }) + .transpose()?, + }), + None => None, + } + } + Some(PerClientRateLimitConfig { rx: None }) | None => None, + }; + relay::Limits { + accept_conn_limit: limits.accept_conn_limit, + accept_conn_burst: limits.accept_conn_burst, + client_rx, + } + } + None => Default::default(), }; let relay_config = relay::RelayConfig { http_bind_addr: cfg.http_bind_addr(), @@ -477,3 +501,46 @@ mod metrics { } } } + +#[cfg(test)] +mod tests { + use std::num::NonZeroU32; + + use testresult::TestResult; + + use super::*; + + #[tokio::test] + async fn test_rate_limit_config() -> TestResult { + let config = " + [limits.client.rx] + bytes_per_second = 400 + max_burst_bytes = 800 + "; + let config = Config::from_str(config)?; + let relay_config = build_relay_config(config).await?; + + let relay = relay_config.relay.expect("no relay config"); + assert_eq!( + relay.limits.client_rx.expect("ratelimit").bytes_per_second, + NonZeroU32::try_from(400).unwrap() + ); + assert_eq!( + relay.limits.client_rx.expect("ratelimit").max_burst_bytes, + Some(NonZeroU32::try_from(800).unwrap()) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_rate_limit_default() -> TestResult { + let config = Config::from_str("")?; + let relay_config = build_relay_config(config).await?; + + let relay = relay_config.relay.expect("no relay config"); + assert!(relay.limits.client_rx.is_none()); + + Ok(()) + } +} diff --git a/iroh-relay/src/protos/disco.rs b/iroh-relay/src/protos/disco.rs index 3f54269f83..64af8187d7 100644 --- a/iroh-relay/src/protos/disco.rs +++ b/iroh-relay/src/protos/disco.rs @@ -10,6 +10,7 @@ pub(crate) const MAGIC_LEN: usize = MAGIC.as_bytes().len(); pub(crate) const KEY_LEN: usize = 32; const MESSAGE_HEADER_LEN: usize = MAGIC_LEN + KEY_LEN; + /// Reports whether p looks like it's a packet containing an encrypted disco message. pub fn looks_like_disco_wrapper(p: &[u8]) -> bool { if p.len() < MESSAGE_HEADER_LEN { diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index ef50a9bb07..4997dc2102 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -18,6 +18,9 @@ use tokio_util::codec::{Decoder, Encoder}; /// including its on-wire framing overhead) pub const MAX_PACKET_SIZE: usize = 64 * 1024; +/// The maximum frame size. +/// +/// This is also the minimum burst size that a rate-limiter has to accept. const MAX_FRAME_SIZE: usize = 1024 * 1024; /// The Relay magic number, sent in the FrameType::ClientInfo frame upon initial connection. @@ -200,9 +203,14 @@ pub(crate) async fn recv_client_key> + Un } } +/// The protocol for the relay server. +/// +/// This is a framed protocol, using [`tokio_util::codec`] to turn the streams of bytes into +/// [`Frame`]s. #[derive(Debug, Default, Clone)] pub(crate) struct DerpCodec; +/// The frames in the [`DerpCodec`]. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum Frame { ClientInfo { @@ -279,6 +287,12 @@ impl Frame { } } + /// Serialized length with frame header. + #[cfg(feature = "server")] + pub(crate) fn len_with_header(&self) -> usize { + self.len() + HEADER_LEN + } + /// Tries to decode a frame received over websockets. /// /// Specifically, bytes received from a binary websocket message frame. diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index c8400c79a6..80fac29e9e 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -16,7 +16,7 @@ //! - HTTPS `/generate_204`: Used for net_report probes. //! - STUN: UDP port for STUN requests/responses. -use std::{fmt, future::Future, net::SocketAddr, pin::Pin, sync::Arc}; +use std::{fmt, future::Future, net::SocketAddr, num::NonZeroU32, pin::Pin, sync::Arc}; use anyhow::{anyhow, bail, Context, Result}; use futures_lite::StreamExt; @@ -140,12 +140,24 @@ pub struct TlsConfig { } /// Rate limits. +// TODO: accept_conn_limit and accept_conn_burst are not currently implemented. #[derive(Debug, Default)] pub struct Limits { /// Rate limit for accepting new connection. Unlimited if not set. pub accept_conn_limit: Option, /// Burst limit for accepting new connection. Unlimited if not set. pub accept_conn_burst: Option, + /// Rate limits for incoming traffic from a client connection. + pub client_rx: Option, +} + +/// Per-client rate limit configuration. +#[derive(Debug, Copy, Clone)] +pub struct ClientConnRateLimit { + /// Max number of bytes per second to read from the client connection. + pub bytes_per_second: NonZeroU32, + /// Max number of bytes to read in a single burst. + pub max_burst_bytes: Option, } /// TLS certificate configuration. @@ -260,6 +272,9 @@ impl Server { .request_handler(Method::GET, "/index.html", Box::new(root_handler)) .request_handler(Method::GET, RELAY_PROBE_PATH, Box::new(probe_handler)) .request_handler(Method::GET, "/robots.txt", Box::new(robots_handler)); + if let Some(cfg) = relay_config.limits.client_rx { + builder = builder.client_rx_ratelimit(cfg); + } let http_addr = match relay_config.tls { Some(tls_config) => { let server_config = rustls::ServerConfig::builder_with_provider(Arc::new( diff --git a/iroh-relay/src/server/actor.rs b/iroh-relay/src/server/actor.rs index 9bd7f72d25..9ef38a0964 100644 --- a/iroh-relay/src/server/actor.rs +++ b/iroh-relay/src/server/actor.rs @@ -183,12 +183,8 @@ impl Actor { } Message::CreateClient(client_builder) => { inc!(Metrics, accepts); - - trace!( - node_id = client_builder.node_id.fmt_short(), - "create client" - ); let node_id = client_builder.node_id; + trace!(node_id = node_id.fmt_short(), "create client"); // build and register client, starting up read & write loops for the client // connection @@ -272,6 +268,7 @@ mod tests { stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), write_timeout: Duration::from_secs(1), channel_capacity: 10, + rate_limit: None, server_channel, }, Framed::new(test_io, DerpCodec), diff --git a/iroh-relay/src/server/client_conn.rs b/iroh-relay/src/server/client_conn.rs index c0567910f8..afb4ed1efd 100644 --- a/iroh-relay/src/server/client_conn.rs +++ b/iroh-relay/src/server/client_conn.rs @@ -1,16 +1,17 @@ //! The server-side representation of an ongoing client relaying connection. -use std::time::Duration; +use std::{future::Future, num::NonZeroU32, pin::Pin, sync::Arc, task::Poll, time::Duration}; use anyhow::{Context, Result}; use bytes::Bytes; -use futures_lite::StreamExt; -use futures_util::SinkExt; +use futures_lite::FutureExt; +use futures_sink::Sink; +use futures_util::{SinkExt, Stream, StreamExt}; use iroh_base::key::NodeId; use iroh_metrics::{inc, inc_by}; use tokio::sync::mpsc; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; -use tracing::{info, trace, warn, Instrument}; +use tracing::{error, info, instrument, trace, warn, Instrument}; use crate::{ protos::{ @@ -21,6 +22,7 @@ use crate::{ actor::{self, Packet}, metrics::Metrics, streams::RelayedStream, + ClientConnRateLimit, }, }; @@ -31,6 +33,7 @@ pub(super) struct ClientConnConfig { pub(super) stream: RelayedStream, pub(super) write_timeout: Duration, pub(super) channel_capacity: usize, + pub(super) rate_limit: Option, pub(super) server_channel: mpsc::Sender, } @@ -66,9 +69,22 @@ impl ClientConn { stream: io, write_timeout, channel_capacity, + rate_limit, server_channel, } = config; + let stream = match rate_limit { + Some(cfg) => { + let mut quota = governor::Quota::per_second(cfg.bytes_per_second); + if let Some(max_burst) = cfg.max_burst_bytes { + quota = quota.allow_burst(max_burst); + } + let limiter = governor::RateLimiter::direct(quota); + RateLimitedRelayedStream::new(io, limiter) + } + None => RateLimitedRelayedStream::unlimited(io), + }; + let done = CancellationToken::new(); let client_id = (key, conn_num); let (send_queue_s, send_queue_r) = mpsc::channel(channel_capacity); @@ -77,7 +93,7 @@ impl ClientConn { let (peer_gone_s, peer_gone_r) = mpsc::channel(channel_capacity); let actor = Actor { - stream: io, + stream, timeout: write_timeout, send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, @@ -162,7 +178,7 @@ impl ClientConn { #[derive(Debug)] struct Actor { /// IO Stream to talk to the client - stream: RelayedStream, + stream: RateLimitedRelayedStream, /// Maximum time we wait to complete a write to the client timeout: Duration, /// Packets queued to send to the client @@ -317,10 +333,186 @@ impl Actor { } } +/// Rate limiter for reading from a [`RelayedStream`]. +/// +/// The writes to the sink are not rate limited. +/// +/// This potentially buffers one frame if the rate limiter does not allows this frame. +/// While the frame is buffered the undernlying stream is no longer polled. +#[derive(Debug)] +struct RateLimitedRelayedStream { + inner: RelayedStream, + limiter: Option>, + state: State, + /// Keeps track if this stream was ever rate-limited. + limited_once: bool, +} + +#[derive(derive_more::Debug)] +enum State { + #[debug("Blocked")] + Blocked { + /// Future which will complete when the item can be yielded. + delay: Pin + Send + Sync>>, + /// Item to yield when the `delay` future completes. + item: anyhow::Result, + }, + Ready, +} + +impl RateLimitedRelayedStream { + fn new(inner: RelayedStream, limiter: governor::DefaultDirectRateLimiter) -> Self { + Self { + inner, + limiter: Some(Arc::new(limiter)), + state: State::Ready, + limited_once: false, + } + } + + fn unlimited(inner: RelayedStream) -> Self { + Self { + inner, + limiter: None, + state: State::Ready, + limited_once: false, + } + } +} + +impl RateLimitedRelayedStream { + /// Records metrics about being rate-limited. + fn record_rate_limited(&mut self) { + // TODO: add a label for the frame type. + inc!(Metrics, frames_rx_ratelimited_total); + if !self.limited_once { + inc!(Metrics, conns_rx_ratelimited_total); + self.limited_once = true; + } + } +} + +impl Stream for RateLimitedRelayedStream { + type Item = anyhow::Result; + + #[instrument(name = "rate_limited_relayed_stream", skip_all)] + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let Some(ref limiter) = self.limiter else { + // If there is no rate-limiter directly poll the inner. + return Pin::new(&mut self.inner).poll_next(cx); + }; + let limiter = limiter.clone(); + loop { + match &mut self.state { + State::Ready => { + // Poll inner for a new item. + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(item)) => { + match &item { + Ok(frame) => { + // How many bytes does this frame consume? + let Ok(frame_len) = + TryInto::::try_into(frame.len_with_header()) + .and_then(TryInto::::try_into) + else { + error!("frame len not NonZeroU32, is MAX_FRAME_SIZE too large?"); + // Let this frame through so to not completely break. + return Poll::Ready(Some(item)); + }; + + match limiter.check_n(frame_len) { + Ok(Ok(_)) => return Poll::Ready(Some(item)), + Ok(Err(_)) => { + // Item is rate-limited. + self.record_rate_limited(); + let delay = Box::pin({ + let limiter = limiter.clone(); + async move { + limiter.until_n_ready(frame_len).await.ok(); + } + }); + self.state = State::Blocked { delay, item }; + continue; + } + Err(_insufficient_capacity) => { + error!( + "frame larger than bucket capacity: \ + configuration error: \ + max_burst_bytes < MAX_FRAME_SIZE?" + ); + // Let this frame through so to not completely break. + return Poll::Ready(Some(item)); + } + } + } + Err(_) => { + // Yielding errors is not rate-limited. + return Poll::Ready(Some(item)); + } + } + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + State::Blocked { delay, .. } => { + match delay.poll(cx) { + Poll::Ready(_) => { + match std::mem::replace(&mut self.state, State::Ready) { + State::Ready => unreachable!(), + State::Blocked { item, .. } => { + // Yield the item directly, rate-limit has already been + // accounted for by awaiting the future. + return Poll::Ready(Some(item)); + } + } + } + Poll::Pending => return Poll::Pending, + } + } + } + } + } +} + +impl Sink for RateLimitedRelayedStream { + type Error = std::io::Error; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Frame) -> std::result::Result<(), Self::Error> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + #[cfg(test)] mod tests { use anyhow::bail; + use bytes::Bytes; use iroh_base::key::SecretKey; + use testresult::TestResult; use tokio_util::codec::Framed; use super::*; @@ -340,9 +532,10 @@ mod tests { let (io, io_rw) = tokio::io::duplex(1024); let mut io_rw = Framed::new(io_rw, DerpCodec); let (server_channel_s, mut server_channel_r) = mpsc::channel(10); + let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)); let actor = Actor { - stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), + stream: RateLimitedRelayedStream::unlimited(stream), timeout: Duration::from_secs(1), send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, @@ -479,10 +672,11 @@ mod tests { let (io, io_rw) = tokio::io::duplex(1024); let mut io_rw = Framed::new(io_rw, DerpCodec); let (server_channel_s, mut server_channel_r) = mpsc::channel(10); + let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)); println!("-- create client conn"); let actor = Actor { - stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), + stream: RateLimitedRelayedStream::unlimited(stream), timeout: Duration::from_secs(1), send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, @@ -542,4 +736,65 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_rate_limit() -> TestResult { + let _logging = iroh_test::logging::setup(); + + const LIMIT: u32 = 50; + const MAX_FRAMES: u32 = 100; + + // Rate limiter allowing LIMIT bytes/s + let quota = governor::Quota::per_second(NonZeroU32::try_from(LIMIT)?); + let limiter = governor::RateLimiter::direct(quota); + + // Build the rate limited stream. + let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _); + let mut frame_writer = Framed::new(io_write, DerpCodec); + let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io_read), DerpCodec)); + let mut stream = RateLimitedRelayedStream::new(stream, limiter); + + // Prepare a frame to send, assert its size. + let data = Bytes::from_static(b"hello world!!"); + let target = SecretKey::generate().public(); + let frame = Frame::SendPacket { + dst_key: target, + packet: data.clone(), + }; + let frame_len = frame.len_with_header(); + assert_eq!(frame_len, LIMIT as usize); + + // Send a frame, it should arrive. + info!("-- send packet"); + frame_writer.send(frame.clone()).await?; + frame_writer.flush().await?; + let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next()) + .await + .expect("timeout") + .expect("option") + .expect("ok"); + assert_eq!(recv_frame, frame); + + // Next frame does not arrive. + info!("-- send packet"); + frame_writer.send(frame.clone()).await?; + frame_writer.flush().await?; + let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await; + assert!(res.is_err(), "expecting a timeout"); + info!("-- timeout happened"); + + // Wait long enough. + info!("-- sleep"); + tokio::time::sleep(Duration::from_secs(1)).await; + + // Frame arrives. + let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next()) + .await + .expect("timeout") + .expect("option") + .expect("ok"); + assert_eq!(recv_frame, frame); + + Ok(()) + } } diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 98c96b6ef9..4c08fd3619 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -63,8 +63,8 @@ impl Clients { conn_num } + /// Builds the client handler and starts the read & write loops for the connection. pub async fn register(&mut self, client_config: ClientConnConfig) { - // this builds the client handler & starts the read & write loops to that client connection let key = client_config.node_id; trace!("registering client: {:?}", key); let conn_num = self.next_conn_num(); @@ -247,6 +247,7 @@ mod tests { stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)), write_timeout: Duration::from_secs(1), channel_capacity: 10, + rate_limit: None, server_channel, }, FramedRead::new(test_io, DerpCodec), diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 90c864f9c2..884b83a2f6 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -35,6 +35,7 @@ use crate::{ client_conn::ClientConnConfig, metrics::Metrics, streams::{MaybeTlsStream, RelayedStream}, + ClientConnRateLimit, }, }; @@ -75,7 +76,7 @@ fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes)> { /// /// Created using [`ServerBuilder::spawn`]. #[derive(Debug)] -pub struct Server { +pub(super) struct Server { addr: SocketAddr, http_server_task: AbortOnDropHandle<()>, cancel_server_loop: CancellationToken, @@ -86,14 +87,14 @@ impl Server { /// /// The server runs in the background as several async tasks. This allows controlling /// the server, in particular it allows gracefully shutting down the server. - pub fn handle(&self) -> ServerHandle { + pub(super) fn handle(&self) -> ServerHandle { ServerHandle { cancel_token: self.cancel_server_loop.clone(), } } /// Closes the underlying relay server and the HTTP(S) server tasks. - pub fn shutdown(&self) { + pub(super) fn shutdown(&self) { self.cancel_server_loop.cancel(); } @@ -102,12 +103,12 @@ impl Server { /// This is the root of all the tasks for the server. Aborting it will abort all the /// other tasks for the server. Awaiting it will complete when all the server tasks are /// completed. - pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> { + pub(super) fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> { &mut self.http_server_task } /// Returns the local address of this server. - pub fn addr(&self) -> SocketAddr { + pub(super) fn addr(&self) -> SocketAddr { self.addr } } @@ -116,24 +117,24 @@ impl Server { /// /// This does not allow access to the task but can communicate with it. #[derive(Debug, Clone)] -pub struct ServerHandle { +pub(super) struct ServerHandle { cancel_token: CancellationToken, } impl ServerHandle { /// Gracefully shut down the server. - pub fn shutdown(&self) { + pub(super) fn shutdown(&self) { self.cancel_token.cancel() } } /// Configuration to use for the TLS connection #[derive(Debug, Clone)] -pub struct TlsConfig { +pub(super) struct TlsConfig { /// The server config - pub config: Arc, + pub(super) config: Arc, /// The kind - pub acceptor: TlsAcceptor, + pub(super) acceptor: TlsAcceptor, } /// Builder for the Relay HTTP Server. @@ -141,7 +142,7 @@ pub struct TlsConfig { /// Defaults to handling relay requests on the "/relay" (and "/derp" for backwards compatibility) endpoint. /// Other HTTP endpoints can be added using [`ServerBuilder::request_handler`]. #[derive(derive_more::Debug)] -pub struct ServerBuilder { +pub(super) struct ServerBuilder { /// The ip + port combination for this server. addr: SocketAddr, /// Optional tls configuration/TlsAcceptor combination. @@ -155,27 +156,42 @@ pub struct ServerBuilder { handlers: Handlers, /// Headers to use for HTTP responses. headers: HeaderMap, + /// Rate-limiting configuration for an individual client connection. + /// + /// Rate-limiting is enforced on received traffic from individual clients. This + /// configuration applies to a single client connection. + client_rx_ratelimit: Option, } impl ServerBuilder { /// Creates a new [ServerBuilder]. - pub fn new(addr: SocketAddr) -> Self { + pub(super) fn new(addr: SocketAddr) -> Self { Self { addr, tls_config: None, handlers: Default::default(), headers: HeaderMap::new(), + client_rx_ratelimit: None, } } /// Serves all requests content using TLS. - pub fn tls_config(mut self, config: Option) -> Self { + pub(super) fn tls_config(mut self, config: Option) -> Self { self.tls_config = config; self } + /// Sets the per-client rate-limit configuration for incoming data. + /// + /// On each client connection the incoming data is rate-limited. By default + /// no rate limit is enforced. + pub(super) fn client_rx_ratelimit(mut self, config: ClientConnRateLimit) -> Self { + self.client_rx_ratelimit = Some(config); + self + } + /// Adds a custom handler for a specific Method & URI. - pub fn request_handler( + pub(super) fn request_handler( mut self, method: Method, uri_path: &'static str, @@ -186,7 +202,7 @@ impl ServerBuilder { } /// Adds HTTP headers to responses. - pub fn headers(mut self, headers: HeaderMap) -> Self { + pub(super) fn headers(mut self, headers: HeaderMap) -> Self { for (k, v) in headers.iter() { self.headers.insert(k.clone(), v.clone()); } @@ -194,13 +210,14 @@ impl ServerBuilder { } /// Builds and spawns an HTTP(S) Relay Server. - pub async fn spawn(self) -> Result { + pub(super) async fn spawn(self) -> Result { let server_task = ServerActorTask::spawn(); let service = RelayService::new( self.handlers, self.headers, server_task.server_channel.clone(), server_task.write_timeout, + self.client_rx_ratelimit, ); let addr = self.addr; @@ -272,7 +289,21 @@ impl ServerBuilder { } } +/// The hyper Service that serves the actual relay endpoints. +#[derive(Clone, Debug)] +struct RelayService(Arc); + +#[derive(Debug)] +struct Inner { + handlers: Handlers, + headers: HeaderMap, + server_channel: mpsc::Sender, + write_timeout: Duration, + rate_limit: Option, +} + impl RelayService { + /// Upgrades the HTTP connection to the relay protocol, runs relay client. fn call_client_conn( &self, mut req: Request, @@ -327,7 +358,7 @@ impl RelayService { None }; - debug!("upgrading protocol: {:?}", protocol); + debug!(?protocol, "upgrading connection"); // Setup a future that will eventually receive the upgraded // connection and talk a new protocol, and spawn the future @@ -340,19 +371,18 @@ impl RelayService { async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { - if let Err(e) = + if let Err(err) = this.0.relay_connection_handler(protocol, upgraded).await { warn!( - "upgrade to \"{}\": io error: {:?}", - e, - protocol.upgrade_header() + ?protocol, + "error accepting upgraded connection: {err:#}", ); } else { - debug!("upgrade to \"{}\" success", protocol.upgrade_header()); + debug!(?protocol, "upgraded connection completed"); }; } - Err(e) => warn!("upgrade error: {:?}", e), + Err(err) => warn!("upgrade error: {err:#}"), } } .instrument(debug_span!("handler")), @@ -385,41 +415,28 @@ impl Service> for RelayService { type Future = Pin> + Send>>; fn call(&self, req: Request) -> Self::Future { - // if the request hits the relay endpoint - // or /derp for backwards compat + // Create a client if the request hits the relay endpoint. if matches!( (req.method(), req.uri().path()), (&hyper::Method::GET, LEGACY_RELAY_PATH | RELAY_PATH) ) { let this = self.clone(); - // otherwise handle the relay connection as normal return Box::pin(async move { this.call_client_conn(req).await.map_err(Into::into) }); } + // Otherwise handle the relay connection as normal. - // check all other possible endpoints + // Check all other possible endpoints. let uri = req.uri().clone(); if let Some(res) = self.0.handlers.get(&(req.method().clone(), uri.path())) { let f = res(req, self.0.default_response()); return Box::pin(async move { f }); } - // otherwise return 404 + // Otherwise return 404 let res = self.0.not_found_fn(req, self.0.default_response()); Box::pin(async move { res }) } } -/// The hyper Service that servers the actual relay endpoints -#[derive(Clone, Debug)] -struct RelayService(Arc); - -#[derive(Debug)] -struct Inner { - handlers: Handlers, - headers: HeaderMap, - server_channel: mpsc::Sender, - write_timeout: Duration, -} - impl Inner { fn default_response(&self) -> ResponseBuilder { let mut response = Response::builder(); @@ -443,6 +460,10 @@ impl Inner { } /// The server HTTP handler to do HTTP upgrades. + /// + /// This handler runs while doing the connection upgrade handshake. Once the connection + /// is upgraded it sends the stream to the relay server which takes it over. After + /// having sent off the connection this handler returns. async fn relay_connection_handler(&self, protocol: Protocol, upgraded: Upgraded) -> Result<()> { debug!(?protocol, "relay_connection upgraded"); let (io, read_buf) = downcast_upgrade(upgraded)?; @@ -496,6 +517,7 @@ impl Inner { stream: io, write_timeout: self.write_timeout, channel_capacity: PER_CLIENT_SEND_QUEUE_DEPTH, + rate_limit: self.rate_limit, server_channel: self.server_channel.clone(), }; trace!("accept: create client"); @@ -511,7 +533,7 @@ impl Inner { /// TLS Certificate Authority acceptor. #[derive(Clone, derive_more::Debug)] -pub enum TlsAcceptor { +pub(super) enum TlsAcceptor { /// Uses Let's Encrypt as the Certificate Authority. This is used in production. LetsEncrypt(#[debug("tokio_rustls_acme::AcmeAcceptor")] AcmeAcceptor), /// Manually added tls acceptor. Generally used for tests or for when we've passed in @@ -525,12 +547,14 @@ impl RelayService { headers: HeaderMap, server_channel: mpsc::Sender, write_timeout: Duration, + rate_limit: Option, ) -> Self { Self(Arc::new(Inner { handlers, headers, server_channel, write_timeout, + rate_limit, })) } @@ -892,6 +916,7 @@ mod tests { Default::default(), server_task.server_channel.clone(), server_task.write_timeout, + None, ); // create client a and connect it to the server @@ -971,6 +996,7 @@ mod tests { Default::default(), server_task.server_channel.clone(), server_task.write_timeout, + None, ); // create client a and connect it to the server diff --git a/iroh-relay/src/server/metrics.rs b/iroh-relay/src/server/metrics.rs index 970294ec73..93e8247725 100644 --- a/iroh-relay/src/server/metrics.rs +++ b/iroh-relay/src/server/metrics.rs @@ -42,6 +42,11 @@ pub struct Metrics { /// Number of `FrameType::Unknown` received pub unknown_frames: Counter, + /// Number of frames received from client connection which have been rate-limited. + pub frames_rx_ratelimited_total: Counter, + /// Number of client connections which have had any frames rate-limited. + pub conns_rx_ratelimited_total: Counter, + /* * Metrics about peers */ @@ -91,6 +96,12 @@ impl Default for Metrics { got_ping: Counter::new("Number of times the server has received a Ping from a client."), sent_pong: Counter::new("Number of times the server has sent a Pong to a client."), unknown_frames: Counter::new("Number of unknown frames sent to this server."), + frames_rx_ratelimited_total: Counter::new( + "Number of frames received from client connection which have been rate-limited.", + ), + conns_rx_ratelimited_total: Counter::new( + "Number of client connections which have had any frames rate-limited.", + ), /* * Metrics about peers diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index e7b867bafc..f7a28990fd 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -14,6 +14,9 @@ use tokio_util::codec::Framed; use crate::protos::relay::{DerpCodec, Frame}; +/// A Stream and Sink for [`Frame`]s connected to a single relay client. +/// +/// The stream receives message from the client while the sink sends them to the client. #[derive(Debug)] pub(crate) enum RelayedStream { Derp(Framed),