Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(iroh-relay): Rate-limit client connections #2961

Merged
merged 19 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions iroh-relay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

use std::{
net::{Ipv6Addr, SocketAddr},
num::NonZeroU32,
path::{Path, PathBuf},
};

Expand Down Expand Up @@ -411,20 +410,25 @@ async fn build_relay_config(cfg: Config) -> Result<relay::ServerConfig<std::io::
Some(ref limits) => {
let client_rx = match &limits.client {
Some(PerClientRateLimitConfig { rx: Some(rx) }) => {
let mut cfg = ClientConnRateLimit::default();
if let Some(bps) = rx.bytes_per_second {
let v = NonZeroU32::try_from(bps)
.context("bytes_per_second must be non-zero u32")?;
cfg.bytes_per_second = v;
if rx.bytes_per_second.is_none() && rx.max_burst_bytes.is_some() {
bail!("bytes_per_seconds must be specified to enable the rate-limiter");
}
if let Some(burst) = rx.max_burst_bytes {
let v = NonZeroU32::try_from(burst)
.context("max_burst_bytes must be non-zero u32")?;
cfg.max_burst_bytes = v;
match rx.bytes_per_second {
Some(bps) => Some(ClientConnRateLimit {
bytes_per_second: bps
.try_into()
.context("bytes_per_second must be non-zero u32")?,
max_burst_bytes: rx
.max_burst_bytes
.map(|v| {
v.try_into().context("max_burst_bytes must be non-zero u32")
})
.transpose()?,
}),
None => None,
}
cfg
}
Some(PerClientRateLimitConfig { rx: None }) | None => Default::default(),
Some(PerClientRateLimitConfig { rx: None }) | None => None,
};
relay::Limits {
accept_conn_limit: limits.accept_conn_limit,
Expand Down Expand Up @@ -500,6 +504,8 @@ mod metrics {

#[cfg(test)]
mod tests {
use std::num::NonZeroU32;

use testresult::TestResult;

use super::*;
Expand All @@ -516,12 +522,12 @@ mod tests {

let relay = relay_config.relay.expect("no relay config");
assert_eq!(
relay.limits.client_rx.bytes_per_second,
relay.limits.client_rx.expect("ratelimit").bytes_per_second,
NonZeroU32::try_from(400).unwrap()
);
assert_eq!(
relay.limits.client_rx.max_burst_bytes,
NonZeroU32::try_from(800).unwrap()
relay.limits.client_rx.expect("ratelimit").max_burst_bytes,
Some(NonZeroU32::try_from(800).unwrap())
);

Ok(())
Expand All @@ -533,8 +539,7 @@ mod tests {
let relay_config = build_relay_config(config).await?;

let relay = relay_config.relay.expect("no relay config");
assert_eq!(relay.limits.client_rx.bytes_per_second, NonZeroU32::MAX);
assert_eq!(relay.limits.client_rx.max_burst_bytes, NonZeroU32::MAX);
assert!(relay.limits.client_rx.is_none());

Ok(())
}
Expand Down
25 changes: 5 additions & 20 deletions iroh-relay/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,16 @@ pub struct Limits {
/// Burst limit for accepting new connection. Unlimited if not set.
pub accept_conn_burst: Option<usize>,
/// Rate limits for incoming traffic from a client connection.
pub client_rx: ClientConnRateLimit,
pub client_rx: Option<ClientConnRateLimit>,
}

/// Per-client rate limit configuration.
#[derive(Debug, Copy, Clone)]
pub struct ClientConnRateLimit {
/// Max number of bytes per second to read from the client connection.
///
/// Defaults to [`NonZeroU32::MAX`], effectively unlimited.
pub bytes_per_second: NonZeroU32,
/// Max number of bytes to read in a single burst.
///
/// Defaults to [`NonZeroU32::MAX`], effectively unlimited.
pub max_burst_bytes: NonZeroU32,
}

impl ClientConnRateLimit {
pub(super) const MAX: ClientConnRateLimit = ClientConnRateLimit {
bytes_per_second: NonZeroU32::MAX,
max_burst_bytes: NonZeroU32::MAX,
};
}

impl Default for ClientConnRateLimit {
fn default() -> Self {
Self::MAX
}
pub max_burst_bytes: Option<NonZeroU32>,
}

/// TLS certificate configuration.
Expand Down Expand Up @@ -284,12 +267,14 @@ impl Server {
None => relay_config.http_bind_addr,
};
let mut builder = http_server::ServerBuilder::new(relay_bind_addr)
.client_rx_ratelimit(relay_config.limits.client_rx)
.headers(headers)
.request_handler(Method::GET, "/", Box::new(root_handler))
.request_handler(Method::GET, "/index.html", Box::new(root_handler))
.request_handler(Method::GET, RELAY_PROBE_PATH, Box::new(probe_handler))
.request_handler(Method::GET, "/robots.txt", Box::new(robots_handler));
if let Some(cfg) = relay_config.limits.client_rx {
builder = builder.client_rx_ratelimit(cfg);
}
let http_addr = match relay_config.tls {
Some(tls_config) => {
let server_config = rustls::ServerConfig::builder_with_provider(Arc::new(
Expand Down
3 changes: 1 addition & 2 deletions iroh-relay/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ mod tests {
server::{
client_conn::ClientConnConfig,
streams::{MaybeTlsStream, RelayedStream},
ClientConnRateLimit,
},
};

Expand All @@ -269,7 +268,7 @@ mod tests {
stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
write_timeout: Duration::from_secs(1),
channel_capacity: 10,
rate_limit: ClientConnRateLimit::MAX,
rate_limit: None,
server_channel,
},
Framed::new(test_io, DerpCodec),
Expand Down
44 changes: 25 additions & 19 deletions iroh-relay/src/server/client_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub(super) struct ClientConnConfig {
pub(super) stream: RelayedStream,
pub(super) write_timeout: Duration,
pub(super) channel_capacity: usize,
pub(super) rate_limit: ClientConnRateLimit,
pub(super) rate_limit: Option<ClientConnRateLimit>,
pub(super) server_channel: mpsc::Sender<actor::Message>,
}

Expand Down Expand Up @@ -69,13 +69,17 @@ impl ClientConn {
stream: io,
write_timeout,
channel_capacity,
rate_limit: rate_limit_config,
rate_limit,
server_channel,
} = config;

let quota = governor::Quota::per_second(rate_limit_config.bytes_per_second)
.allow_burst(rate_limit_config.max_burst_bytes);
let rate_limiter = governor::RateLimiter::direct(quota);
let rate_limiter = rate_limit.map(|cfg| {
let mut quota = governor::Quota::per_second(cfg.bytes_per_second);
if let Some(max_burst) = cfg.max_burst_bytes {
quota = quota.allow_burst(max_burst);
}
governor::RateLimiter::direct(quota)
});
let stream = RateLimitedRelayedStream::new(io, rate_limiter);

let done = CancellationToken::new();
Expand Down Expand Up @@ -335,7 +339,7 @@ impl Actor {
#[derive(Debug)]
struct RateLimitedRelayedStream {
inner: RelayedStream,
limiter: Arc<governor::DefaultDirectRateLimiter>,
limiter: Option<Arc<governor::DefaultDirectRateLimiter>>,
state: State,
/// Keeps track if this stream was ever rate-limited.
limited_once: bool,
Expand All @@ -354,10 +358,10 @@ enum State {
}

impl RateLimitedRelayedStream {
fn new(inner: RelayedStream, limiter: governor::DefaultDirectRateLimiter) -> Self {
fn new(inner: RelayedStream, limiter: Option<governor::DefaultDirectRateLimiter>) -> Self {
flub marked this conversation as resolved.
Show resolved Hide resolved
Self {
inner,
limiter: Arc::new(limiter),
limiter: limiter.map(Arc::new),
state: State::Ready,
limited_once: false,
}
Expand All @@ -384,6 +388,10 @@ impl Stream for RateLimitedRelayedStream {
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let Some(limiter) = self.limiter.clone() else {
flub marked this conversation as resolved.
Show resolved Hide resolved
// If there is no rate-limiter directly poll the inner.
return Pin::new(&mut self.inner).poll_next(cx);
};
loop {
match &mut self.state {
State::Ready => {
Expand All @@ -402,14 +410,16 @@ impl Stream for RateLimitedRelayedStream {
return Poll::Ready(Some(item));
};

match self.limiter.check_n(frame_len) {
match limiter.check_n(frame_len) {
Ok(Ok(_)) => return Poll::Ready(Some(item)),
Ok(Err(_)) => {
// Item is rate-limited.
self.record_rate_limited();
let limiter = self.limiter.clone();
let delay = Box::pin(async move {
limiter.until_n_ready(frame_len).await.ok();
let delay = Box::pin({
let limiter = limiter.clone();
async move {
limiter.until_n_ready(frame_len).await.ok();
}
});
self.state = State::Blocked { delay, item };
continue;
Expand Down Expand Up @@ -509,12 +519,10 @@ mod tests {
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Framed::new(io_rw, DerpCodec);
let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
let quota = governor::Quota::per_second(NonZeroU32::MAX);
let limiter = governor::RateLimiter::direct(quota);
let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec));

let actor = Actor {
stream: RateLimitedRelayedStream::new(stream, limiter),
stream: RateLimitedRelayedStream::new(stream, None),
timeout: Duration::from_secs(1),
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
Expand Down Expand Up @@ -651,13 +659,11 @@ mod tests {
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Framed::new(io_rw, DerpCodec);
let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
let quota = governor::Quota::per_second(NonZeroU32::MAX);
let limiter = governor::RateLimiter::direct(quota);
let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec));

println!("-- create client conn");
let actor = Actor {
stream: RateLimitedRelayedStream::new(stream, limiter),
stream: RateLimitedRelayedStream::new(stream, None),
timeout: Duration::from_secs(1),
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
Expand Down Expand Up @@ -733,7 +739,7 @@ mod tests {
let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _);
let mut frame_writer = Framed::new(io_write, DerpCodec);
let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io_read), DerpCodec));
let mut stream = RateLimitedRelayedStream::new(stream, limiter);
let mut stream = RateLimitedRelayedStream::new(stream, Some(limiter));

// Prepare a frame to send, assert its size.
let data = Bytes::from_static(b"hello world!!");
Expand Down
7 changes: 2 additions & 5 deletions iroh-relay/src/server/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ mod tests {
use super::*;
use crate::{
protos::relay::{recv_frame, DerpCodec, Frame, FrameType},
server::{
streams::{MaybeTlsStream, RelayedStream},
ClientConnRateLimit,
},
server::streams::{MaybeTlsStream, RelayedStream},
};

fn test_client_builder(key: NodeId) -> (ClientConnConfig, FramedRead<DuplexStream, DerpCodec>) {
Expand All @@ -250,7 +247,7 @@ mod tests {
stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
write_timeout: Duration::from_secs(1),
channel_capacity: 10,
rate_limit: ClientConnRateLimit::MAX,
rate_limit: None,
server_channel,
},
FramedRead::new(test_io, DerpCodec),
Expand Down
16 changes: 8 additions & 8 deletions iroh-relay/src/server/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ pub(super) struct ServerBuilder {
///
/// Rate-limiting is enforced on received traffic from individual clients. This
/// configuration applies to a single client connection.
client_rx_ratelimit: ClientConnRateLimit,
client_rx_ratelimit: Option<ClientConnRateLimit>,
}

impl ServerBuilder {
Expand All @@ -171,7 +171,7 @@ impl ServerBuilder {
tls_config: None,
handlers: Default::default(),
headers: HeaderMap::new(),
client_rx_ratelimit: ClientConnRateLimit::MAX,
client_rx_ratelimit: None,
}
}

Expand All @@ -184,9 +184,9 @@ impl ServerBuilder {
/// Sets the per-client rate-limit configuration for incoming data.
///
/// On each client connection the incoming data is rate-limited. By default
/// [`ClientConnRateLimit::MAX`] is enforced.
/// no rate limit is enforced.
pub(super) fn client_rx_ratelimit(mut self, config: ClientConnRateLimit) -> Self {
self.client_rx_ratelimit = config;
self.client_rx_ratelimit = Some(config);
self
}

Expand Down Expand Up @@ -299,7 +299,7 @@ struct Inner {
headers: HeaderMap,
server_channel: mpsc::Sender<Message>,
write_timeout: Duration,
rate_limit: ClientConnRateLimit,
rate_limit: Option<ClientConnRateLimit>,
}

impl RelayService {
Expand Down Expand Up @@ -547,7 +547,7 @@ impl RelayService {
headers: HeaderMap,
server_channel: mpsc::Sender<Message>,
write_timeout: Duration,
rate_limit: ClientConnRateLimit,
rate_limit: Option<ClientConnRateLimit>,
) -> Self {
Self(Arc::new(Inner {
handlers,
Expand Down Expand Up @@ -916,7 +916,7 @@ mod tests {
Default::default(),
server_task.server_channel.clone(),
server_task.write_timeout,
ClientConnRateLimit::MAX,
None,
);

// create client a and connect it to the server
Expand Down Expand Up @@ -996,7 +996,7 @@ mod tests {
Default::default(),
server_task.server_channel.clone(),
server_task.write_timeout,
ClientConnRateLimit::MAX,
None,
);

// create client a and connect it to the server
Expand Down
Loading