diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 8e94ed7c56a0ea..d2ea9811c131c7 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -466,6 +466,22 @@ config INET_DIAG_DESTROY had been disconnected. If unsure, say N. +config RUST_SOCK_ABSTRACTIONS + bool "INET: Rust sock abstractions" + depends on RUST + help + Adds Rust abstractions for working with `struct sock`s. + + If unsure, say N. + +config RUST_TCP_ABSTRACTIONS + bool "TCP: Rust abstractions" + depends on RUST_SOCK_ABSTRACTIONS + help + Adds support for writing Rust kernel modules that integrate with TCP. + + If unsure, say N. + menuconfig TCP_CONG_ADVANCED bool "TCP: advanced congestion control" help @@ -493,6 +509,15 @@ config TCP_CONG_BIC increase provides TCP friendliness. See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/ +config TCP_CONG_BIC_RUST + tristate "Binary Increase Congestion (BIC) control (Rust rewrite)" + depends on RUST_TCP_ABSTRACTIONS + help + Rust rewrite of the original implementation of Binary Increase + Congestion (BIC) control. + + If unsure, say N. + config TCP_CONG_CUBIC tristate "CUBIC TCP" default y @@ -501,6 +526,15 @@ config TCP_CONG_CUBIC among other techniques. See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/cubic-paper.pdf +config TCP_CONG_CUBIC_RUST + tristate "CUBIC TCP (Rust rewrite)" + depends on RUST_TCP_ABSTRACTIONS + help + Rust rewrite of the original implementation of TCP CUBIC congestion + control. + + If unsure, say N. + config TCP_CONG_WESTWOOD tristate "TCP Westwood+" default m @@ -688,9 +722,15 @@ choice config DEFAULT_BIC bool "Bic" if TCP_CONG_BIC=y + config DEFAULT_BIC_RUST + bool "Bic (Rust)" if TCP_CONG_BIC_RUST=y + config DEFAULT_CUBIC bool "Cubic" if TCP_CONG_CUBIC=y + config DEFAULT_CUBIC_RUST + bool "Cubic (Rust)" if TCP_CONG_CUBIC_RUST=y + config DEFAULT_HTCP bool "Htcp" if TCP_CONG_HTCP=y @@ -729,7 +769,9 @@ config TCP_CONG_CUBIC config DEFAULT_TCP_CONG string default "bic" if DEFAULT_BIC + default "bic_rust" if DEFAULT_BIC_RUST default "cubic" if DEFAULT_CUBIC + default "cubic_rust" if DEFAULT_CUBIC_RUST default "htcp" if DEFAULT_HTCP default "hybla" if DEFAULT_HYBLA default "vegas" if DEFAULT_VEGAS diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index ec36d2ec059e80..8aecd5fa55e96d 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -46,8 +46,10 @@ obj-$(CONFIG_INET_UDP_DIAG) += udp_diag.o obj-$(CONFIG_INET_RAW_DIAG) += raw_diag.o obj-$(CONFIG_TCP_CONG_BBR) += tcp_bbr.o obj-$(CONFIG_TCP_CONG_BIC) += tcp_bic.o +obj-$(CONFIG_TCP_CONG_BIC_RUST) += tcp_bic_rust.o obj-$(CONFIG_TCP_CONG_CDG) += tcp_cdg.o obj-$(CONFIG_TCP_CONG_CUBIC) += tcp_cubic.o +obj-$(CONFIG_TCP_CONG_CUBIC_RUST) += tcp_cubic_rust.o obj-$(CONFIG_TCP_CONG_DCTCP) += tcp_dctcp.o obj-$(CONFIG_TCP_CONG_WESTWOOD) += tcp_westwood.o obj-$(CONFIG_TCP_CONG_HSTCP) += tcp_highspeed.o diff --git a/net/ipv4/tcp_bic.c b/net/ipv4/tcp_bic.c index 58358bf92e1b8a..757033a32fd78b 100644 --- a/net/ipv4/tcp_bic.c +++ b/net/ipv4/tcp_bic.c @@ -16,6 +16,7 @@ #include #include +#include #include #define BICTCP_BETA_SCALE 1024 /* Scale factor beta calculation @@ -55,6 +56,7 @@ struct bictcp { u32 epoch_start; /* beginning of an epoch */ #define ACK_RATIO_SHIFT 4 u32 delayed_ack; /* estimate the ratio of Packets/ACKs << 4 */ + u64 start_time; }; static inline void bictcp_reset(struct bictcp *ca) @@ -65,6 +67,7 @@ static inline void bictcp_reset(struct bictcp *ca) ca->last_time = 0; ca->epoch_start = 0; ca->delayed_ack = 2 << ACK_RATIO_SHIFT; + ca->start_time = ktime_get_boot_fast_ns(); } static void bictcp_init(struct sock *sk) @@ -75,6 +78,19 @@ static void bictcp_init(struct sock *sk) if (initial_ssthresh) tcp_sk(sk)->snd_ssthresh = initial_ssthresh; + + pr_info("Socket created: start %llu\n", ca->start_time); +} + +static void bictcp_release(struct sock* sk) +{ + struct bictcp *ca = inet_csk_ca(sk); + + pr_info( + "Socket destroyed: start %llu, end %llu\n", + ca->start_time, + ktime_get_boot_fast_ns() + ); } /* @@ -147,11 +163,23 @@ static void bictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) if (tcp_in_slow_start(tp)) { acked = tcp_slow_start(tp, acked); - if (!acked) + if (!acked) { + pr_info( + "New cwnd: %u, time %llu, ssthresh %u, start %llu, ss 1\n", + tp->snd_cwnd, ktime_get_boot_fast_ns(), + tp->snd_ssthresh, ca->start_time + ); return; + } } bictcp_update(ca, tcp_snd_cwnd(tp)); tcp_cong_avoid_ai(tp, ca->cnt, acked); + + pr_info( + "New cwnd: %u, time %llu, ssthresh %u, start %llu, ss 1\n", + tp->snd_cwnd, ktime_get_boot_fast_ns(), + tp->snd_ssthresh, ca->start_time + ); } /* @@ -163,6 +191,12 @@ static u32 bictcp_recalc_ssthresh(struct sock *sk) const struct tcp_sock *tp = tcp_sk(sk); struct bictcp *ca = inet_csk_ca(sk); + pr_info( + "Enter fast retransmit: time %llu, start %llu\n", + ktime_get_boot_fast_ns(), + ca->start_time + ); + ca->epoch_start = 0; /* end of epoch */ /* Wmax and fast convergence */ @@ -180,8 +214,20 @@ static u32 bictcp_recalc_ssthresh(struct sock *sk) static void bictcp_state(struct sock *sk, u8 new_state) { - if (new_state == TCP_CA_Loss) + if (new_state == TCP_CA_Loss) { + struct bictcp *ca = inet_csk_ca(sk); + u64 tmp = ca->start_time; + + pr_info( + "Retransmission timeout fired: time %llu, start %llu\n", + ktime_get_boot_fast_ns(), + ca->start_time + ); + bictcp_reset(inet_csk_ca(sk)); + + ca->start_time = tmp; + } } /* Track delayed acknowledgment ratio using sliding window @@ -201,6 +247,7 @@ static void bictcp_acked(struct sock *sk, const struct ack_sample *sample) static struct tcp_congestion_ops bictcp __read_mostly = { .init = bictcp_init, + .release = bictcp_release, .ssthresh = bictcp_recalc_ssthresh, .cong_avoid = bictcp_cong_avoid, .set_state = bictcp_state, diff --git a/net/ipv4/tcp_bic_rust.rs b/net/ipv4/tcp_bic_rust.rs new file mode 100644 index 00000000000000..adbcd03d3b1dcd --- /dev/null +++ b/net/ipv4/tcp_bic_rust.rs @@ -0,0 +1,312 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Binary Increase Congestion control (BIC). Based on: +//! Binary Increase Congestion Control (BIC) for Fast Long-Distance +//! Networks - Lisong Xu, Khaled Harfoush, and Injong Rhee +//! IEEE INFOCOM 2004, Hong Kong, China, 2004, pp. 2514-2524 vol.4 +//! doi: 10.1109/INFCOM.2004.1354672 +//! Link: https://doi.org/10.1109/INFCOM.2004.1354672 +//! Link: https://web.archive.org/web/20160417213452/http://netsrv.csc.ncsu.edu/export/bitcp.pdf + +use core::cmp::{max, min}; +use core::num::NonZeroU32; +use kernel::c_str; +use kernel::net::tcp::cong::{self, module_cca}; +use kernel::prelude::*; +use kernel::time; + +const ACK_RATIO_SHIFT: u32 = 4; + +// TODO: Convert to module parameters once they are available. +/// The initial value of ssthresh for new connections. Setting this to `None` +/// implies `i32::MAX`. +const INITIAL_SSTHRESH: Option = None; +/// If cwnd is larger than this threshold, BIC engages; otherwise normal TCP +/// increase/decrease will be performed. +const LOW_WINDOW: u32 = 14; +/// In binary search, go to point: `cwnd + (W_max - cwnd) / BICTCP_B`. +// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised. +// SAFETY: This will panic at compile time when passing zero. +const BICTCP_B: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(4) }; +/// The maximum increment, i.e., `S_max`. This is used during additive increase. +/// After crossing `W_max`, slow start is performed until passing +/// `MAX_INCREMENT * (BICTCP_B - 1)`. +// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised. +// SAFETY: This will panic at compile time when passing zero. +const MAX_INCREMENT: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(16) }; +/// The number of RTT it takes to get from `W_max - BICTCP_B` to `W_max` (and +/// from `W_max` to `W_max + BICTCP_B`). This is not part of the original paper +/// and results in a slow additive increase across `W_max`. +const SMOOTH_PART: u32 = 20; +/// Whether to use fast convergence. This is a heuristic to increase the +/// release of bandwidth by existing flows to speed up the convergence to a +/// steady state when a new flow joins the link. +const FAST_CONVERGENCE: bool = true; +/// Factor for multiplicative decrease. In fast retransmit we have: +/// `cwnd = cwnd * BETA/BETA_SCALE` +/// and if fast convergence is active: +/// `W_max = cwnd * (1 + BETA/BETA_SCALE)/2` +/// instead of `W_max = cwnd`. +const BETA: u32 = 819; +/// Used to calculate beta in [0, 1] with integer arithmetics. +// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised. +// SAFETY: This will panic at compile time when passing zero. +const BETA_SCALE: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1024) }; +/// The minimum amount of time that has to pass between two updates of the cwnd. +const MIN_UPDATE_INTERVAL: time::Msecs32 = time::MSEC_PER_SEC / 32; + +module_cca! { + type: Bic, + name: "tcp_bic_rust", + author: "Rust for Linux Contributors", + description: "Binary Increase Congestion control (BIC) algorithm, Rust implementation", + license: "GPL v2", +} + +struct Bic {} + +#[vtable] +impl cong::Algorithm for Bic { + type Data = BicState; + + const NAME: &'static CStr = c_str!("bic_rust"); + + fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { + if let Ok(cong::State::Open) = sk.inet_csk().ca_state() { + let ca = sk.inet_csk_ca_mut(); + + // Track delayed acknowledgment ratio using sliding window: + // ratio = (15*ratio + sample) / 16 + ca.delayed_ack = ca.delayed_ack.wrapping_add( + sample + .pkts_acked() + .wrapping_sub(ca.delayed_ack >> ACK_RATIO_SHIFT), + ); + } + } + + fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { + let cwnd = sk.tcp_sk().snd_cwnd(); + let ca = sk.inet_csk_ca_mut(); + + pr_info!( + // TODO: remove + "Enter fast retransmit: time {}, start {}", + time::ktime_get_boot_fast_ns(), + ca.start_time + ); + + // Epoch has ended. + ca.epoch_start = 0; + ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { + (cwnd * (BETA_SCALE.get() + BETA)) / (2 * BETA_SCALE.get()) + } else { + cwnd + }; + + if cwnd <= LOW_WINDOW { + // Act like normal TCP. + max(cwnd >> 1, 2) + } else { + max((cwnd * BETA) / BETA_SCALE, 2) + } + } + + fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { + if !sk.tcp_is_cwnd_limited() { + return; + } + + let tp = sk.tcp_sk_mut(); + + if tp.in_slow_start() { + acked = tp.slow_start(acked); + if acked == 0 { + pr_info!( + // TODO: remove + "New cwnd {}, time {}, ssthresh {}, start {}, ss 1", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_ns(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().start_time + ); + return; + } + } + + let cwnd = tp.snd_cwnd(); + let cnt = sk.inet_csk_ca_mut().update(cwnd); + sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); + + pr_info!( + // TODO: remove + "New cwnd {}, time {}, ssthresh {}, start {}, ss 0", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_ns(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().start_time + ); + } + + fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { + if matches!(new_state, cong::State::Loss) { + pr_info!( + // TODO: remove + "Retransmission timeout fired: time {}, start {}", + time::ktime_get_boot_fast_ns(), + sk.inet_csk_ca().start_time + ); + sk.inet_csk_ca_mut().reset() + } + } + + fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { + pr_info!( + // TODO: remove + "Undo cwnd reduction: time {}, start {}", + time::ktime_get_boot_fast_ns(), + sk.inet_csk_ca().start_time + ); + + cong::reno::undo_cwnd(sk) + } + + fn init(sk: &mut cong::Sock<'_, Self>) { + if let Some(ssthresh) = INITIAL_SSTHRESH { + sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); + } + + // TODO: remove + pr_info!("Socket created: start {}", sk.inet_csk_ca().start_time); + } + + // TODO: remove + fn release(sk: &mut cong::Sock<'_, Self>) { + pr_info!( + "Socket destroyed: start {}, end {}", + sk.inet_csk_ca().start_time, + time::ktime_get_boot_fast_ns() + ); + } +} + +/// Internal state of each instance of the algorithm. +struct BicState { + /// During congestion avoidance, cwnd is increased at most every `cnt` + /// acknowledged packets, i.e., the average increase per acknowledged packet + /// is proportional to `1 / cnt`. + // NOTE: The C impl initialises this to zero. It then ensures that zero is + // never passed to `cong_avoid_ai`, which could divide by it. Make it + // explicit in the types that zero is not a valid value. + cnt: NonZeroU32, + /// Last maximum `snd_cwnd`, i.e, `W_max`. + last_max_cwnd: u32, + /// The last `snd_cwnd`. + last_cwnd: u32, + /// Time when `last_cwnd` was updated. + last_time: time::Msecs32, + /// Records the beginning of an epoch. + epoch_start: time::Msecs32, + /// Estimates the ratio of `packets/ACK << 4`. This allows us to adjust cwnd + /// per packet when a receiver is sending a single ACK for multiple received + /// packets. + delayed_ack: u32, + /// Time when algorithm was initialised. + // TODO: remove + start_time: time::Nsecs, +} + +impl Default for BicState { + fn default() -> Self { + Self { + // NOTE: Initialising this to 1 deviates from the C code. It does + // not change the behaviour of the algorithm. + cnt: NonZeroU32::MIN, + last_max_cwnd: 0, + last_cwnd: 0, + last_time: 0, + epoch_start: 0, + delayed_ack: 2 << ACK_RATIO_SHIFT, + // TODO: remove + start_time: time::ktime_get_boot_fast_ns(), + } + } +} + +impl BicState { + /// Compute congestion window to use. Returns the new `cnt`. + /// + /// This governs the behavior of the algorithm during congestion avoidance. + fn update(&mut self, cwnd: u32) -> NonZeroU32 { + let now = time::ktime_get_boot_fast_ms32(); + + // Do nothing if we are invoked too frequently. + if self.last_cwnd == cwnd && now.wrapping_sub(self.last_time) <= MIN_UPDATE_INTERVAL { + return self.cnt; + } + + self.last_cwnd = cwnd; + self.last_time = now; + + // Record the beginning of an epoch. + if self.epoch_start == 0 { + self.epoch_start = now; + } + + // Start off like normal TCP. + if cwnd <= LOW_WINDOW { + self.cnt = NonZeroU32::new(cwnd).unwrap_or(NonZeroU32::MIN); + return self.cnt; + } + + let mut new_cnt = if cwnd < self.last_max_cwnd { + // binary increase + let dist: u32 = (self.last_max_cwnd - cwnd) / BICTCP_B; + + if dist > MAX_INCREMENT.get() { + // additive increase + cwnd / MAX_INCREMENT + } else if dist <= 1 { + // careful additive increase + (cwnd * SMOOTH_PART) / BICTCP_B + } else { + // binary search + cwnd / dist + } + } else { + if cwnd < self.last_max_cwnd + BICTCP_B.get() { + // careful additive increase + (cwnd * SMOOTH_PART) / BICTCP_B + } else if cwnd < self.last_max_cwnd + MAX_INCREMENT.get() * (BICTCP_B.get() - 1) { + // slow start + (cwnd * (BICTCP_B.get() - 1)) / (cwnd - self.last_max_cwnd) + } else { + // linear increase + cwnd / MAX_INCREMENT + } + }; + + // If in initial slow start or link utilization is very low. + if self.last_max_cwnd == 0 { + new_cnt = min(new_cnt, 20); + } + + // Account for estimated packets/ACK to ensure that we increase per + // packet. + new_cnt = (new_cnt << ACK_RATIO_SHIFT) / self.delayed_ack; + + self.cnt = NonZeroU32::new(new_cnt).unwrap_or(NonZeroU32::MIN); + + self.cnt + } + + fn reset(&mut self) { + // TODO: remove + let tmp = self.start_time; + + *self = Self::default(); + + // TODO: remove + self.start_time = tmp; + } +} diff --git a/net/ipv4/tcp_cubic_rust.rs b/net/ipv4/tcp_cubic_rust.rs new file mode 100644 index 00000000000000..f93c15f27d77be --- /dev/null +++ b/net/ipv4/tcp_cubic_rust.rs @@ -0,0 +1,510 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! TCP CUBIC congestion control algorithm. +//! +//! Based on: +//! Sangtae Ha, Injong Rhee, and Lisong Xu. 2008. +//! CUBIC: A New TCP-Friendly High-Speed TCP Variant. +//! SIGOPS Oper. Syst. Rev. 42, 5 (July 2008), 64–74. +//! +//! +//! CUBIC is also described in [RFC9438](https://www.rfc-editor.org/rfc/rfc9438). + +use core::cmp::{max, min}; +use core::num::NonZeroU32; +use kernel::c_str; +use kernel::net::tcp; +use kernel::net::tcp::cong::{self, hystart, hystart::HystartDetect, module_cca}; +use kernel::prelude::*; +use kernel::time; + +const BICTCP_BETA_SCALE: u32 = 1024; + +// TODO: Convert to module parameters once they are available. Currently these +// are the defaults from the C implementation. +// TODO: Use `NonZeroU32` where appropriate. +/// Whether to use fast convergence. This is a heuristic to increase the +/// release of bandwidth by existing flows to speed up the convergence to a +/// steady state when a new flow joins the link. +const FAST_CONVERGENCE: bool = true; +/// The factor for multiplicative decrease of cwnd upon a loss event. Will be +/// divided by `BICTCP_BETA_SCALE`, approximately 0.7. +const BETA: u32 = 717; +/// The initial value of ssthresh for new connections. Setting this to `None` +/// implies `i32::MAX`. +const INITIAL_SSTHRESH: Option = None; +/// The parameter `C` that scales the cubic term is defined as `BIC_SCALE/2^10`. +/// (For C: Dimension: Time^-2, Unit: s^-2). +const BIC_SCALE: u32 = 41; +/// In environments where CUBIC grows cwnd less aggressively than normal TCP, +/// enabling this option causes it to behave like normal TCP instead. This is +/// the case in short RTT and/or low bandwidth delay product networks. +const TCP_FRIENDLINESS: bool = true; +/// Whether to use the [HyStart] slow start algorithm. +/// +/// [HyStart]: hystart::HyStart +const HYSTART: bool = true; + +impl hystart::HyStart for Cubic { + /// Which mechanism to use for deciding when it is time to exit slow start. + const DETECT: HystartDetect = HystartDetect::Both; + /// Lower bound for cwnd during hybrid slow start. + const LOW_WINDOW: u32 = 16; + /// Spacing between ACKs indicating an ACK-train. + /// (Dimension: Time. Unit: us). + const ACK_DELTA: time::Usecs32 = 2000; +} + +// TODO: Those are computed based on the module parameters in the init. Even +// with module parameters available this will be a bit tricky to do in Rust. +/// Factor of `8/3 * (1 + beta) / (1 - beta)` that is used in various +/// calculations. (Dimension: none) +const BETA_SCALE: u32 = ((8 * (BICTCP_BETA_SCALE + BETA)) / 3) / (BICTCP_BETA_SCALE - BETA); +/// Factor of `2^10*C/SRTT` where `SRTT = 100ms` that is used in various +/// calculations. (Dimension: Time^-3, Unit: s^-3). +const CUBE_RTT_SCALE: u32 = BIC_SCALE * 10; +/// Factor of `SRTT/C` where `SRTT = 100ms` and `C` from above. +/// (Dimension: Time^3. Unit: (ms)^3) +// Note: C uses a custom time unit of 2^-10 s called `BICTCP_HZ`. This +// implementation consistently uses milliseconds instead. +const CUBE_FACTOR: u64 = 1_000_000_000 * (1u64 << 10) / (CUBE_RTT_SCALE as u64); + +module_cca! { + type: Cubic, + name: "tcp_cubic_rust", + author: "Rust for Linux Contributors", + description: "TCP CUBIC congestion control algorithm, Rust implementation", + license: "GPL v2", +} + +struct Cubic {} + +#[vtable] +impl cong::Algorithm for Cubic { + type Data = CubicState; + + const NAME: &'static CStr = c_str!("cubic_rust"); + + fn init(sk: &mut cong::Sock<'_, Self>) { + if HYSTART { + ::reset(sk) + } else if let Some(ssthresh) = INITIAL_SSTHRESH { + sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); + } + + // TODO: remove + pr_info!( + "init: socket created: start {}us", + sk.inet_csk_ca().hystart_state.start_time + ); + } + + // TODO: remove + fn release(sk: &mut cong::Sock<'_, Self>) { + pr_info!( + "release: socket destroyed: start {}us, end {}us", + sk.inet_csk_ca().hystart_state.start_time, + time::ktime_get_boot_fast_us32(), + ); + } + + fn cwnd_event(sk: &mut cong::Sock<'_, Self>, ev: cong::Event) { + if matches!(ev, cong::Event::TxStart) { + // Here we cannot avoid jiffies as the `lsndtime` field is measured + // in jiffies. + let now = time::jiffies32(); + let delta: time::Jiffies32 = now.wrapping_sub(sk.tcp_sk().lsndtime()); + + if (delta as i32) <= 0 { + return; + } + + let ca = sk.inet_csk_ca_mut(); + // Ok, lets switch to SI units. + let now = time::ktime_get_boot_fast_ms32(); + let delta = time::jiffies_to_msecs(delta as time::Jiffies); + // TODO: remove + pr_debug!("cwnd_event: TxStart, now {}ms, delta {}ms", now, delta); + // We were application limited, i.e., idle, for a while. If we are + // in congestion avoidance, shift `epoch_start` by the time we were + // idle to keep cwnd growth to cubic curve. + ca.epoch_start = ca.epoch_start.map(|mut epoch_start| { + epoch_start = epoch_start.wrapping_add(delta); + if tcp::after(epoch_start, now) { + epoch_start = now; + } + epoch_start + }); + } + } + + fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { + if matches!(new_state, cong::State::Loss) { + pr_info!( + // TODO: remove + "set_state: Loss, time {}us, start {}us", + time::ktime_get_boot_fast_us32(), + sk.inet_csk_ca().hystart_state.start_time + ); + sk.inet_csk_ca_mut().reset(); + ::reset(sk); + } + } + + fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { + // Some samples do not include RTTs. + let Some(rtt_us) = sample.rtt_us() else { + // TODO: remove + pr_debug!( + "pkts_acked: no RTT sample, start {}us", + sk.inet_csk_ca().hystart_state.start_time, + ); + return; + }; + + let epoch_start = sk.inet_csk_ca().epoch_start; + // For some time after existing fast recovery the samples might still be + // inaccurate. + if epoch_start.is_some_and(|epoch_start| { + time::ktime_get_boot_fast_ms32().wrapping_sub(epoch_start) < time::MSEC_PER_SEC + }) { + // TODO: remove + pr_debug!( + "pkts_acked: {}ms - {}ms < 1s, too close to epoch_start", + time::ktime_get_boot_fast_ms32(), + epoch_start.unwrap() + ); + return; + } + + let delay = max(1, rtt_us); + let cwnd = sk.tcp_sk().snd_cwnd(); + let in_slow_start = sk.tcp_sk().in_slow_start(); + let ca = sk.inet_csk_ca_mut(); + + // TODO: remove + pr_debug!( + "pkts_acked: delay {}us, cwnd {}, ss {}", + delay, + cwnd, + in_slow_start + ); + + // First call after reset or the delay decreased. + if ca.hystart_state.delay_min.is_none() + || ca + .hystart_state + .delay_min + .is_some_and(|delay_min| delay_min > delay) + { + ca.hystart_state.delay_min = Some(delay); + } + + if in_slow_start && HYSTART && ca.hystart_state.in_hystart::(cwnd) { + hystart::HyStart::update(sk, delay); + } + } + + fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { + let cwnd = sk.tcp_sk().snd_cwnd(); + let ca = sk.inet_csk_ca_mut(); + + pr_info!( + // TODO: remove + "ssthresh: time {}us, start {}us", + time::ktime_get_boot_fast_us32(), + ca.hystart_state.start_time + ); + + // Epoch has ended. + ca.epoch_start = None; + ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { + (cwnd * (BICTCP_BETA_SCALE + BETA)) / (2 * BICTCP_BETA_SCALE) + } else { + cwnd + }; + + max((cwnd * BETA) / BICTCP_BETA_SCALE, 2) + } + + fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { + pr_info!( + // TODO: remove + "undo_cwnd: time {}us, start {}us", + time::ktime_get_boot_fast_us32(), + sk.inet_csk_ca().hystart_state.start_time + ); + + cong::reno::undo_cwnd(sk) + } + + fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { + if !sk.tcp_is_cwnd_limited() { + return; + } + + let tp = sk.tcp_sk_mut(); + + if tp.in_slow_start() { + acked = tp.slow_start(acked); + if acked == 0 { + pr_info!( + // TODO: remove + "cong_avoid: new cwnd {}, time {}us, ssthresh {}, start {}us, ss 1", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_us32(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().hystart_state.start_time + ); + return; + } + } + + let cwnd = tp.snd_cwnd(); + let cnt = sk.inet_csk_ca_mut().update(cwnd, acked); + sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); + + pr_info!( + // TODO: remove + "cong_avoid: new cwnd {}, time {}us, ssthresh {}, start {}us, ss 0", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_us32(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().hystart_state.start_time + ); + } +} + +#[allow(non_snake_case)] +struct CubicState { + /// Increase cwnd by one step after `cnt` ACKs. + cnt: NonZeroU32, + /// W__last_max. + last_max_cwnd: u32, + /// Value of cwnd before it was updated the last time. + last_cwnd: u32, + /// Time when `last_cwnd` was updated. + last_time: time::Msecs32, + /// Value of cwnd where the plateau of the cubic function is located. + origin_point: u32, + /// Time it takes to reach `origin_point`, measured from the beginning of + /// an epoch. + K: time::Msecs32, + /// Time when the current epoch has started. `None` when not in congestion + /// avoidance. + epoch_start: Option, + /// Number of packets that have been ACKed in the current epoch. + ack_cnt: u32, + /// Estimate for the cwnd of TCP Reno. + tcp_cwnd: u32, + /// State of the HyStart slow start algorithm. + hystart_state: hystart::HyStartState, +} + +impl hystart::HasHyStartState for CubicState { + fn hy(&self) -> &hystart::HyStartState { + &self.hystart_state + } + + fn hy_mut(&mut self) -> &mut hystart::HyStartState { + &mut self.hystart_state + } +} + +impl Default for CubicState { + fn default() -> Self { + Self { + // NOTE: Initializing this to 1 deviates from the C code. It does + // not change the behavior. + cnt: NonZeroU32::MIN, + last_max_cwnd: 0, + last_cwnd: 0, + last_time: 0, + origin_point: 0, + K: 0, + epoch_start: None, + ack_cnt: 0, + tcp_cwnd: 0, + hystart_state: hystart::HyStartState::default(), + } + } +} + +impl CubicState { + /// Checks if the current CUBIC increase is less aggressive than normal TCP, + /// i.e., if we are in the TCP-friendly region. If so, returns `cnt` that + /// increases at the speed of normal TCP. + #[inline] + fn tcp_friendliness(&mut self, cnt: u32, cwnd: u32) -> u32 { + if !TCP_FRIENDLINESS { + return cnt; + } + + // Estimate cwnd of normal TCP. + // cwnd/3 * (1 + BETA)/(1 - BETA) + let delta = (cwnd * BETA_SCALE) >> 3; + // W__tcp(t) = W__tcp(t__0) + (acks(t) - acks(t__0)) / delta + while self.ack_cnt > delta { + self.ack_cnt -= delta; + self.tcp_cwnd += 1; + } + + //TODO: remove + pr_info!( + "tcp_friendliness: tcp_cwnd {}, cwnd {}, start {}us", + self.tcp_cwnd, + cwnd, + self.hystart_state.start_time, + ); + + // We are slower than normal TCP. + if self.tcp_cwnd > cwnd { + let delta = self.tcp_cwnd - cwnd; + + min(cnt, cwnd / delta) + } else { + cnt + } + } + + /// Returns the new value of `cnt` to keep the window grow on the cubic + /// curve. + fn update(&mut self, cwnd: u32, acked: u32) -> NonZeroU32 { + let now: time::Msecs32 = time::ktime_get_boot_fast_ms32(); + + self.ack_cnt += acked; + + if self.last_cwnd == cwnd && now.wrapping_sub(self.last_time) <= time::MSEC_PER_SEC / 32 { + return self.cnt; + } + + // We can update the CUBIC function at most once every ms. + if self.epoch_start.is_some() && now == self.last_time { + let cnt = self.tcp_friendliness(self.cnt.get(), cwnd); + + // SAFETY: 2 != 0. QED. + self.cnt = unsafe { NonZeroU32::new_unchecked(max(2, cnt)) }; + + return self.cnt; + } + + self.last_cwnd = cwnd; + self.last_time = now; + + if self.epoch_start.is_none() { + self.epoch_start = Some(now); + self.ack_cnt = acked; + self.tcp_cwnd = cwnd; + + if self.last_max_cwnd <= cwnd { + self.K = 0; + self.origin_point = cwnd; + } else { + // K = (SRTT/C * (W__max - cwnd))^1/3 + self.K = cubic_root(CUBE_FACTOR * ((self.last_max_cwnd - cwnd) as u64)); + self.origin_point = self.last_max_cwnd; + } + } + + // PANIC: This is always `Some`. + let epoch_start: time::Msecs32 = self.epoch_start.unwrap(); + let Some(delay_min) = self.hystart_state.delay_min else { + pr_err!("update: delay_min was None"); + return self.cnt; + }; + + // NOTE: Addition might overflow after 50 days without a loss, C uses a + // `u64` here. + let t: time::Msecs32 = + now.wrapping_sub(epoch_start) + (delay_min / (time::USEC_PER_MSEC as time::Usecs32)); + let offs: time::Msecs32 = if t < self.K { self.K - t } else { t - self.K }; + + // Calculate c/rtt * (t-K)^3 and change units to seconds. + // Widen type to prevent overflow. + let offs = offs as u64; + let delta = (((CUBE_RTT_SCALE as u64 * offs * offs * offs) >> 10) / 1_000_000_000) as u32; + // Calculate the full cubic function c/rtt * (t - K)^3 + W__max. + let target = if t < self.K { + self.origin_point - delta + } else { + self.origin_point + delta + }; + + // TODO: remove + pr_info!( + "update: now {}ms, epoch_start {}ms, t {}ms, K {}ms, |t - K| {}ms, last_max_cwnd {}, origin_point {}, target {}, start {}us", + now, + epoch_start, + t, + self.K, + offs, + self.last_max_cwnd, + self.origin_point, + target, + self.hystart_state.start_time, + ); + + let mut cnt = if target > cwnd { + cwnd / (target - cwnd) + } else { + // Effectively keeps cwnd constant for the next RTT. + 100 * cwnd + }; + + // In initial epoch or after timeout we grow at a minimum rate. + if self.last_max_cwnd == 0 { + cnt = min(cnt, 20); + } + + // SAFETY: 2 != 0. QED. + self.cnt = unsafe { NonZeroU32::new_unchecked(max(2, self.tcp_friendliness(cnt, cwnd))) }; + + self.cnt + } + + fn reset(&mut self) { + // TODO: remove + let tmp = self.hystart_state.start_time; + + *self = Self::default(); + + // TODO: remove + self.hystart_state.start_time = tmp; + } +} + +/// Calculate the cubic root of `a` using a table lookup followed by one +/// Newton-Raphson iteration. +// E[ |(cubic_root(x) - x.cbrt()) / x.cbrt()| ] = 0.71% for x in 1..1_000_000. +// E[ |(cubic_root(x) - x.cbrt()) / x.cbrt()| ] = 8.87% for x in 1..63. +// Where everything is `f64` and `.cbrt` is Rust's builtin. No overflow panics +// in this domain. +const fn cubic_root(a: u64) -> u32 { + const V: [u8; 64] = [ + 0, 54, 54, 54, 118, 118, 118, 118, 123, 129, 134, 138, 143, 147, 151, 156, 157, 161, 164, + 168, 170, 173, 176, 179, 181, 185, 187, 190, 192, 194, 197, 199, 200, 202, 204, 206, 209, + 211, 213, 215, 217, 219, 221, 222, 224, 225, 227, 229, 231, 232, 234, 236, 237, 239, 240, + 242, 244, 245, 246, 248, 250, 251, 252, 254, + ]; + + let mut b = fls64(a) as u32; + if b < 7 { + return ((V[a as usize] as u32) + 35) >> 6; + } + + b = ((b * 84) >> 8) - 1; + let shift = a >> (b * 3); + + let mut x = (((V[shift as usize] as u32) + 10) << b) >> 6; + x = 2 * x + (a / ((x * (x - 1)) as u64)) as u32; + + (x * 341) >> 10 +} + +/// Find last set bit in a 64-bit word. +/// +/// The last (most significant) bit is at position 64. +#[inline] +const fn fls64(x: u64) -> u8 { + (64 - x.leading_zeros()) as u8 +} diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index 65b98831b97560..978885d6336272 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -17,6 +17,7 @@ #include #include #include +#include /* `bindgen` gets confused at certain things. */ const size_t RUST_CONST_HELPER_ARCH_SLAB_MINALIGN = ARCH_SLAB_MINALIGN; diff --git a/rust/helpers.c b/rust/helpers.c index 70e59efd92bc43..fc01594fafbe8a 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -31,6 +31,7 @@ #include #include #include +#include __noreturn void rust_helper_BUG(void) { @@ -157,6 +158,42 @@ void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func, } EXPORT_SYMBOL_GPL(rust_helper_init_work_with_key); +bool rust_helper_tcp_in_slow_start(const struct tcp_sock *tp) +{ + return tcp_in_slow_start(tp); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_in_slow_start); + +bool rust_helper_tcp_is_cwnd_limited(const struct sock *sk) +{ + return tcp_is_cwnd_limited(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_is_cwnd_limited); + +struct tcp_sock *rust_helper_tcp_sk(struct sock *sk) +{ + return tcp_sk(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_sk); + +u32 rust_helper_tcp_snd_cwnd(const struct tcp_sock *tp) +{ + return tcp_snd_cwnd(tp); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_snd_cwnd); + +struct inet_connection_sock *rust_helper_inet_csk(const struct sock *sk) +{ + return inet_csk(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_inet_csk); + +void *rust_helper_inet_csk_ca(struct sock *sk) +{ + return inet_csk_ca(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_inet_csk_ca); + /* * `bindgen` binds the C `size_t` type as the Rust `usize` type, so we can * use it in contexts where Rust expects a `usize` like slice (array) indices. diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 1e5f229b82638e..b90282b9962c2d 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -13,6 +13,7 @@ #![no_std] #![feature(allocator_api)] +#![feature(associated_type_bounds)] #![feature(coerce_unsized)] #![feature(dispatch_from_dyn)] #![feature(new_uninit)] @@ -75,6 +76,29 @@ pub trait Module: Sized + Sync { fn init(module: &'static ThisModule) -> error::Result; } +/// A module that is pinned and initialised in-place. +pub trait InPlaceModule: Sync { + /// Creates an initialiser for the module. + /// + /// It is called when the module is loaded. + fn init(module: &'static ThisModule) -> impl init::PinInit; +} + +impl InPlaceModule for T { + fn init(module: &'static ThisModule) -> impl init::PinInit { + let initer = move |slot: *mut Self| { + let m = ::init(module)?; + + // SAFETY: `slot` is valid for write per the contract with `pin_init_from_closure`. + unsafe { slot.write(m) }; + Ok(()) + }; + + // SAFETY: On success, `initer` always fully initialises an instance of `Self`. + unsafe { init::pin_init_from_closure(initer) } + } +} + /// Equivalent to `THIS_MODULE` in the C API. /// /// C header: [`include/linux/export.h`](srctree/include/linux/export.h) diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index fe415cb369d3ac..a17555940d6418 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -4,3 +4,7 @@ #[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)] pub mod phy; +#[cfg(CONFIG_RUST_SOCK_ABSTRACTIONS)] +pub mod sock; +#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] +pub mod tcp; diff --git a/rust/kernel/net/sock.rs b/rust/kernel/net/sock.rs new file mode 100644 index 00000000000000..c4fc539303a88b --- /dev/null +++ b/rust/kernel/net/sock.rs @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Representation of a C `struct sock`. +//! +//! C header: [`include/net/sock.h`](srctree/include/net/sock.h) + +#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] +use crate::net::tcp::{self, InetConnectionSock, TcpSock}; +use crate::types::Opaque; +use core::convert::TryFrom; +use core::ptr::addr_of; + +/// Representation of a C `struct sock`. +/// +/// Not intended to be used directly by modules. Abstractions should provide a +/// safe interface to only those operations that are OK to use for the module. +/// +/// # Invariants +/// +/// Referencing a `sock` using this struct asserts that you are in +/// a context where all safe methods defined on this struct are indeed safe to +/// call. +#[repr(transparent)] +pub(crate) struct Sock { + sk: Opaque, +} + +impl Sock { + /// Returns a raw pointer to the wrapped `struct sock`. + /// + /// It is up to the caller to use it correctly. + #[inline] + pub(crate) fn raw_sk_mut(&mut self) -> *mut bindings::sock { + self.sk.get() + } + + /// Returns the sockets pacing rate in bytes per second. + #[inline] + pub(crate) fn sk_pacing_rate(&self) -> u64 { + // NOTE: C uses READ_ONCE for this field, thus `read_volatile`. + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. It is a C unsigned + // long so we can always convert it to a u64 without loss. + unsafe { addr_of!((*self.sk.get()).sk_pacing_rate).read_volatile() as u64 } + } + + /// Returns the sockets pacing status. + #[inline] + pub(crate) fn sk_pacing_status(&self) -> Result { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { Pacing::try_from(*addr_of!((*self.sk.get()).sk_pacing_status)) } + } + + /// Returns the sockets maximum GSO segment size to build. + #[inline] + pub(crate) fn sk_gso_max_size(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. It is an unsigned int + // and we are guaranteed that this will always fit into a u32. + unsafe { *addr_of!((*self.sk.get()).sk_gso_max_size) as u32 } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + /// + /// # Safety + /// + /// `sk` must be valid for `tcp_sk`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn tcp_sk<'a>(&'a self) -> &'a TcpSock { + // SAFETY: + // - Downcasting via `tcp_sk` is OK by the functions precondition. + // - The cast is OK since `TcpSock` is transparent to `struct tcp_sock`. + unsafe { &*(bindings::tcp_sk(self.sk.get()) as *const TcpSock) } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + /// + /// # Safety + /// + /// `sk` must be valid for `tcp_sk`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn tcp_sk_mut<'a>(&'a mut self) -> &'a mut TcpSock { + // SAFETY: + // - Downcasting via `tcp_sk` is OK by the functions precondition. + // - The cast is OK since `TcpSock` is transparent to `struct tcp_sock`. + unsafe { &mut *(bindings::tcp_sk(self.sk.get()) as *mut TcpSock) } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: tcp::cong::Algorithm::Data + /// + /// # Safety + /// + /// - `sk` must be valid for `inet_csk_ca`, + /// - `sk` must use the CCA `T`, the `init` CB of the CCA must have been + /// called, the `release` CB of the CCA must not have been called. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn inet_csk_ca<'a, T: tcp::cong::Algorithm + ?Sized>( + &'a self, + ) -> &'a T::Data { + // SAFETY: By the function's preconditions, calling `inet_csk_ca` is OK + // and the returned pointer points to a valid instance of `T::Data`. + unsafe { &*(bindings::inet_csk_ca(self.sk.get()) as *const T::Data) } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: tcp::cong::Algorithm::Data + /// + /// # Safety + /// + /// - `sk` must be valid for `inet_csk_ca`, + /// - `sk` must use the CCA `T`, the `init` CB of the CCA must have been + /// called, the `release` CB of the CCA must not have been called. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn inet_csk_ca_mut<'a, T: tcp::cong::Algorithm + ?Sized>( + &'a mut self, + ) -> &'a mut T::Data { + // SAFETY: By the function's preconditions, calling `inet_csk_ca` is OK + // and the returned pointer points to a valid instance of `T::Data`. + unsafe { &mut *(bindings::inet_csk_ca(self.sk.get()) as *mut T::Data) } + } + + /// Returns the [`InetConnectionSock`] view of this socket. + /// + /// # Safety + /// + /// `sk` must be valid for `inet_csk`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn inet_csk<'a>(&'a self) -> &'a InetConnectionSock { + // SAFETY: + // - Calling `inet_csk` is OK by the functions precondition. + // - The cast is OK since `InetConnectionSock` is transparent to + // `struct inet_connection_sock`. + unsafe { &*(bindings::inet_csk(self.sk.get()) as *const InetConnectionSock) } + } + + /// Tests if the connection's sending rate is limited by the cwnd. + /// + /// # Safety + /// + /// `sk` must be valid for `tcp_is_cwnd_limited`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn tcp_is_cwnd_limited(&self) -> bool { + // SAFETY: Calling `tcp_is_cwnd_limited` is OK by the functions + // precondition. + unsafe { bindings::tcp_is_cwnd_limited(self.sk.get()) } + } +} + +/// The socket's pacing status. +#[repr(u32)] +#[allow(missing_docs)] +pub enum Pacing { + r#None = bindings::sk_pacing_SK_PACING_NONE, + Needed = bindings::sk_pacing_SK_PACING_NEEDED, + Fq = bindings::sk_pacing_SK_PACING_FQ, +} + +// TODO: Replace with automatically generated code by bindgen when it becomes +// possible. +impl TryFrom for Pacing { + type Error = (); + + fn try_from(val: u32) -> Result { + match val { + x if x == Pacing::r#None as u32 => Ok(Pacing::r#None), + x if x == Pacing::Needed as u32 => Ok(Pacing::Needed), + x if x == Pacing::Fq as u32 => Ok(Pacing::Fq), + _ => Err(()), + } + } +} diff --git a/rust/kernel/net/tcp.rs b/rust/kernel/net/tcp.rs new file mode 100644 index 00000000000000..62002c777411a5 --- /dev/null +++ b/rust/kernel/net/tcp.rs @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Transmission Control Protocol (TCP). + +use crate::time; +use crate::types::Opaque; +use core::{num, ptr}; + +pub mod cong; + +/// Representation of a `struct inet_connection_sock`. +/// +/// # Invariants +/// +/// Referencing a `inet_connection_sock` using this struct asserts that you are +/// in a context where all safe methods defined on this struct are indeed safe +/// to call. +/// +/// C header: [`include/net/inet_connection_sock.h`](srctree/include/net/inet_connection_sock.h) +#[repr(transparent)] +pub struct InetConnectionSock { + icsk: Opaque, +} + +impl InetConnectionSock { + /// Returns the congestion control state of this socket. + #[inline] + pub fn ca_state(&self) -> Result { + const CA_STATE_MASK: u8 = 0b11111; + // TODO: Replace code to access the bit field with automatically + // generated code by bindgen when it becomes possible. + // SAFETY: By the type invariants, it is okay to read `icsk_ca_state`, which is the first + // member of the bitfield and has a size of five. + cong::State::try_from(unsafe { + *(ptr::addr_of!((*self.icsk.get())._bitfield_1).cast::()) & CA_STATE_MASK + }) + } +} + +/// Representation of a `struct tcp_sock`. +/// +/// # Invariants +/// +/// Referencing a `tcp_sock` using this struct asserts that you are in +/// a context where all safe methods defined on this struct are indeed safe to +/// call. +/// +/// C header: [`include/linux/tcp.h`](srctree/include/linux/tcp.h) +#[repr(transparent)] +pub struct TcpSock { + tp: Opaque, +} + +impl TcpSock { + /// Returns true iff `snd_cwnd < snd_ssthresh`. + #[inline] + pub fn in_slow_start(&self) -> bool { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_in_slow_start(self.tp.get()) } + } + + /// Performs the standard slow start increment of cwnd. + /// + /// If this causes the socket to exit slow start, any leftover ACKs are + /// returned. + #[inline] + pub fn slow_start(&mut self, acked: u32) -> u32 { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_slow_start(self.tp.get(), acked) } + } + + /// Performs the standard increase of cwnd during congestion avoidance. + /// + /// The increase per ACK is upper bounded by `1 / w`. + #[inline] + pub fn cong_avoid_ai(&mut self, w: num::NonZeroU32, acked: u32) { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_cong_avoid_ai(self.tp.get(), w.get(), acked) }; + } + + /// Returns the connection's current cwnd. + #[inline] + pub fn snd_cwnd(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_snd_cwnd(self.tp.get()) } + } + + /// Returns the connection's current ssthresh. + #[inline] + pub fn snd_ssthresh(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).snd_ssthresh) } + } + + /// Returns the sequence number of the next byte that will be sent. + #[inline] + pub fn snd_nxt(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).snd_nxt) } + } + + /// Returns the sequence number of the first unacknowledged byte. + #[inline] + pub fn snd_una(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).snd_una) } + } + + /// Returns the time when the last packet was received or sent. + #[inline] + pub fn tcp_mstamp(&self) -> time::Usecs { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).tcp_mstamp) } + } + + /// Sets the connection's ssthresh. + #[inline] + pub fn set_snd_ssthresh(&mut self, new: u32) { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of_mut!((*self.tp.get()).snd_ssthresh) = new }; + } + + /// Returns the timestamp of the last send data packet in 32bit Jiffies. + #[inline] + pub fn lsndtime(&self) -> time::Jiffies32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).lsndtime) as time::Jiffies32 } + } +} + +/// Tests if `sqn_1` comes after `sqn_2`. +#[inline] +pub fn after(sqn_1: u32, sqn_2: u32) -> bool { + (sqn_2.wrapping_sub(sqn_1) as i32) < 0 +} diff --git a/rust/kernel/net/tcp/cong.rs b/rust/kernel/net/tcp/cong.rs new file mode 100644 index 00000000000000..a08ec04b946621 --- /dev/null +++ b/rust/kernel/net/tcp/cong.rs @@ -0,0 +1,647 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Congestion control algorithms (CCA). +//! +//! Abstractions for implementing pluggable CCAs in Rust. + +use crate::bindings; +use crate::error::{self, Error, VTABLE_DEFAULT_ERROR}; +use crate::init::PinInit; +use crate::net::sock; +use crate::prelude::{pr_err, vtable}; +use crate::str::CStr; +use crate::time; +use crate::types::Opaque; +use crate::ThisModule; +use crate::{build_assert, build_error, field_size, try_pin_init}; +use core::convert::TryFrom; +use core::marker::PhantomData; +use core::pin::Pin; +use macros::{pin_data, pinned_drop}; + +use super::{InetConnectionSock, TcpSock}; + +pub mod hystart; + +/// Congestion control algorithm (CCA). +/// +/// A CCA is implemented as a set of callbacks that are invoked whenever +/// specific events occur in a connection. Each socket has its own instance of +/// some CCA. Every instance of a CCA has its own private data that is stored in +/// the socket and is mutated by the callbacks. +/// +/// Callbacks that operate on the same instance are guaranteed to run +/// sequentially, and each callback has exclusive mutable access to the private +/// data of the instance it operates on. +#[vtable] +pub trait Algorithm { + /// Private data. Each socket has its own instance. + type Data: Default + Send + Sync; + + /// Name of the algorithm. + const NAME: &'static CStr; + + /// Called when entering CWR, Recovery, or Loss states from Open or Disorder + /// states. Returns the new slow start threshold. + fn ssthresh(sk: &mut Sock<'_, Self>) -> u32; + + /// Called when one of the events in [`Event`] occurs. + fn cwnd_event(_sk: &mut Sock<'_, Self>, _ev: Event) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Called towards the end of processing an ACK if a cwnd increase is + /// possible. Performs a new cwnd calculation and sets it on the socket. + // Note: In fact, one of `cong_avoid` and `cond_control` is required. + // (see `tcp_validate_congestion_control`) + fn cong_avoid(sk: &mut Sock<'_, Self>, ack: u32, acked: u32); + + /// Called before the sender's congestion state is changed. + fn set_state(_sk: &mut Sock<'_, Self>, _new_state: State) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Called when removing ACKed packets from the retransmission queue. Can be + /// used for packet ACK accounting. + fn pkts_acked(_sk: &mut Sock<'_, Self>, _sample: &AckSample) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Called to undo a recent cwnd reduction that was found to has been + /// unnecessary. Returns the new value of cwnd. + fn undo_cwnd(sk: &mut Sock<'_, Self>) -> u32; + + /// Initializes the private data. + /// + /// When this function is called, [`sk.inet_csk_ca()`] will contain a value + /// returned by `Self::Data::default()`. + /// + /// Only implement this function when you need to perform additional setup + /// tasks. + /// + /// [`sk.inet_csk_ca()`]: Sock::inet_csk_ca + fn init(_sk: &mut Sock<'_, Self>) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Cleans up the private data. + /// + /// After this function returns, [`sk.inet_csk_ca()`] will be dropped. + /// + /// Only implement this function when you need to perform additional cleanup + /// tasks. + /// + /// [`sk.inet_csk_ca()`]: Sock::inet_csk_ca + fn release(_sk: &mut Sock<'_, Self>) { + build_error!(VTABLE_DEFAULT_ERROR); + } +} + +pub mod reno { + //! TCP Reno congestion control. + //! + //! Algorithms may choose to invoke these callbacks instead of providing + //! their own implementation. This is convenient as a new CCA might have + //! the same logic as an existing one in some of its callbacks. + use super::{Algorithm, Sock}; + use crate::bindings; + + /// Implementation of [`undo_cwnd`] that returns `max(snd_cwnd, prior_cwnd)`, + /// where `prior_cwnd` is the value of cwnd before the last reduction. + /// + /// [`undo_cwnd`]: super::Algorithm::undo_cwnd + #[inline] + pub fn undo_cwnd(sk: &mut Sock<'_, T>) -> u32 { + // SAFETY: + // - `sk` has been passed to the callback that invoked us, + // - it is OK to pass it to the callback of the Reno algorithm as it + // will never touch the private data. + unsafe { bindings::tcp_reno_undo_cwnd(sk.sk.raw_sk_mut()) } + } +} + +/// Representation of the `struct sock *` that is passed to the callbacks of the +/// CCA. +/// +/// Every callback receives a pointer to the socket that it is operating on. +/// There are certain operations that callbacks are allowed to perform on the +/// socket, and this type just exposes methods for performing those. This +/// prevents callbacks from performing arbitrary manipulations on the socket. +// TODO: Currently all callbacks can perform all operations. However, this +// might be too permissive, e.g., the `pkts_acked` callback should probably not +// be changing cwnd... +/// +/// # Invariants +/// +/// The wrapped `sk` must have been obtained as the argument to a callback of +/// the congestion algorithm `T` (other than the `init` cb) and may only be used +/// for the duration of that callback. In particular: +/// +/// - `sk` points to a valid `struct sock`. +/// - `tcp_sk(sk)` points to a valid `struct tcp_sock`. +/// - The socket uses the CCA `T`. +/// - `inet_csk_ca(sk)` points to a valid instance of `T::Data`, which belongs +/// to the instance of the algorithm used by this socket. A callback has +/// exclusive, mutable access to this data. +pub struct Sock<'a, T: Algorithm + ?Sized> { + sk: &'a mut sock::Sock, + _pd: PhantomData, +} + +impl<'a, T: Algorithm + ?Sized> Sock<'a, T> { + /// Creates a new `Sock`. + /// + /// # Safety + /// + /// - `sk` must have been obtained as the argument to a callback of the + /// congestion algorithm `T`. + /// - The CCAs private data must have been initialised. + /// - The returned value must not live longer than the duration of the + /// callback. + unsafe fn new(sk: *mut bindings::sock) -> Self { + // INVARIANTS: Satisfied by the functions precondition. + Self { + // SAFETY: + // - The cast is OK since `sock::Sock` is transparent to + // `struct sock`. + // - Dereferencing `sk` is OK since the pointers passed to CCA CBs + // are valid. + // - By the function's preconditions, the produced `Self` value will + // only live for the duration of the callback; thus, the wrapped + // reference will always be valid. + sk: unsafe { &mut *(sk as *mut sock::Sock) }, + _pd: PhantomData, + } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + #[inline] + pub fn tcp_sk<'b>(&'b self) -> &'b TcpSock { + // SAFETY: By the type invariants, `sk` is valid for `tcp_sk`. + unsafe { self.sk.tcp_sk() } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + #[inline] + pub fn tcp_sk_mut<'b>(&'b mut self) -> &'b mut TcpSock { + // SAFETY: By the type invariants, `sk` is valid for `tcp_sk`. + unsafe { self.sk.tcp_sk_mut() } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: Algorithm::Data + #[inline] + pub fn inet_csk_ca<'b>(&'b self) -> &'b T::Data { + // SAFETY: By the type invariants, `sk` is valid for `inet_csk_ca`, it + // it uses the algorithm `T`, and the private data is valid. + unsafe { self.sk.inet_csk_ca::() } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: Algorithm::Data + #[inline] + pub fn inet_csk_ca_mut<'b>(&'b mut self) -> &'b mut T::Data { + // SAFETY: By the type invariants, `sk` is valid for `inet_csk_ca`, it + // it uses the algorithm `T`, and the private data is valid. + unsafe { self.sk.inet_csk_ca_mut::() } + } + + /// Returns the [`InetConnectionSock`] of this socket. + #[inline] + pub fn inet_csk<'b>(&'b self) -> &'b InetConnectionSock { + // SAFETY: By the type invariants, `sk` is valid for `inet_csk`. + unsafe { self.sk.inet_csk() } + } + + /// Tests if the connection's sending rate is limited by the cwnd. + // NOTE: This feels like it should be a method on `TcpSock`, but C defines + // it on `struct sock` so there is not much we can do about it. At least, if + // we don't want to reimplement the function (or perform the conversion from + // `struct tcp_sock` to `struct sock` just to have C reverse it right away. + #[inline] + pub fn tcp_is_cwnd_limited(&self) -> bool { + // SAFETY: By the type invariants, `sk` is valid for + // `tcp_is_cwnd_limited`. + unsafe { self.sk.tcp_is_cwnd_limited() } + } + + /// Returns the sockets pacing rate in bytes per second. + #[inline] + pub fn sk_pacing_rate(&self) -> u64 { + self.sk.sk_pacing_rate() + } + + /// Returns the sockets pacing status. + #[inline] + pub fn sk_pacing_status(&self) -> Result { + self.sk.sk_pacing_status() + } + + /// Returns the sockets maximum GSO segment size to build. + #[inline] + pub fn sk_gso_max_size(&self) -> u32 { + self.sk.sk_gso_max_size() + } +} + +/// Representation of the `struct ack_sample *` that is passed to the +/// `pkts_acked` callback of the CCA. +/// +/// # Invariants +/// +/// - `sample` points to a valid `struct ack_sample`, +/// - all fields of `sample` can be read without additional synchronization. +pub struct AckSample { + sample: *const bindings::ack_sample, +} + +impl AckSample { + /// Creates a new `AckSample`. + /// + /// # Safety + /// + /// `sample` must have been obtained as the argument to the `pkts_acked` + /// callback. + unsafe fn new(sample: *const bindings::ack_sample) -> Self { + // INVARIANTS: Satisfied by the function's precondition. + Self { sample } + } + + /// Returns the number of packets that were ACKed. + #[inline] + pub fn pkts_acked(&self) -> u32 { + // SAFETY: By the type invariants it is OK to read any field. + unsafe { (*self.sample).pkts_acked } + } + + /// Returns the RTT measurement of this ACK sample. + // Note: Some samples might not include a RTT measurement. This is indicated + // by a negative value for `rtt_us`, we return `None` in that case. + #[inline] + pub fn rtt_us(&self) -> Option { + // SAFETY: By the type invariants it is OK to read any field. + match unsafe { (*self.sample).rtt_us } { + t if t < 0 => None, + t => Some(t as time::Usecs32), + } + } +} + +/// States of the TCP sender state machine. +/// +/// The TCP sender's congestion state indicating normal or abnormal situations +/// in the last round of packets sent. The state is driven by the ACK +/// information and timer events. +#[repr(u8)] +pub enum State { + /// Nothing bad has been observed recently. No apparent reordering, packet + /// loss, or ECN marks. + Open = bindings::tcp_ca_state_TCP_CA_Open as u8, + /// The sender enters disordered state when it has received DUPACKs or + /// SACKs in the last round of packets sent. This could be due to packet + /// loss or reordering but needs further information to confirm packets + /// have been lost. + Disorder = bindings::tcp_ca_state_TCP_CA_Disorder as u8, + /// The sender enters Congestion Window Reduction (CWR) state when it + /// has received ACKs with ECN-ECE marks, or has experienced congestion + /// or packet discard on the sender host (e.g. qdisc). + Cwr = bindings::tcp_ca_state_TCP_CA_CWR as u8, + /// The sender is in fast recovery and retransmitting lost packets, + /// typically triggered by ACK events. + Recovery = bindings::tcp_ca_state_TCP_CA_Recovery as u8, + /// The sender is in loss recovery triggered by retransmission timeout. + Loss = bindings::tcp_ca_state_TCP_CA_Loss as u8, +} + +// TODO: Replace with automatically generated code by bindgen when it becomes +// possible. +impl TryFrom for State { + type Error = (); + + fn try_from(val: u8) -> Result { + match val { + x if x == State::Open as u8 => Ok(State::Open), + x if x == State::Disorder as u8 => Ok(State::Disorder), + x if x == State::Cwr as u8 => Ok(State::Cwr), + x if x == State::Recovery as u8 => Ok(State::Recovery), + x if x == State::Loss as u8 => Ok(State::Loss), + _ => Err(()), + } + } +} + +/// Events passed to congestion control interface. +#[repr(u32)] +pub enum Event { + /// First transmit when no packets in flight. + TxStart = bindings::tcp_ca_event_CA_EVENT_TX_START, + /// Congestion window restart. + CwndRestart = bindings::tcp_ca_event_CA_EVENT_CWND_RESTART, + /// End of congestion recovery. + CompleteCwr = bindings::tcp_ca_event_CA_EVENT_COMPLETE_CWR, + /// Loss timeout. + Loss = bindings::tcp_ca_event_CA_EVENT_LOSS, + /// ECT set, but not CE marked. + EcnNoCe = bindings::tcp_ca_event_CA_EVENT_ECN_NO_CE, + /// Received CE marked IP packet. + EcnIsCe = bindings::tcp_ca_event_CA_EVENT_ECN_IS_CE, +} + +// TODO: Replace with automatically generated code by bindgen when it becomes +// possible. +impl TryFrom for Event { + type Error = (); + + fn try_from(ev: bindings::tcp_ca_event) -> Result { + match ev { + x if x == Event::TxStart as u32 => Ok(Event::TxStart), + x if x == Event::CwndRestart as u32 => Ok(Event::CwndRestart), + x if x == Event::CompleteCwr as u32 => Ok(Event::CompleteCwr), + x if x == Event::Loss as u32 => Ok(Event::Loss), + x if x == Event::EcnNoCe as u32 => Ok(Event::EcnNoCe), + x if x == Event::EcnIsCe as u32 => Ok(Event::EcnIsCe), + _ => Err(()), + } + } +} + +#[pin_data(PinnedDrop)] +struct Registration { + #[pin] + ops: Opaque, + _pd: PhantomData, +} + +// SAFETY: `Registration` doesn't provide any `&self` methods, so it is safe to +// pass references to it around. +unsafe impl Sync for Registration {} + +// SAFETY: Both registration and unregistration are implemented in C and safe to +// be performed from any thread, so `Registration` is `Send`. +unsafe impl Send for Registration {} + +impl Registration { + const NAME_FIELD: [i8; 16] = Self::gen_name_field::<16>(); + // Maximal size of the private data. + const ICSK_CA_PRIV_SIZE: usize = field_size!(bindings::inet_connection_sock, icsk_ca_priv); + const DATA_SIZE: usize = core::mem::size_of::(); + + fn new(module: &'static ThisModule) -> impl PinInit { + try_pin_init!(Self { + _pd: PhantomData, + ops <- Opaque::try_ffi_init(|ops_ptr: *mut bindings::tcp_congestion_ops| { + // SAFETY: `try_ffi_init` guarantees that `ops_ptr` is valid for + // write. + unsafe { ops_ptr.write(bindings::tcp_congestion_ops::default()) }; + + // SAFETY: `try_ffi_init` guarantees that `ops_ptr` is valid for + // write, and it has just been initialised above, so it's also + // valid for read. + let ops = unsafe { &mut *ops_ptr }; + + ops.ssthresh = Some(Self::ssthresh_cb); + ops.cong_avoid = Some(Self::cong_avoid_cb); + ops.undo_cwnd = Some(Self::undo_cwnd_cb); + if T::HAS_SET_STATE { + ops.set_state = Some(Self::set_state_cb); + } + if T::HAS_PKTS_ACKED { + ops.pkts_acked = Some(Self::pkts_acked_cb); + } + if T::HAS_CWND_EVENT { + ops.cwnd_event = Some(Self::cwnd_event_cb); + } + + // Even though it is not mandated by the C side, we + // unconditionally set these CBs to ensure that it is always + // safe to access the CCA's private data. + // Future work could allow the CCA to declare whether it wants + // to be able to use the private data. + ops.init = Some(Self::init_cb); + ops.release = Some(Self::release_cb); + + ops.owner = module.0; + ops.name = Self::NAME_FIELD; + + // SAFETY: Pointers stored in `ops` are static so they will live + // for as long as the registration is active (it is undone in + // `drop`). + error::to_result( unsafe { bindings::tcp_register_congestion_control(ops_ptr) }) + }), + }) + } + + const fn gen_name_field() -> [i8; N] { + let mut name_field: [i8; N] = [0; N]; + let mut i = 0; + + while i < T::NAME.len_with_nul() { + name_field[i] = T::NAME.as_bytes_with_nul()[i] as i8; + i += 1; + } + + name_field + } + + unsafe extern "C" fn cwnd_event_cb(sk: *mut bindings::sock, ev: bindings::tcp_ca_event) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + match Event::try_from(ev) { + Ok(ev) => T::cwnd_event(&mut sk, ev), + Err(_) => pr_err!("cwnd_event: event was {}", ev), + } + } + + unsafe extern "C" fn init_cb(sk: *mut bindings::sock) { + // Fail the build if the module-defined private data is larger than the + // storage that the kernel provides. + build_assert!(Self::DATA_SIZE <= Self::ICSK_CA_PRIV_SIZE); + + // SAFETY: + // - The `sk` that is passed to this callback is valid for + // `inet_csk_ca`. + // - We just checked that there is enough space for the cast to be okay. + let ca = unsafe { bindings::inet_csk_ca(sk) as *mut T::Data }; + + unsafe { ca.write(T::Data::default()) }; + + if T::HAS_INIT { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - We just initialized the `Data`. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::init(&mut sk) + } + } + + unsafe extern "C" fn release_cb(sk: *mut bindings::sock) { + if T::HAS_RELEASE { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::release(&mut sk) + } + + // We have to manually dispose the private data that we stored with the + // kernel. + // SAFETY: + // - The `sk` passed to callbacks is valid for `inet_csk_ca`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - After we return no other callback will be invoked with this socket. + unsafe { core::ptr::drop_in_place(bindings::inet_csk_ca(sk) as *mut T::Data) }; + } + + unsafe extern "C" fn ssthresh_cb(sk: *mut bindings::sock) -> u32 { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::ssthresh(&mut sk) + } + + unsafe extern "C" fn cong_avoid_cb(sk: *mut bindings::sock, ack: u32, acked: u32) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::cong_avoid(&mut sk, ack, acked) + } + + unsafe extern "C" fn set_state_cb(sk: *mut bindings::sock, new_state: u8) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + match State::try_from(new_state) { + Ok(new_state) => T::set_state(&mut sk, new_state), + Err(_) => pr_err!("set_state: new_state was {}", new_state), + } + } + + unsafe extern "C" fn pkts_acked_cb( + sk: *mut bindings::sock, + sample: *const bindings::ack_sample, + ) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + // SAFETY: + // - `sample` points to a valid `struct ack_sample`. + let sample = unsafe { AckSample::new(sample) }; + T::pkts_acked(&mut sk, &sample) + } + + unsafe extern "C" fn undo_cwnd_cb(sk: *mut bindings::sock) -> u32 { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::undo_cwnd(&mut sk) + } +} + +#[pinned_drop] +impl PinnedDrop for Registration { + fn drop(self: Pin<&mut Self>) { + // SAFETY: + // - The fact that `Self` exists implies that a previous call to + // `tcp_register_congestion_control` with `self.ops.get()` was + // successful. + unsafe { bindings::tcp_unregister_congestion_control(self.ops.get()) }; + } +} + +/// Kernel module that implements a single CCA `T`. +#[pin_data] +pub struct Module { + #[pin] + reg: Registration, +} + +impl crate::InPlaceModule for Module { + fn init(module: &'static ThisModule) -> impl PinInit { + try_pin_init!(Self { + reg <- Registration::::new(module), + }) + } +} + +/// Defines a kernel module that implements a single congestion control +/// algorithm. +/// +/// # Examples +/// +/// To start experimenting with your own congestion control algorithm, implement +/// the [`Algorithm`] trait and use this macro to declare the module to the +/// rest of the kerne. That's it! +/// +/// ```ignore +/// use kernel::{c_str, module_cca}; +/// use kernel::prelude::*; +/// use kernel::net::tcp::cong::*; +/// use core::num::NonZeroU32; +/// +/// struct MyCca {} +/// +/// #[vtable] +/// impl Algorithm for MyCca { +/// type Data = (); +/// +/// const NAME: &'static CStr = c_str!("my_cca"); +/// +/// fn undo_cwnd(sk: &mut Sock<'_, Self>) -> u32 { +/// reno::undo_cwnd(sk) +/// } +/// +/// fn ssthresh(_sk: &mut Sock<'_, Self>) -> u32 { +/// 2 +/// } +/// +/// fn cong_avoid(sk: &mut Sock<'_, Self>, _ack: u32, acked: u32) { +/// sk.tcp_sk_mut().cong_avoid_ai(NonZeroU32::new(1).unwrap(), acked) +/// } +/// } +/// +/// module_cca! { +/// type: MyCca, +/// name: "my_cca", +/// author: "Rust for Linux Contributors", +/// description: "Sample congestion control algorithm implemented in Rust.", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_cca { + (type: $type:ty, $($f:tt)*) => { + type ModuleType = $crate::net::tcp::cong::Module<$type>; + $crate::macros::module! { + type: ModuleType, + $($f)* + } + } +} +pub use module_cca; diff --git a/rust/kernel/net/tcp/cong/hystart.rs b/rust/kernel/net/tcp/cong/hystart.rs new file mode 100644 index 00000000000000..5bc847902c5f14 --- /dev/null +++ b/rust/kernel/net/tcp/cong/hystart.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! HyStart slow start algorithm. +//! +//! Based on: +//! Sangtae Ha, Injong Rhee, +//! Taming the elephants: New TCP slow start, +//! Computer Networks, Volume 55, Issue 9, 2011, Pages 2092-2110, +//! ISSN 1389-1286, + +use crate::net::sock; +use crate::net::tcp::{self, cong}; +use crate::time; +use crate::{pr_err, pr_info}; +use core::cmp::min; + +/// The heuristic that is used to find the exit point for slow start. +pub enum HystartDetect { + /// Exits slow start when the length of so-called ACK-trains becomes equal + /// to the estimated minimum forward path one-way delay. + AckTrain = 1, + /// Exits slow start when the estimated RTT increase between two consecutive + /// rounds exceeds a threshold that is based on the last RTT. + Delay = 2, + /// Combine both algorithms. + Both = 3, +} + +/// Internal state of the [`HyStart`] algorithm. +pub struct HyStartState { + /// Number of ACKs already sampled to determine the RTT of this round. + sample_cnt: u8, + /// Whether the slow start exit point was found. + found: bool, + /// Time when the current round has started. + round_start: time::Usecs32, + /// Sequence number of the byte that marks the end of the current round. + end_seq: u32, + /// Time when the last ACK was received in this round. + last_ack: time::Usecs32, + /// The minimum RTT of the current round. + curr_rtt: time::Usecs32, + /// Estimate of the minimum forward path one-way delay of the link. + pub delay_min: Option, + /// Time when the connection was created. + // TODO: remove + pub start_time: time::Usecs32, +} + +impl Default for HyStartState { + fn default() -> Self { + Self { + sample_cnt: 0, + found: false, + round_start: 0, + end_seq: 0, + last_ack: 0, + curr_rtt: 0, + delay_min: None, + // TODO: remove + start_time: time::ktime_get_boot_fast_us32(), + } + } +} + +impl HyStartState { + /// Returns true iff the algorithm `T` is in hybrid slow start. + #[inline] + pub fn in_hystart(&self, cwnd: u32) -> bool { + !self.found && cwnd >= T::LOW_WINDOW + } +} + +/// Implement this trait on [`Algorithm::Data`] to use [`HyStart`] for your CCA. +/// +/// [`Algorithm::Data`]: cong::Algorithm::Data +pub trait HasHyStartState { + /// Returns the private data of the HyStart algorithm. + fn hy(&self) -> &HyStartState; + + /// Returns the private data of the HyStart algorithm. + fn hy_mut(&mut self) -> &mut HyStartState; +} + +/// Implement this trait on your [`Algorithm`] to use HyStart. You still need to +/// invoke the [`reset`] and [`update`] methods at the right places. +/// +/// [`Algorithm`]: cong::Algorithm +/// [`reset`]: HyStart::reset +/// [`update`]: HyStart::update +pub trait HyStart: cong::Algorithm { + // TODO: Those constants should be configurable via module parameters. + /// Which heuristic to use for deciding when it is time to exit slow start. + const DETECT: HystartDetect; + + /// Lower bound for cwnd during hybrid slow start. + const LOW_WINDOW: u32; + + /// Max spacing between ACKs in an ACK-train. + const ACK_DELTA: time::Usecs32; + + /// Number of ACKs to sample at the beginning of each round to estimate the + /// RTT of this round. + const MIN_SAMPLES: u8 = 8; + + /// Lower bound on the increase in RTT between to consecutive rounds that is + /// needed to trigger an exit from slow start. + const DELAY_MIN: time::Usecs32 = 4000; + + /// Upper bound on the increase in RTT between to consecutive rounds that is + /// needed to trigger an exit from slow start. + const DELAY_MAX: time::Usecs32 = 16000; + + /// Corresponds to the function eta from the paper. Returns the increase in + /// RTT between consecutive rounds that triggers and exit from slow start. + /// `t` is the RTT of the last round. + fn delay_thresh(mut t: time::Usecs32) -> time::Usecs32 { + t >>= 3; + + if t < Self::DELAY_MIN { + Self::DELAY_MIN + } else if t > Self::DELAY_MAX { + Self::DELAY_MAX + } else { + t + } + } + + /// TODO + fn ack_delay(sk: &cong::Sock<'_, Self>) -> time::Usecs32 { + (match sk.sk_pacing_rate() { + 0 => 0, + rate => min( + time::USEC_PER_MSEC, + ((sk.sk_gso_max_size() as u64) * 4 * time::USEC_PER_SEC) / rate, + ), + } as time::Usecs32) + } + + /// Called in slow start at the beginning of a new round of incoming ACKs. + fn reset(sk: &mut cong::Sock<'_, Self>) { + let tp = sk.tcp_sk(); + let now = tp.tcp_mstamp() as time::Usecs32; + let snd_nxt = tp.snd_nxt(); + + let hy = sk.inet_csk_ca_mut().hy_mut(); + + hy.round_start = now; + hy.last_ack = now; + hy.end_seq = snd_nxt; + hy.curr_rtt = u32::MAX; + hy.sample_cnt = 0; + } + + /// Called in slow start to decide if it is time to exit slow start. Sets + /// [`HyStartState`] `found` to true when it is time to exit. + fn update(sk: &mut cong::Sock<'_, Self>, delay: time::Usecs32) { + // Start of a new round. + if tcp::after(sk.tcp_sk().snd_una(), sk.inet_csk_ca().hy().end_seq) { + Self::reset(sk); + } + let hy = sk.inet_csk_ca().hy(); + let Some(delay_min) = hy.delay_min else { + // This should not happen. + pr_err!("hystart: update: delay_min was None"); + return; + }; + + if matches!(Self::DETECT, HystartDetect::Both | HystartDetect::AckTrain) { + let tp = sk.tcp_sk(); + let now = tp.tcp_mstamp() as time::Usecs32; + + // Is this ACK part of a train? + // NOTE: I don't get it. C is doing this as a signed comparison but + // for: + // -- `0 <= now < ca->last_ack <= 0x7F..F` this means it always + // passes, + // -- `ca->last_ack = 0x80..0` and `0 <= new <= 0x7F..F` it also + // always passes, + // -- `0x80..00 < ca->last_ack` and `now < 0x80.0` (big enough) + // also always passes. + // If I understand the paper correctly, this is not what is + // intended. What we really want here is the unsigned version I + // guess, please correct me if I am wrong. + // Commit: c54b4b7655447c1f24f6d50779c22eba9ee0fd24 + // Purposefully introduced the cast ... am I just stupid? + // Link: https://godbolt.org/z/E7ocxae69 + if now.wrapping_sub(hy.last_ack) <= Self::ACK_DELTA { + let threshold = if let Ok(sock::Pacing::r#None) = sk.sk_pacing_status() { + (delay_min + Self::ack_delay(sk)) >> 1 + } else { + delay_min + Self::ack_delay(sk) + }; + + // Does the length of this ACK-train indicate it is time to + // exit slow start? + // NOTE: C is a bit weird here ... `threshold` is unsigned but + // the lhs is still cast to signed, even though the usual + // arithmetic conversions will immediately cast it back to + // unsigned; thus, I guess we can just do everything unsigned. + if now.wrapping_sub(hy.round_start) > threshold { + // TODO: change to debug + pr_info!( + "hystart_ack_train ({}us > {}us) delay_min {}us (+ ack_delay {}us) cwnd {}, start {}us", + now.wrapping_sub(hy.round_start), + threshold, + delay_min, + Self::ack_delay(sk), + tp.snd_cwnd(), + hy.start_time + ); + + let tp = sk.tcp_sk_mut(); + + tp.set_snd_ssthresh(tp.snd_cwnd()); + + sk.inet_csk_ca_mut().hy_mut().found = true; + + // TODO: Update net stats. + } + + sk.inet_csk_ca_mut().hy_mut().last_ack = now; + } + } + + if matches!(Self::DETECT, HystartDetect::Both | HystartDetect::Delay) { + let hy = sk.inet_csk_ca_mut().hy_mut(); + + // The paper only takes the min RTT of the first `MIN_SAMPLES` + // ACKs in a round, but it does no harm to consider later ACKs as + // well. + if hy.curr_rtt > delay { + hy.curr_rtt = delay + } + + if hy.sample_cnt < Self::MIN_SAMPLES { + hy.sample_cnt += 1; + } else { + // Does the increase in RTT indicate its time to exit slow + // start? + if hy.curr_rtt > delay_min + Self::delay_thresh(delay_min) { + hy.found = true; + + // TODO: change to debug + let curr_rtt = hy.curr_rtt; + let start_time = hy.start_time; + pr_info!( + "hystart_delay: {}us > {}us, delay_min {}us (+ delay_thresh {}us), cwnd {}, start {}us", + curr_rtt, + delay_min + Self::delay_thresh(delay_min), + delay_min, + Self::delay_thresh(delay_min), + sk.tcp_sk().snd_cwnd(), + start_time, + ); + // TODO: Update net stats. + + let tp = sk.tcp_sk_mut(); + + tp.set_snd_ssthresh(tp.snd_cwnd()); + } + } + } + } +} diff --git a/rust/kernel/time.rs b/rust/kernel/time.rs index 25a896eed4689f..875fafd44f51ee 100644 --- a/rust/kernel/time.rs +++ b/rust/kernel/time.rs @@ -8,9 +8,43 @@ /// The time unit of Linux kernel. One jiffy equals (1/HZ) second. pub type Jiffies = core::ffi::c_ulong; +/// Jiffies, but with a fixed width of 32bit. +pub type Jiffies32 = u32; + /// The millisecond time unit. pub type Msecs = core::ffi::c_uint; +/// Milliseconds per second. +pub const MSEC_PER_SEC: Msecs = 1000; + +/// The milliseconds time unit with a fixed width of 32bit. +/// +/// This is used in networking. +pub type Msecs32 = u32; + +/// The microseconds time unit. +pub type Usecs = u64; + +/// Microseconds per millisecond. +pub const USEC_PER_MSEC: Usecs = 1000; + +/// Microseconds per second. +pub const USEC_PER_SEC: Usecs = 1_000_000; + +/// The microseconds time unit with a fixed width of 32bit. +/// +/// This is used in networking. +pub type Usecs32 = u32; + +/// The nanosecond time unit. +pub type Nsecs = u64; + +/// Nanoseconds per microsecond. +pub const NSEC_PER_USEC: Nsecs = 1000; + +/// Nanoseconds per millisecond. +pub const NSEC_PER_MSEC: Nsecs = 1_000_000; + /// Converts milliseconds to jiffies. #[inline] pub fn msecs_to_jiffies(msecs: Msecs) -> Jiffies { @@ -18,3 +52,40 @@ pub fn msecs_to_jiffies(msecs: Msecs) -> Jiffies { // matter what the argument is. unsafe { bindings::__msecs_to_jiffies(msecs) } } + +/// Converts jiffies to milliseconds. +#[inline] +pub fn jiffies_to_msecs(jiffies: Jiffies) -> Msecs { + // SAFETY: The `__msecs_to_jiffies` function is always safe to call no + // matter what the argument is. + unsafe { bindings::jiffies_to_msecs(jiffies) } +} + +/// Returns the current time in 32bit jiffies. +#[inline] +pub fn jiffies32() -> Jiffies32 { + // SAFETY: It is always atomic to read the lower 32bit of jiffies. + unsafe { bindings::jiffies as u32 } +} + +/// Returns the time elapsed since system boot, in nanoseconds. Does include the +/// time the system was suspended. +#[inline] +pub fn ktime_get_boot_fast_ns() -> Nsecs { + // SAFETY: FFI call without safety requirements. + unsafe { bindings::ktime_get_boot_fast_ns() } +} + +/// Returns the time elapsed since system boot, in 32bit microseconds. Does +/// include the time the system was suspended. +#[inline] +pub fn ktime_get_boot_fast_us32() -> Usecs32 { + (ktime_get_boot_fast_ns() / NSEC_PER_USEC) as Usecs32 +} + +/// Returns the time elapsed since system boot, in 32bit milliseconds. Does +/// include the time the system was suspended. +#[inline] +pub fn ktime_get_boot_fast_ms32() -> Msecs32 { + (ktime_get_boot_fast_ns() / NSEC_PER_MSEC) as Msecs32 +} diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index 8aabe348b19473..91ba53a783516a 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -240,14 +240,22 @@ impl Opaque { /// uninitialized. Additionally, access to the inner `T` requires `unsafe`, so the caller needs /// to verify at that point that the inner value is valid. pub fn ffi_init(init_func: impl FnOnce(*mut T)) -> impl PinInit { + Self::try_ffi_init(move |slot| { + init_func(slot); + Ok(()) + }) + } + + /// Similar to [`Self::ffi_init`], except that the closure can fail. + /// + /// To avoid leaks on failure, the closure must drop any fields it has initialised before the + /// failure. + pub fn try_ffi_init( + init_func: impl FnOnce(*mut T) -> Result<(), E>, + ) -> impl PinInit { // SAFETY: We contain a `MaybeUninit`, so it is OK for the `init_func` to not fully // initialize the `T`. - unsafe { - init::pin_init_from_closure::<_, ::core::convert::Infallible>(move |slot| { - init_func(Self::raw_get(slot)); - Ok(()) - }) - } + unsafe { init::pin_init_from_closure(|slot| init_func(Self::raw_get(slot))) } } /// Returns a raw pointer to the opaque data. @@ -390,3 +398,36 @@ pub enum Either { /// Constructs an instance of [`Either`] containing a value of type `R`. Right(R), } + +/// Returns the size of a struct field in bytes. +/// +/// This macro can be used in const contexts. +/// +/// # Examples +/// +/// ``` +/// use kernel::field_size; +/// +/// struct Foo { +/// bar: u64, +/// baz: [i8; 100], +/// } +/// +/// assert_eq!(field_size!(Foo, bar), 8); +/// assert_eq!(field_size!(Foo, baz), 100); +/// ``` +// Link: https://stackoverflow.com/a/70222282 +#[macro_export] +macro_rules! field_size { + ($t:ty, $field:ident) => {{ + let m = core::mem::MaybeUninit::<$t>::uninit(); + // SAFETY: It is OK to dereference invalid pointers inside of + // `addr_of!`. + let p = unsafe { core::ptr::addr_of!((*m.as_ptr()).$field) }; + + const fn size_of_raw(_: *const T) -> usize { + core::mem::size_of::() + } + size_of_raw(p) + }}; +} diff --git a/rust/macros/module.rs b/rust/macros/module.rs index d62d8710d77ab0..9152bd691c5a6d 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -208,7 +208,7 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[used] static __IS_RUST_MODULE: () = (); - static mut __MOD: Option<{type_}> = None; + static mut __MOD: core::mem::MaybeUninit<{type_}> = core::mem::MaybeUninit::uninit(); // SAFETY: `__this_module` is constructed by the kernel at load time and will not be // freed until the module is unloaded. @@ -270,23 +270,17 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { }} fn __init() -> core::ffi::c_int {{ - match <{type_} as kernel::Module>::init(&THIS_MODULE) {{ - Ok(m) => {{ - unsafe {{ - __MOD = Some(m); - }} - return 0; - }} - Err(e) => {{ - return e.to_errno(); - }} + let initer = <{type_} as kernel::InPlaceModule>::init(&THIS_MODULE); + match unsafe {{ initer.__pinned_init(__MOD.as_mut_ptr()) }} {{ + Ok(m) => 0, + Err(e) => e.to_errno(), }} }} fn __exit() {{ unsafe {{ // Invokes `drop()` on `__MOD`, which should be used for cleanup. - __MOD = None; + __MOD.assume_init_drop(); }} }} diff --git a/samples/rust/Kconfig b/samples/rust/Kconfig index b0f74a81c8f9ad..8b9cc6cc7d301d 100644 --- a/samples/rust/Kconfig +++ b/samples/rust/Kconfig @@ -30,6 +30,17 @@ config SAMPLE_RUST_PRINT If unsure, say N. +config SAMPLE_RUST_CCA + tristate "Congestion control algorithm" + depends on RUST_TCP_ABSTRACTIONS + help + This option builds the Rust congestion control algorithm sample. + + To compile this as a module, choose M here: + the module will be called rust_cca. + + If unsure, say N. + config SAMPLE_RUST_HOSTPROGS bool "Host programs" help diff --git a/samples/rust/Makefile b/samples/rust/Makefile index 03086dabbea44f..ee0b9bb7b6ab21 100644 --- a/samples/rust/Makefile +++ b/samples/rust/Makefile @@ -2,5 +2,6 @@ obj-$(CONFIG_SAMPLE_RUST_MINIMAL) += rust_minimal.o obj-$(CONFIG_SAMPLE_RUST_PRINT) += rust_print.o +obj-$(CONFIG_SAMPLE_RUST_CCA) += rust_cca.o subdir-$(CONFIG_SAMPLE_RUST_HOSTPROGS) += hostprogs diff --git a/samples/rust/rust_cca.rs b/samples/rust/rust_cca.rs new file mode 100644 index 00000000000000..4c092582112b07 --- /dev/null +++ b/samples/rust/rust_cca.rs @@ -0,0 +1,35 @@ +//! Congestion control algorithm example. +use core::num::NonZeroU32; +use kernel::net::tcp::cong::*; +use kernel::prelude::*; +use kernel::{c_str, module_cca}; + +struct MyCca {} + +#[vtable] +impl Algorithm for MyCca { + type Data = (); + + const NAME: &'static CStr = c_str!("my_cca"); + + fn undo_cwnd(sk: &mut Sock<'_, Self>) -> u32 { + reno::undo_cwnd(sk) + } + + fn ssthresh(_sk: &mut Sock<'_, Self>) -> u32 { + 2 + } + + fn cong_avoid(sk: &mut Sock<'_, Self>, _ack: u32, acked: u32) { + sk.tcp_sk_mut() + .cong_avoid_ai(NonZeroU32::new(1).unwrap(), acked) + } +} + +module_cca! { + type: MyCca, + name: "my_cca", + author: "Rust for Linux Contributors", + description: "Sample congestion control algorithm implemented in Rust.", + license: "GPL v2", +} diff --git a/scripts/generate_rust_analyzer.py b/scripts/generate_rust_analyzer.py index fc52bc41d3e7bd..4a687e36091eb5 100755 --- a/scripts/generate_rust_analyzer.py +++ b/scripts/generate_rust_analyzer.py @@ -116,7 +116,7 @@ def is_root_crate(build_file, target): # Then, the rest outside of `rust/`. # # We explicitly mention the top-level folders we want to cover. - extra_dirs = map(lambda dir: srctree / dir, ("samples", "drivers")) + extra_dirs = map(lambda dir: srctree / dir, ("samples", "drivers", "net")) if external_src is not None: extra_dirs = [external_src] for folder in extra_dirs: