From c35894fd36ab3a5936cadf75985e74e5baa327f5 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 12:29:04 +0200 Subject: [PATCH 01/18] feat: add CPU native compilation and instructions + add un/likely primitives (when RUST supports we should change) + consolidate checksum calculations (now support offload on all layers) --- .cargo/config.toml | 3 + lightway-core/src/utils.rs | 303 +++++++++++++++++++++++++++---------- 2 files changed, 223 insertions(+), 83 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index bda566bd..e08b7963 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,6 @@ [target.aarch64-unknown-linux-gnu] linker = "aarch64-linux-gnu-gcc" runner = ["qemu-aarch64-static"] # use qemu user emulation for cargo run and test + +[build] +rustflags = ["-C", "target-cpu=native"] \ No newline at end of file diff --git a/lightway-core/src/utils.rs b/lightway-core/src/utils.rs index 3215572f..13939c1f 100644 --- a/lightway-core/src/utils.rs +++ b/lightway-core/src/utils.rs @@ -9,17 +9,56 @@ use std::net::Ipv4Addr; use std::ops; use tracing::warn; +// #[cfg(target_arch = "x86_64")] +// use std::arch::x86_64::*; + +// // Check if AVX2 is available on the current CPU (unused until we support IPv6) +// #[inline(always)] +// fn has_avx2() -> bool { +// #[cfg(target_arch = "x86_64")] +// { +// is_x86_feature_detected!("avx2") +// } +// #[cfg(not(target_arch = "x86_64"))] +// { +// false +// } +// } + +/** + * HOT/COLD path implementation until RUST adds + * https://github.com/rust-lang/rust/issues/26179 + */ + +#[inline] +#[cold] +fn cold() {} + +#[inline] +pub(crate) fn likely(b: bool) -> bool { + if !b { cold() } + b +} + +#[inline] +pub(crate) fn unlikely(b: bool) -> bool { + if b { cold() } + b +} + +/// Validate if a buffer contains a valid IPv4 packet pub(crate) fn ipv4_is_valid_packet(buf: &[u8]) -> bool { - if buf.is_empty() { + if buf.len() < 20 { + // IPv4 header is at least 20 bytes return false; } let first_byte = buf[0]; let ip_version = first_byte >> 4; - ip_version == 4 } -// Structure to calculate incremental checksum +/// Structure to calculate incremental checksum +#[derive(Clone, Copy)] struct Checksum(u16); impl ops::Deref for Checksum { @@ -33,122 +72,220 @@ impl ops::Sub for Checksum { type Output = Checksum; fn sub(self, rhs: u16) -> Checksum { let (n, of) = self.0.overflowing_sub(rhs); - Checksum(match of { - true => n - 1, - false => n, - }) + Checksum(if of { n.wrapping_sub(1) } else { n }) } } +/// Structure to handle checksum updates when modifying IP addresses +struct ChecksumUpdate(Vec<(u16, u16)>); + impl Checksum { - // Based on RFC-1624 [Eqn. 4] + /// Update checksum when replacing one word with another + /// Based on RFC-1624 [Eqn. 4] fn update_word(self, old_word: u16, new_word: u16) -> Self { self - !old_word - new_word } + /// Apply multiple checksum updates fn update(self, updates: &ChecksumUpdate) -> Self { - updates.0.iter().fold(self, |c, x| c.update_word(x.0, x.1)) + updates + .0 + .iter() + .fold(self, |c, &(old, new)| c.update_word(old, new)) } -} -struct ChecksumUpdate(Vec<(u16, u16)>); + // AVX2-accelerated checksum update (unused until we support IPv6) + // #[allow(unsafe_code)] + // #[cfg(target_arch = "x86_64")] + // #[target_feature(enable = "avx2")] + // unsafe fn update_avx2(self, updates: &ChecksumUpdate) -> Self { + // let mut sum = u32::from(self.0); + + // // Process 8 words at a time using AVX2 + // for chunk in updates.0.chunks(8) { + // // Pre-allocate with known size + // let mut old_words = Vec::with_capacity(8); + // let mut new_words = Vec::with_capacity(8); + + // // Fill vectors with data or zeros + // for i in 0..8 { + // if let Some(&(old, new)) = chunk.get(i) { + // old_words.push(i32::from(old)); + // new_words.push(i32::from(new)); + // } else { + // old_words.push(0); + // new_words.push(0); + // } + // } + + // // SAFETY: Vectors are guaranteed to have exactly 8 elements + // unsafe { + // // Load data into AVX2 registers + // let old_vec = _mm256_set_epi32( + // old_words[7], + // old_words[6], + // old_words[5], + // old_words[4], + // old_words[3], + // old_words[2], + // old_words[1], + // old_words[0], + // ); + // let new_vec = _mm256_set_epi32( + // new_words[7], + // new_words[6], + // new_words[5], + // new_words[4], + // new_words[3], + // new_words[2], + // new_words[1], + // new_words[0], + // ); + + // // Compute NOT(old) + new using AVX2 + // let not_old = _mm256_xor_si256(old_vec, _mm256_set1_epi32(-1)); + // let sum_vec = _mm256_add_epi32(not_old, new_vec); + + // // Horizontal sum + // let hadd = _mm256_hadd_epi32(sum_vec, sum_vec); + // let hadd = _mm256_hadd_epi32(hadd, hadd); + + // sum = sum.wrapping_add(_mm256_extract_epi32(hadd, 0) as u32); + // } + // } + + // // Fold 32-bit sum to 16 bits + // while sum > 0xFFFF { + // sum = (sum & 0xFFFF) + (sum >> 16); + // } + + // Checksum(sum as u16) + // } +} impl ChecksumUpdate { + /// Create checksum update data from IP address change fn from_ipv4_address(old: Ipv4Addr, new: Ipv4Addr) -> Self { - let mut result = vec![]; - let old: [u8; 4] = old.octets(); - let new: [u8; 4] = new.octets(); - for i in 0..2 { - let old_word = u16::from_be_bytes([old[i * 2], old[i * 2 + 1]]); - let new_word = u16::from_be_bytes([new[i * 2], new[i * 2 + 1]]); - result.push((old_word, new_word)); - } - Self(result) + let old_bytes = old.octets(); + let new_bytes = new.octets(); + + // Convert to u16 pairs for checksum calculation + let old_words = [ + u16::from_be_bytes([old_bytes[0], old_bytes[1]]), + u16::from_be_bytes([old_bytes[2], old_bytes[3]]), + ]; + let new_words = [ + u16::from_be_bytes([new_bytes[0], new_bytes[1]]), + u16::from_be_bytes([new_bytes[2], new_bytes[3]]), + ]; + + Self(vec![ + (old_words[0], new_words[0]), + (old_words[1], new_words[1]), + ]) } } -fn tcp_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) { - let packet = MutableTcpPacket::new(packet.payload_mut()); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Tcp header)!"); +/// Update transport protocol checksums after IP address changes +fn update_transport_checksums(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) { + // Skip if this is not the first fragment + if packet.get_fragment_offset() != 0 { return; - }; - - let checksum = Checksum(packet.get_checksum()); - let checksum = checksum.update(&updates); - packet.set_checksum(*checksum); -} - -fn udp_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) { - let packet = MutableUdpPacket::new(packet.payload_mut()); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Udp header)!"); - return; - }; - - let checksum = Checksum(packet.get_checksum()); + } - // UDP checksums are optional, and we should respect that when doing NAT - if *checksum != 0 { - let checksum = checksum.update(&updates); - packet.set_checksum(checksum.0); + match packet.get_next_level_protocol() { + IpNextHeaderProtocols::Tcp => update_tcp_checksum(packet, updates), + IpNextHeaderProtocols::Udp => update_udp_checksum(packet, updates), + IpNextHeaderProtocols::Icmp => {} // ICMP doesn't need checksum update for IP changes + protocol => { + if unlikely(true) { + warn!(protocol = ?protocol, "Unknown protocol, skipping checksum update"); + } + }, } } -fn ipv4_adjust_packet_checksum(mut packet: MutableIpv4Packet, updates: ChecksumUpdate) { - let checksum = Checksum(packet.get_checksum()); - let checksum = checksum.update(&updates); - packet.set_checksum(*checksum); - - // In case of fragmented packets, TCP/UDP header will be present only in the first fragment. - // So skip updating the checksum, if it is not the first fragment (i.e frag_offset != 0) - if 0 != packet.get_fragment_offset() { - return; +fn update_tcp_checksum(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) { + if likely(MutableTcpPacket::new(packet.payload_mut()).is_some()) { + let mut tcp_packet = MutableTcpPacket::new(packet.payload_mut()).unwrap(); + let checksum = tcp_packet.get_checksum(); + // Only update if checksum is present (not 0) + if checksum != 0 { + let checksum = Checksum(checksum).update(&updates); + tcp_packet.set_checksum(*checksum); + } + } else { + warn!("Invalid packet size (less than TCP header)!"); } +} - let transport_protocol = packet.get_next_level_protocol(); - match transport_protocol { - IpNextHeaderProtocols::Tcp => tcp_adjust_packet_checksum(packet, updates), - IpNextHeaderProtocols::Udp => udp_adjust_packet_checksum(packet, updates), - IpNextHeaderProtocols::Icmp => {} - protocol => { - warn!(protocol = ?protocol, "Unknown protocol, skipping checksum adjust") +fn update_udp_checksum(packet: &mut MutableIpv4Packet, updates: ChecksumUpdate) { + if likely(MutableUdpPacket::new(packet.payload_mut()).is_some()) { + let mut udp_packet = MutableUdpPacket::new(packet.payload_mut()).unwrap(); + let checksum = udp_packet.get_checksum(); + // Only update if checksum is present (not 0) + if checksum != 0 { + let checksum = Checksum(checksum).update(&updates); + udp_packet.set_checksum(*checksum); } + } else { + warn!("Invalid packet size (less than UDP header)!"); } } -/// Utility function to update source ip address in ipv4 packet buffer -/// Nop if buf is not a valid IPv4 packet -pub fn ipv4_update_source(buf: &mut [u8], ip: Ipv4Addr) { - let packet = MutableIpv4Packet::new(buf); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Ipv4 header)!"); +#[derive(Clone, Copy)] +enum IpField { + Source, + Destination, +} + +// NOTE: the field is compile-time known, so gets optimized, this is for better maintanance +#[inline(always)] +fn ipv4_update_field(buf: &mut [u8], new_ip: Ipv4Addr, field: IpField) { + let Some(mut packet) = MutableIpv4Packet::new(buf) else { + if unlikely(true) { + warn!("Failed to create IPv4 packet!"); + } return; }; - let old = packet.get_source(); - // Set new source only after getting old source ip address - packet.set_source(ip); - - ipv4_adjust_packet_checksum(packet, ChecksumUpdate::from_ipv4_address(old, ip)); -} + // Get old IP before updating + let old_ip = match field { + IpField::Source => packet.get_source(), + IpField::Destination => packet.get_destination(), + }; -/// Utility function to update destination ip address in ipv4 packet buffer -/// Nop if buf is not a valid IPv4 packet -pub fn ipv4_update_destination(buf: &mut [u8], ip: Ipv4Addr) { - let packet = MutableIpv4Packet::new(buf); - let Some(mut packet) = packet else { - warn!("Invalid packet size (less than Ipv4 header)!"); - return; + // Update IP field + match field { + IpField::Source => packet.set_source(new_ip), + IpField::Destination => packet.set_destination(new_ip), }; - let old = packet.get_destination(); - // Set new destination only after getting old destination ip address - packet.set_destination(ip); + // Update checksums + let updates = ChecksumUpdate::from_ipv4_address(old_ip, new_ip); + let checksum = packet.get_checksum(); + if checksum != 0 { + let checksum = Checksum(checksum).update(&updates); + packet.set_checksum(*checksum); + } + + // Update transport protocol checksums + update_transport_checksums(&mut packet, updates); +} + +/// Update source IP address in an IPv4 packet +#[inline] +pub fn ipv4_update_source(buf: &mut [u8], new_ip: Ipv4Addr) { + ipv4_update_field(buf, new_ip, IpField::Source) +} - ipv4_adjust_packet_checksum(packet, ChecksumUpdate::from_ipv4_address(old, ip)); +/// Update destination IP address in an IPv4 packet +#[inline] +pub fn ipv4_update_destination(buf: &mut [u8], new_ip: Ipv4Addr) { + ipv4_update_field(buf, new_ip, IpField::Destination) } +/// Clamp TCP MSS option if present in a TCP SYN packet pub fn tcp_clamp_mss(pkt: &mut [u8], mss: u16) -> Option { let mut ipv4_packet = MutableIpv4Packet::new(pkt)?; @@ -177,7 +314,7 @@ pub fn tcp_clamp_mss(pkt: &mut [u8], mss: u16) -> Option { } [bytes[0], bytes[1]] = mss.to_be_bytes(); - tcp_adjust_packet_checksum(ipv4_packet, ChecksumUpdate(vec![(existing_mss, mss)])); + update_tcp_checksum(&mut ipv4_packet, ChecksumUpdate(vec![(existing_mss, mss)])); return Some(existing_mss); } let start = std::cmp::min(option.packet_size(), option_raw.len()); From 82b53f490584ca7232badcc72c05e6b44f0368e3 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 12:36:37 +0200 Subject: [PATCH 02/18] feat: in-place replacement for std::Mutex usage + introduce parking_lot for better performence lock primitives + introduce DashMap for faster atomic hashmap management + adapt API usage to new primitives (no need to unwrap) --- Cargo.lock | 71 +++++++++++++++++++++++ lightway-core/Cargo.toml | 1 + lightway-core/src/connection.rs | 5 +- lightway-core/src/connection/builders.rs | 2 +- lightway-core/src/context.rs | 3 +- lightway-server/Cargo.toml | 2 + lightway-server/src/connection.rs | 13 ++--- lightway-server/src/connection_manager.rs | 34 +++++------ 8 files changed, 99 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7584deae..e0173d59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -592,6 +592,20 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "delegate" version = "0.12.0" @@ -960,6 +974,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -1276,6 +1296,7 @@ dependencies = [ "more-asserts", "num_enum", "once_cell", + "parking_lot", "pnet", "rand", "rand_core", @@ -1307,6 +1328,7 @@ dependencies = [ "bytesize", "clap", "ctrlc", + "dashmap", "delegate", "educe", "ipnet", @@ -1317,6 +1339,7 @@ dependencies = [ "metrics", "metrics-util", "more-asserts", + "parking_lot", "pnet", "ppp", "pwhash", @@ -1342,6 +1365,16 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.25" @@ -1579,6 +1612,29 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "paste" version = "1.0.15" @@ -1867,6 +1923,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.11.1" @@ -1935,6 +2000,12 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.217" diff --git a/lightway-core/Cargo.toml b/lightway-core/Cargo.toml index d4037be0..4c8cebee 100644 --- a/lightway-core/Cargo.toml +++ b/lightway-core/Cargo.toml @@ -29,6 +29,7 @@ metrics.workspace = true more-asserts.workspace = true num_enum = "0.7.0" once_cell = "1.19.0" +parking_lot = "0.12" pnet.workspace = true rand.workspace = true rand_core = "0.6.4" diff --git a/lightway-core/src/connection.rs b/lightway-core/src/connection.rs index e8c1d558..d84f2968 100644 --- a/lightway-core/src/connection.rs +++ b/lightway-core/src/connection.rs @@ -6,13 +6,14 @@ mod io_adapter; mod key_update; use bytes::{Bytes, BytesMut}; +use parking_lot::Mutex; use rand::Rng; use std::borrow::Cow; use std::net::AddrParseError; use std::num::{NonZeroU16, Wrapping}; use std::{ net::SocketAddr, - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, Instant}, }; use thiserror::Error; @@ -936,7 +937,7 @@ impl Connection { ref mut pending_session_id, .. } => { - let new_session_id = rng.lock().unwrap().r#gen(); + let new_session_id = rng.lock().r#gen(); self.session.io_cb_mut().set_session_id(new_session_id); diff --git a/lightway-core/src/connection/builders.rs b/lightway-core/src/connection/builders.rs index 7a55338d..9f44d11d 100644 --- a/lightway-core/src/connection/builders.rs +++ b/lightway-core/src/connection/builders.rs @@ -282,7 +282,7 @@ impl<'a, AppState: Send + 'static> ServerConnectionBuilder<'a, AppState> { let auth = ctx.auth.clone(); let ip_pool = ctx.ip_pool.clone(); - let session_id = ctx.rng.lock().unwrap().r#gen(); + let session_id = ctx.rng.lock().r#gen(); let outside_mtu = MAX_OUTSIDE_MTU; let outside_plugins = ctx.outside_plugins.build()?; diff --git a/lightway-core/src/context.rs b/lightway-core/src/context.rs index 6dc21580..9b648227 100644 --- a/lightway-core/src/context.rs +++ b/lightway-core/src/context.rs @@ -1,8 +1,9 @@ pub mod ip_pool; mod server_auth; +use parking_lot::Mutex; use rand::SeedableRng; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use thiserror::Error; use crate::{ diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 8b5bec3f..92c88753 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -24,6 +24,7 @@ bytes.workspace = true bytesize.workspace = true clap.workspace = true ctrlc.workspace = true +dashmap = "6.1.0" delegate.workspace = true educe.workspace = true ipnet.workspace = true @@ -33,6 +34,7 @@ lightway-app-utils.workspace = true lightway-core = { workspace = true, features = ["postquantum"] } metrics.workspace = true metrics-util = "0.18.0" +parking_lot = "0.12.3" pnet.workspace = true ppp = "2.2.0" pwhash = "1.0.0" diff --git a/lightway-server/src/connection.rs b/lightway-server/src/connection.rs index efe3b430..49b6ff55 100644 --- a/lightway-server/src/connection.rs +++ b/lightway-server/src/connection.rs @@ -1,8 +1,9 @@ use bytes::BytesMut; use delegate::delegate; +use parking_lot::Mutex; use std::{ net::{Ipv4Addr, SocketAddr}, - sync::{Arc, Mutex, Weak}, + sync::{Arc, Weak}, }; use tracing::{trace, warn}; @@ -80,11 +81,9 @@ impl Connection { conn.lw_conn .lock() - .unwrap() .app_state_mut() .conn .set(Arc::downgrade(&conn)) - .unwrap(); let mut join_set = tokio::task::JoinSet::new(); ticker_task.spawn(Arc::downgrade(&conn), &mut join_set); @@ -93,7 +92,7 @@ impl Connection { } delegate! { - to self.lw_conn.lock().unwrap() { + to self.lw_conn.lock() { pub fn tls_protocol_version(&self) -> ProtocolVersion; pub fn connection_type(&self) -> ConnectionType; pub fn session_id(&self) -> SessionId; @@ -145,7 +144,7 @@ impl Connection { } pub fn begin_session_id_rotation(self: &Arc) { - let mut conn = self.lw_conn.lock().unwrap(); + let mut conn = self.lw_conn.lock(); // A rotation is already in flight, nothing to be done this // time. @@ -171,13 +170,13 @@ impl Connection { // Use this only during shutdown, after clearing all connections from // connection_manager pub fn lw_disconnect(self: Arc) -> ConnectionResult<()> { - self.lw_conn.lock().unwrap().disconnect() + self.lw_conn.lock().disconnect() } pub fn disconnect(&self) -> ConnectionResult<()> { metrics::connection_closed(); self.manager.remove_connection(self); - self.lw_conn.lock().unwrap().disconnect() + self.lw_conn.lock().disconnect() } } diff --git a/lightway-server/src/connection_manager.rs b/lightway-server/src/connection_manager.rs index 9682b77c..a4c957a9 100644 --- a/lightway-server/src/connection_manager.rs +++ b/lightway-server/src/connection_manager.rs @@ -1,11 +1,12 @@ mod connection_map; +use dashmap::DashMap; use delegate::delegate; +use parking_lot::Mutex; use std::{ - collections::HashMap, net::SocketAddr, sync::{ - Arc, Mutex, Weak, + Arc, Weak, atomic::{AtomicUsize, Ordering}, }, }; @@ -107,7 +108,7 @@ async fn evict_expired_connections(manager: Weak) { pub(crate) struct ConnectionManager { ctx: ServerContext, connections: Mutex>, - pending_session_id_rotations: Mutex>>, + pending_session_id_rotations: DashMap>, /// Total number of sessions there have ever been total_sessions: AtomicUsize, } @@ -230,7 +231,7 @@ impl ConnectionManager { let conn_manager = Arc::new(Self { ctx, connections: Mutex::new(Default::default()), - pending_session_id_rotations: Mutex::new(Default::default()), + pending_session_id_rotations: Default::default(), total_sessions: Default::default(), }); @@ -252,7 +253,7 @@ impl ConnectionManager { } pub(crate) fn pending_session_id_rotations_count(&self) -> usize { - self.pending_session_id_rotations.lock().unwrap().len() + self.pending_session_id_rotations.len() } pub(crate) fn create_streaming_connection( @@ -269,7 +270,7 @@ impl ConnectionManager { outside_io, )?; // TODO: what if addr was already present? - self.connections.lock().unwrap().insert(&conn)?; + self.connections.lock().insert(&conn)?; Ok(conn) } @@ -302,7 +303,7 @@ impl ConnectionManager { where F: FnOnce() -> OutsideIOSendCallbackArg, { - match self.connections.lock().unwrap().lookup(addr, session_id) { + match self.connections.lock().lookup(addr, session_id) { connection_map::Entry::Occupied(c) => { if session_id == SessionId::EMPTY || c.session_id() == session_id { let update_peer_address = addr != c.peer_addr(); @@ -330,8 +331,6 @@ impl ConnectionManager { // Maybe this is a pending session rotation if let Some(c) = self .pending_session_id_rotations - .lock() - .unwrap() .get(&session_id) { let update_peer_address = addr != c.peer_addr(); @@ -349,19 +348,18 @@ impl ConnectionManager { self: &Arc, addr: SocketAddr, ) -> Option> { - self.connections.lock().unwrap().find_by(addr) + self.connections.lock().find_by(addr) } pub(crate) fn set_peer_addr(&self, conn: &Arc, new_addr: SocketAddr) { let old_addr = conn.set_peer_addr(new_addr); self.connections .lock() - .unwrap() .update_socketaddr_for_connection(old_addr, new_addr); } pub(crate) fn remove_connection(&self, conn: &Connection) { - self.connections.lock().unwrap().remove(conn) + self.connections.lock().remove(conn) } pub(crate) fn begin_session_id_rotation( @@ -370,8 +368,6 @@ impl ConnectionManager { new_session_id: SessionId, ) { self.pending_session_id_rotations - .lock() - .unwrap() .insert(new_session_id, conn.clone()); metrics::udp_session_rotation_begin(); @@ -384,12 +380,9 @@ impl ConnectionManager { new: SessionId, ) { self.pending_session_id_rotations - .lock() - .unwrap() .remove(&new); self.connections .lock() - .unwrap() .update_session_id_for_connection(old, new); metrics::udp_session_rotation_finalized(); @@ -398,7 +391,6 @@ impl ConnectionManager { pub(crate) fn online_connection_activity(&self) -> Vec { self.connections .lock() - .unwrap() .iter_connections() .filter_map(|c| match c.state() { State::Online => Some(c.activity()), @@ -411,7 +403,7 @@ impl ConnectionManager { fn evict_idle_connections(&self) { tracing::trace!("Aging connections"); - for conn in self.connections.lock().unwrap().iter_connections() { + for conn in self.connections.lock().iter_connections() { let age = conn.activity().last_outside_data_received.elapsed(); if age > CONNECTION_MAX_IDLE_AGE { tracing::info!(session = ?conn.session_id(), age = ?age, "Disconnecting idle connection"); @@ -431,7 +423,7 @@ impl ConnectionManager { fn evict_expired_connections(&self) { tracing::trace!("Expiring connections"); - for conn in self.connections.lock().unwrap().iter_connections() { + for conn in self.connections.lock().iter_connections() { let Ok(expired) = conn.authentication_expired() else { continue; }; @@ -449,7 +441,7 @@ impl ConnectionManager { } pub(crate) fn close_all_connections(&self) { - let connections = self.connections.lock().unwrap().remove_connections(); + let connections = self.connections.lock().remove_connections(); for conn in connections { let _ = conn.lw_disconnect(); } From 1312151d6e4697339c44bed1f54238a92deda896 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 14:33:06 +0200 Subject: [PATCH 03/18] lint and format --- lightway-core/src/utils.rs | 17 ++++++++++------- lightway-server/src/connection.rs | 1 + lightway-server/src/connection_manager.rs | 8 ++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lightway-core/src/utils.rs b/lightway-core/src/utils.rs index 13939c1f..a4fcec9b 100644 --- a/lightway-core/src/utils.rs +++ b/lightway-core/src/utils.rs @@ -25,10 +25,9 @@ use tracing::warn; // } // } -/** - * HOT/COLD path implementation until RUST adds - * https://github.com/rust-lang/rust/issues/26179 - */ +//! +// HOT/COLD path implementation until RUST adds +// https://github.com/rust-lang/rust/issues/26179 #[inline] #[cold] @@ -36,13 +35,17 @@ fn cold() {} #[inline] pub(crate) fn likely(b: bool) -> bool { - if !b { cold() } + if !b { + cold() + } b } #[inline] pub(crate) fn unlikely(b: bool) -> bool { - if b { cold() } + if b { + cold() + } b } @@ -201,7 +204,7 @@ fn update_transport_checksums(packet: &mut MutableIpv4Packet, updates: ChecksumU if unlikely(true) { warn!(protocol = ?protocol, "Unknown protocol, skipping checksum update"); } - }, + } } } diff --git a/lightway-server/src/connection.rs b/lightway-server/src/connection.rs index 49b6ff55..302dc9ee 100644 --- a/lightway-server/src/connection.rs +++ b/lightway-server/src/connection.rs @@ -84,6 +84,7 @@ impl Connection { .app_state_mut() .conn .set(Arc::downgrade(&conn)) + .unwrap(); let mut join_set = tokio::task::JoinSet::new(); ticker_task.spawn(Arc::downgrade(&conn), &mut join_set); diff --git a/lightway-server/src/connection_manager.rs b/lightway-server/src/connection_manager.rs index a4c957a9..a15ff9da 100644 --- a/lightway-server/src/connection_manager.rs +++ b/lightway-server/src/connection_manager.rs @@ -329,10 +329,7 @@ impl ConnectionManager { } connection_map::Entry::Vacant(_e) => { // Maybe this is a pending session rotation - if let Some(c) = self - .pending_session_id_rotations - .get(&session_id) - { + if let Some(c) = self.pending_session_id_rotations.get(&session_id) { let update_peer_address = addr != c.peer_addr(); return Ok((c.clone(), update_peer_address)); @@ -379,8 +376,7 @@ impl ConnectionManager { old: SessionId, new: SessionId, ) { - self.pending_session_id_rotations - .remove(&new); + self.pending_session_id_rotations.remove(&new); self.connections .lock() .update_session_id_for_connection(old, new); From afa3b99515181dfbf70938d7d45943ce6761505a Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 14:44:25 +0200 Subject: [PATCH 04/18] fix test/lint --- lightway-core/src/utils.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightway-core/src/utils.rs b/lightway-core/src/utils.rs index a4fcec9b..7b4ada34 100644 --- a/lightway-core/src/utils.rs +++ b/lightway-core/src/utils.rs @@ -25,7 +25,6 @@ use tracing::warn; // } // } -//! // HOT/COLD path implementation until RUST adds // https://github.com/rust-lang/rust/issues/26179 @@ -391,8 +390,9 @@ mod tests { ]; #[test_case(&[] => false; "empty")] - #[test_case(&[0x40] => true; "v4")] - #[test_case(&[0x60] => false; "v6")] + #[test_case(&[0x40; 19] => false; "buffer too small")] + #[test_case(&[0x45; 20] => true; "minimum valid v4")] + #[test_case(&[0x60; 20] => false; "v6 header")] #[test_case(SOURCE_1_DEST_1 => true; "SOURCE_1_TO_DEST_1")] #[test_case(SOURCE_1_DEST_2 => true; "SOURCE_1_TO_DEST_2")] #[test_case(SOURCE_2_DEST_1 => true; "SOURCE_2_TO_DEST_1")] From b955150bb80d2a717ba93b42d005354c609b0b93 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 14:56:16 +0200 Subject: [PATCH 05/18] fixup: force hashbrown version to allow dashmap --- Cargo.lock | 5 +++++ lightway-server/Cargo.toml | 1 + 2 files changed, 6 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index e0173d59..a61e729b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -979,6 +979,10 @@ name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "hashbrown" @@ -1331,6 +1335,7 @@ dependencies = [ "dashmap", "delegate", "educe", + "hashbrown 0.14.5", "ipnet", "jsonwebtoken", "libc", diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 92c88753..7010cb72 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -27,6 +27,7 @@ ctrlc.workspace = true dashmap = "6.1.0" delegate.workspace = true educe.workspace = true +hashbrown = "=0.14.5" # Force older version for dashmap compatibility ipnet.workspace = true jsonwebtoken = "9.3.0" libc.workspace = true From 7d279b7f25294c4b8be732d1f090a62aad094090 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 15:07:59 +0200 Subject: [PATCH 06/18] fixup: match test to runtime validation code --- lightway-core/tests/connection.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lightway-core/tests/connection.rs b/lightway-core/tests/connection.rs index d67845ab..54616674 100644 --- a/lightway-core/tests/connection.rs +++ b/lightway-core/tests/connection.rs @@ -377,7 +377,7 @@ async fn client( assert!(matches!(client.state(), State::Online)); assert!(message_sent); - assert_eq!(&buf[..], b"\x40Hello World!"); + assert_eq!(&buf[..], b"\x40Hello World!But Bigger"); let curve = client.current_curve().unwrap(); assert_eq!(curve, pqc.expected_curve()); @@ -429,7 +429,9 @@ async fn client( // work as the first byte too, but be more // explicit to avoid a confusing surprise for some // future developer). - let mut buf: BytesMut = BytesMut::from(&b"\x40Hello World!"[..]); + // Hi, future developer here, this is an invalid ipv4 size, + // now the code sends 20 bytes so we don't fail any validations + let mut buf: BytesMut = BytesMut::from(&b"\x40Hello World!But Bigger"[..]); eprintln!("Sending message: {buf:?}"); client.inside_data_received(&mut buf).expect("Send my message"); message_sent = true; From 1cd0ada10244bbbe1288f4de374d4a8414d74af2 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 14 Jan 2025 15:48:07 +0200 Subject: [PATCH 07/18] revert dashmap (dup versions) - move to hashbrown --- Cargo.lock | 27 +---------------------- lightway-server/Cargo.toml | 3 +-- lightway-server/src/connection_manager.rs | 13 ++++++----- 3 files changed, 9 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a61e729b..4cb0e15e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -592,20 +592,6 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "dashmap" -version = "6.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" -dependencies = [ - "cfg-if", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "delegate" version = "0.12.0" @@ -974,16 +960,6 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", - "allocator-api2", -] - [[package]] name = "hashbrown" version = "0.15.2" @@ -1332,10 +1308,9 @@ dependencies = [ "bytesize", "clap", "ctrlc", - "dashmap", "delegate", "educe", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "ipnet", "jsonwebtoken", "libc", diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 7010cb72..5b545d39 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -24,10 +24,9 @@ bytes.workspace = true bytesize.workspace = true clap.workspace = true ctrlc.workspace = true -dashmap = "6.1.0" delegate.workspace = true educe.workspace = true -hashbrown = "=0.14.5" # Force older version for dashmap compatibility +hashbrown = "0.15.2" ipnet.workspace = true jsonwebtoken = "9.3.0" libc.workspace = true diff --git a/lightway-server/src/connection_manager.rs b/lightway-server/src/connection_manager.rs index a15ff9da..30203baf 100644 --- a/lightway-server/src/connection_manager.rs +++ b/lightway-server/src/connection_manager.rs @@ -1,7 +1,7 @@ mod connection_map; -use dashmap::DashMap; use delegate::delegate; +use hashbrown::HashMap; use parking_lot::Mutex; use std::{ net::SocketAddr, @@ -108,7 +108,7 @@ async fn evict_expired_connections(manager: Weak) { pub(crate) struct ConnectionManager { ctx: ServerContext, connections: Mutex>, - pending_session_id_rotations: DashMap>, + pending_session_id_rotations: Mutex>>, /// Total number of sessions there have ever been total_sessions: AtomicUsize, } @@ -231,7 +231,7 @@ impl ConnectionManager { let conn_manager = Arc::new(Self { ctx, connections: Mutex::new(Default::default()), - pending_session_id_rotations: Default::default(), + pending_session_id_rotations: Mutex::new(Default::default()), total_sessions: Default::default(), }); @@ -253,7 +253,7 @@ impl ConnectionManager { } pub(crate) fn pending_session_id_rotations_count(&self) -> usize { - self.pending_session_id_rotations.len() + self.pending_session_id_rotations.lock().len() } pub(crate) fn create_streaming_connection( @@ -329,7 +329,7 @@ impl ConnectionManager { } connection_map::Entry::Vacant(_e) => { // Maybe this is a pending session rotation - if let Some(c) = self.pending_session_id_rotations.get(&session_id) { + if let Some(c) = self.pending_session_id_rotations.lock().get(&session_id) { let update_peer_address = addr != c.peer_addr(); return Ok((c.clone(), update_peer_address)); @@ -365,6 +365,7 @@ impl ConnectionManager { new_session_id: SessionId, ) { self.pending_session_id_rotations + .lock() .insert(new_session_id, conn.clone()); metrics::udp_session_rotation_begin(); @@ -376,7 +377,7 @@ impl ConnectionManager { old: SessionId, new: SessionId, ) { - self.pending_session_id_rotations.remove(&new); + self.pending_session_id_rotations.lock().remove(&new); self.connections .lock() .update_session_id_for_connection(old, new); From bc2ab35796323a5a012beea4c3d051b84e90efb8 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Mon, 27 Jan 2025 01:00:20 +0200 Subject: [PATCH 08/18] feat: support managing IOURING directly --- Cargo.lock | 1 + lightway-app-utils/Cargo.toml | 4 +- lightway-app-utils/src/iouring.rs | 627 ++++++++++++------------------ lightway-app-utils/src/metrics.rs | 7 + lightway-app-utils/src/tun.rs | 5 +- lightway-server/Cargo.toml | 1 + 6 files changed, 259 insertions(+), 386 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4cb0e15e..1a4db3be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1218,6 +1218,7 @@ dependencies = [ "libc", "lightway-core", "metrics", + "parking_lot", "pnet", "serde", "serde_with", diff --git a/lightway-app-utils/Cargo.toml b/lightway-app-utils/Cargo.toml index df8da5a2..712db5d6 100644 --- a/lightway-app-utils/Cargo.toml +++ b/lightway-app-utils/Cargo.toml @@ -11,7 +11,8 @@ readme = "README.md" [features] default = [ "tokio" ] -io-uring = [ "dep:io-uring", "dep:tokio", "dep:tokio-eventfd" ] +io-uring = [ "dep:io-uring", "dep:tokio", "dep:tokio-eventfd", "dep:parking_lot" ] +iouring-bufsize = [ "io-uring" ] tokio = [ "dep:tokio", "dep:tokio-stream" ] [lints] @@ -28,6 +29,7 @@ ipnet.workspace = true libc.workspace = true lightway-core.workspace = true metrics.workspace = true +parking_lot = { version = "0.12.3", optional = true } serde.workspace = true serde_with = "3.4.0" serde_yaml = "0.9.34" diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index d96108d0..bbe1b806 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -1,63 +1,110 @@ -use anyhow::{Context, Result, anyhow}; -use bytes::{BufMut, Bytes, BytesMut}; -use lightway_core::IOCallbackResult; -use thiserror::Error; - use crate::metrics; -use io_uring::{ - Builder, IoUring, SubmissionQueue, Submitter, cqueue::Entry as CEntry, opcode, - squeue::Entry as SEntry, types::Fixed, -}; +use anyhow::Result; +use bytes::BytesMut; +use io_uring::{opcode, types, IoUring}; +use libc::iovec; +use lightway_core::IOCallbackResult; +use parking_lot::Mutex; use std::{ - os::fd::{AsRawFd, RawFd}, - sync::Arc, + cell::UnsafeCell, + os::unix::io::{AsRawFd, RawFd}, + sync::{ + atomic::{AtomicU32, AtomicUsize, Ordering}, + Arc, + }, thread, time::Duration, }; -use tokio::{ - io::AsyncReadExt, - sync::{Mutex, mpsc}, -}; -use tokio_eventfd::EventFd; -const REGISTERED_FD_INDEX: u32 = 0; +const DEFAULT_BUFFER_SIZE: usize = 2048; -/// IO-uring Struct -pub struct IOUring { - /// Any struct corresponds to a file descriptor - owned_fd: Arc, +#[cfg(not(feature = "iouring-bufsize"))] +const BUFFER_SIZE: usize = DEFAULT_BUFFER_SIZE; - tx_queue: mpsc::Sender, - rx_queue: Mutex>, +#[cfg(feature = "iouring-bufsize")] +const MAX_BUFFER_SIZE: usize = 65536; +#[cfg(feature = "iouring-bufsize")] +const BUFFER_SIZE: usize = { + let size = std::env!("IOURING_BUFFER_SIZE") + .parse::() + .expect("IOURING_BUFFER_SIZE must be a valid usize"); + assert!(size <= MAX_BUFFER_SIZE, "Buffer size cannot exceed 64KB"); + size +}; + +#[repr(align(128))] +struct Buffer { + data: UnsafeCell<[u8; BUFFER_SIZE]>, + state: AtomicU32, // 0 = free, 1 = in_flight, 2 = completed + length: AtomicU32, } -/// An error from read/write operation -#[derive(Debug, Error)] -pub enum IOUringError { - /// A recv error occurred - #[error("Recv Error")] - RecvError, +#[allow(unsafe_code)] +unsafe impl Send for Buffer {} +#[allow(unsafe_code)] +unsafe impl Sync for Buffer {} + +impl Buffer { + fn new() -> Self { + Self { + data: UnsafeCell::new([0u8; BUFFER_SIZE]), + state: AtomicU32::new(0), + length: AtomicU32::new(0), + } + } +} - /// A send error occurred - #[error("Send Error")] - SendError, +struct BufferPool { + buffers: Vec, + read_idx: AtomicUsize, } -pub type IOUringResult = std::result::Result; +/// IO-uring Struct +pub struct IOUring { + owned_fd: Arc, + rx_pool: Arc, + tx_pool: Arc, + ring: Arc, + submission_lock: Arc>, + write_index: AtomicUsize, +} +#[allow(unsafe_code)] impl IOUring { /// Create `IOUring` struct pub async fn new( owned_fd: Arc, ring_size: usize, - channel_size: usize, + _channel_size: usize, mtu: usize, sqpoll_idle_time: Duration, ) -> Result { + assert!(mtu <= BUFFER_SIZE); + + let rx_pool = Arc::new(BufferPool { + buffers: (0..ring_size / 2).map(|_| Buffer::new()).collect(), + read_idx: AtomicUsize::new(0), + }); + + let tx_pool = Arc::new(BufferPool { + buffers: (0..ring_size / 2).map(|_| Buffer::new()).collect(), + read_idx: AtomicUsize::new(0), + }); + + let ring = Arc::new( + IoUring::builder() + .setup_sqpoll(sqpoll_idle_time.as_millis() as u32) + .build(ring_size as u32)?, + ); + + let submission_lock = Arc::new(Mutex::new(())); + + let rx_pool_clone = rx_pool.clone(); + let tx_pool_clone = tx_pool.clone(); + let ring_clone = ring.clone(); + let lock_clone = submission_lock.clone(); let fd = owned_fd.as_raw_fd(); - let (tx_queue_sender, tx_queue_receiver) = mpsc::channel(channel_size); - let (rx_queue_sender, rx_queue_receiver) = mpsc::channel(channel_size); thread::Builder::new() .name("io_uring-main".to_string()) .spawn(move || { @@ -67,21 +114,20 @@ impl IOUring { .expect("Failed building Tokio Runtime") .block_on(iouring_task( fd, - ring_size, - mtu, - sqpoll_idle_time, - tx_queue_receiver, - rx_queue_sender, + rx_pool_clone, + tx_pool_clone, + ring_clone, + lock_clone, )) - .inspect_err(|err| { - tracing::error!("i/o uring task stopped: {:?}", err); - }) })?; Ok(Self { owned_fd, - tx_queue: tx_queue_sender, - rx_queue: Mutex::new(rx_queue_receiver), + rx_pool, + tx_pool, + ring, + submission_lock, + write_index: AtomicUsize::new(0), }) } @@ -90,375 +136,190 @@ impl IOUring { &self.owned_fd } - /// Receive packet from Tun device - pub async fn recv(&self) -> IOUringResult { - self.rx_queue - .lock() - .await - .recv() - .await - .ok_or(IOUringError::RecvError) - } - - /// Try Send packet to Tun device + /// Send packet on Tun device (push to RING and submit) pub fn try_send(&self, buf: BytesMut) -> IOCallbackResult { - let buf_len = buf.len(); - let try_send_res = self.tx_queue.try_send(buf.freeze()); - match try_send_res { - Ok(()) => IOCallbackResult::Ok(buf_len), - Err(mpsc::error::TrySendError::Full(_)) => IOCallbackResult::WouldBlock, - Err(_) => { - use std::io::{Error, ErrorKind}; - IOCallbackResult::Err(Error::new(ErrorKind::Other, IOUringError::SendError)) - } + let len = buf.len(); + if len > BUFFER_SIZE { + return IOCallbackResult::WouldBlock; } - } -} - -#[derive(Debug)] -enum SlotIdx { - Tx(isize), - Rx(isize), -} - -impl SlotIdx { - fn from_user_data(u: u64) -> Self { - let u = u as isize; - if u < 0 { Self::Rx(!u) } else { Self::Tx(u) } - } - fn idx(&self) -> usize { - match *self { - SlotIdx::Tx(idx) => idx as usize, - SlotIdx::Rx(idx) => idx as usize, + let write_idx = + self.write_index.fetch_add(1, Ordering::AcqRel) % self.tx_pool.buffers.len(); + let buffer = &self.tx_pool.buffers[write_idx]; + + // Check if buffer is free (state = 0) + if buffer + .state + .compare_exchange( + 0, + 1, // free -> in_flight + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + // Out of buffers, need kernel to work faster + // consider a bigger queue if we see this counter + metrics::tun_iouring_tx_err(); + return IOCallbackResult::WouldBlock; } - } - fn user_data(&self) -> u64 { - match *self { - SlotIdx::Tx(idx) => idx as u64, - SlotIdx::Rx(idx) => (!idx) as u64, - } - } -} + unsafe { (*buffer.data.get())[..len].copy_from_slice(&buf) }; + buffer.length.store(len as u32, Ordering::Release); -struct RxState { - sender: Option>, - buf: BytesMut, -} - -fn push_one_tx_event_to( - buf: Bytes, - sq: &mut SubmissionQueue, - bufs: &mut [Option], - slot: SlotIdx, -) -> std::result::Result<(), SlotIdx> { - let sqe = opcode::Write::new(Fixed(REGISTERED_FD_INDEX), buf.as_ptr(), buf.len() as _) + let write_op = opcode::WriteFixed::new( + types::Fd(self.owned_fd.as_raw_fd()), + buffer.data.get() as *mut u8, + len as _, + write_idx as _, + ) .build() - .user_data(slot.user_data()); - - #[allow(unsafe_code)] - // SAFETY: sqe points to a buffer on the heap, owned - // by a `Bytes` in `bufs[slot]`, we will not reuse - // `bufs[slot]` until `slot` is returned to the slots vector. - if unsafe { sq.push(&sqe) }.is_err() { - return Err(slot); - } - - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - unsafe { - *bufs.get_unchecked_mut(slot.idx()) = Some(buf) - }; - - Ok(()) -} - -fn push_tx_events_to( - sbmt: &Submitter, - sq: &mut SubmissionQueue, - txq: &mut mpsc::Receiver, - slots: &mut Vec, - bufs: &mut [Option], -) -> Result<()> { - while !slots.is_empty() { - if sq.is_full() { - match sbmt.submit() { - Ok(_) => (), - Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, - Err(err) => { - return Err(anyhow!(err)).context("Push TX events failed for sq submit"); - } - } - } - sq.sync(); - - match txq.try_recv() { - Ok(buf) => { - let slot = slots.pop().expect("no tx slots left"); // we are inside `!slots.is_empty()`. - if let Err(slot) = push_one_tx_event_to(buf, sq, bufs, slot) { - slots.push(slot); - break; - } - } - Err(mpsc::error::TryRecvError::Empty) => { - break; - } - Err(err) => { - return Err(anyhow!(err)).context("Push TX events failed for try_recv"); + // NOTE: we set the index starting from after the RX_POOL part + .user_data((self.rx_pool.buffers.len() + write_idx) as u64); + + let _guard = self.submission_lock.lock(); + unsafe { + match self.ring.submission_shared().push(&write_op) { + Ok(_) => IOCallbackResult::Ok(len), + Err(_) => IOCallbackResult::WouldBlock, } } } - Ok(()) -} -fn push_rx_events_to( - sbmt: &Submitter, - sq: &mut SubmissionQueue, - slots: &mut Vec, - state: &mut [RxState], -) -> Result<()> { - loop { - if sq.is_full() { - match sbmt.submit() { - Ok(_) => (), - Err(ref err) if err.raw_os_error() == Some(libc::EBUSY) => break, - Err(err) => { - return Err(anyhow!(err)).context("Push RX events failed for sq submit"); - } - } - } - sq.sync(); - - match slots.pop() { - Some(slot) => { - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - let state = unsafe { state.get_unchecked_mut(slot.idx()) }; - - // queue a new rx - let sqe = opcode::Read::new( - Fixed(REGISTERED_FD_INDEX), - state.buf.as_mut_ptr(), - state.buf.capacity() as _, - ) - .build() - .user_data(slot.user_data()); - #[allow(unsafe_code)] - // SAFETY: sqe points to a buffer on the heap, owned - // by a `BytesMut` in `rx_bufs[slot]`, we will not reuse - // `rx_bufs[slot]` until `slot` is returned to the slots vector. - if unsafe { sq.push(&sqe) }.is_err() { - slots.push(slot); - break; - } - } - None => break, + /// Receive packet from Tun device + pub async fn recv(&self) -> IOCallbackResult { + let idx = self.rx_pool.read_idx.load(Ordering::Relaxed) % self.rx_pool.buffers.len(); + let buffer = &self.rx_pool.buffers[idx]; + + if buffer + .state + .compare_exchange(2, 0, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + let len = buffer.length.load(Ordering::Acquire) as usize; + let mut new_buf = BytesMut::with_capacity(len); + unsafe { new_buf.extend_from_slice(&(*buffer.data.get())[..len]) }; + self.rx_pool.read_idx.fetch_add(1, Ordering::Release); + return IOCallbackResult::Ok(new_buf); } + IOCallbackResult::WouldBlock } - - Ok(()) } +#[allow(unsafe_code)] async fn iouring_task( fd: RawFd, - ring_size: usize, - mtu: usize, - sqpoll_idle_time: Duration, - mut tx_queue: mpsc::Receiver, - rx_queue: mpsc::Sender, + rx_pool: Arc, + tx_pool: Arc, + ring: Arc, + submission_lock: Arc>, ) -> Result<()> { - let mut event_fd: EventFd = EventFd::new(0, false)?; - let mut builder: Builder = IoUring::builder(); - - if sqpoll_idle_time.as_millis() > 0 { - let idle_time: u32 = sqpoll_idle_time - .as_millis() - .try_into() - .with_context(|| "invalid sqpoll idle time")?; - // This setting makes CPU go 100% when there is continuous traffic - builder.setup_sqpoll(idle_time); // Needs 5.13 - } - let mut ring = builder - .build(ring_size as u32) - .inspect_err(|e| tracing::error!("iouring setup failed: {e}"))?; - - let (sbmt, mut sq, mut cq) = ring.split(); - - // Register event-fd to cqe entries - sbmt.register_eventfd(event_fd.as_raw_fd())?; - sbmt.register_files(&[fd])?; - - // Using half of total io-uring size for rx and half for tx - let nr_tx_rx_slots = (ring_size / 2) as isize; - tracing::info!( - ring_size, - nr_tx_rx_slots, - ?sqpoll_idle_time, - "uring main task" - ); - - let mut rx_slots: Vec<_> = (0..nr_tx_rx_slots).map(SlotIdx::Rx).collect(); - let mut rx_state: Vec<_> = rx_slots - .iter() - .map(|_| RxState { - sender: None, - buf: BytesMut::with_capacity(mtu), - }) - .collect(); - for state in rx_state.iter_mut() { - state.sender = Some(rx_queue.clone().reserve_owned().await?) - } - - let rx_sq_entries: Vec<_> = rx_slots - .drain(..) - .map(|slot| { - let state = &mut rx_state[slot.idx()]; - opcode::Read::new( - Fixed(REGISTERED_FD_INDEX), - state.buf.as_mut_ptr(), - state.buf.capacity() as _, + let rx_len = rx_pool.buffers.len(); + let mut iovecs = Vec::with_capacity(rx_len + tx_pool.buffers.len()); + + iovecs.extend(rx_pool.buffers.iter().map(|buf| iovec { + iov_base: buf.data.get() as *mut libc::c_void, + iov_len: BUFFER_SIZE, + })); + + iovecs.extend(tx_pool.buffers.iter().map(|buf| iovec { + iov_base: buf.data.get() as *mut libc::c_void, + iov_len: BUFFER_SIZE, + })); + + unsafe { ring.submitter().register_buffers(&iovecs)? }; + + // Initial submission of read operations + { + let _guard = submission_lock.lock(); + let mut sq = unsafe { ring.submission_shared() }; + for idx in 0..rx_len { + let read_op = opcode::ReadFixed::new( + types::Fd(fd), + rx_pool.buffers[idx].data.get() as *mut u8, + BUFFER_SIZE as _, + idx as _, ) .build() - .user_data(slot.user_data()) - }) - .collect(); + .user_data(idx as u64); - // SAFETY: sqe points to a buffer on the heap, owned - // by a `BytesMut` in `rx_bufs[slot]`, we will not reuse - // `rx_bufs[slot]` until `slot` is returned to the slots vector. - #[allow(unsafe_code)] - unsafe { - let entries = rx_sq_entries; - // This call should not fail since the SubmissionQueue should be empty now - sq.push_multiple(&entries)? - }; - sq.sync(); - - let mut tx_slots: Vec<_> = (0..nr_tx_rx_slots).map(SlotIdx::Tx).collect(); - let mut tx_bufs = vec![None; tx_slots.len()]; + unsafe { sq.push(&read_op)? }; + rx_pool.buffers[idx].state.store(1, Ordering::Release); + } + } - tracing::info!("Entering i/o uring loop"); + let in_flight_reads = AtomicUsize::new(rx_len); loop { - let _ = sbmt.submit()?; - - cq.sync(); - - if cq.is_empty() && tx_queue.is_empty() { - let mut completed_number: [u8; 8] = [0; 8]; - tokio::select! { - // There is no "wait until the queue contains - // something" method so we have to actually receive - // and treat that as a special case. - Some(buf) = tx_queue.recv(), if !tx_slots.is_empty() && !sq.is_full() => { - - let slot = tx_slots.pop().expect("no tx slots left"); // we are inside `!slots.is_empty()` guard. - if let Err(slot) = push_one_tx_event_to(buf, &mut sq, &mut tx_bufs, slot) { - tx_slots.push(slot); + ring.submit_and_wait(1)?; + + let mut reads_to_resubmit = false; + for cqe in unsafe { ring.completion_shared() } { + let user_data = cqe.user_data() as usize; + let result = cqe.result(); + + if user_data < rx_len { + let idx = user_data; + in_flight_reads.fetch_sub(1, Ordering::Release); + + if result > 0 { + rx_pool.buffers[idx] + .length + .store(result as u32, Ordering::Release); + rx_pool.buffers[idx].state.store(2, Ordering::Release); + + // Check if we need to resubmit batch + if in_flight_reads.load(Ordering::Acquire) < (rx_len / 4) { + reads_to_resubmit = true; + } + } else { + // Error or EOF case for read + rx_pool.buffers[idx].state.store(0, Ordering::Release); + reads_to_resubmit = true; + if result < 0 { + tracing::error!( + "Read operation failed: {}", + std::io::Error::from_raw_os_error(-result) + ); + metrics::tun_iouring_rx_err(); } } - - Ok(a) = event_fd.read(&mut completed_number) => { - assert_eq!(a, 8); - }, - - }; - cq.sync(); + } else { + // Write completion + let tx_idx = user_data - rx_len; + if result <= 0 { + tracing::error!( + "Write operation failed: {}", + std::io::Error::from_raw_os_error(-result) + ); + metrics::tun_iouring_tx_err(); + } + tx_pool.buffers[tx_idx].state.store(0, Ordering::Release); + } } - // fill tx slots - push_tx_events_to(&sbmt, &mut sq, &mut tx_queue, &mut tx_slots, &mut tx_bufs)?; - - // refill rx slots - push_rx_events_to(&sbmt, &mut sq, &mut rx_slots, &mut rx_state)?; - - sq.sync(); - - for cqe in &mut cq { - let res = cqe.result(); - let slot = SlotIdx::from_user_data(cqe.user_data()); - - match slot { - SlotIdx::Rx(_) => { - if res > 0 { - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - let RxState { - sender: maybe_sender, - buf, - } = unsafe { rx_state.get_unchecked_mut(slot.idx()) }; - - let mut buf = std::mem::replace(buf, BytesMut::with_capacity(mtu)); - - // SAFETY: We trust that the read operation - // returns the correct number of bytes received. - #[allow(unsafe_code)] - unsafe { - buf.advance_mut(res as _); - } - - if let Some(sender) = maybe_sender.take() { - let sender = sender.send(buf); - maybe_sender.replace(sender.reserve_owned().await?); - } else { - panic!("inflight rx state with no sender!"); - }; - } else if res != -libc::EAGAIN { - metrics::tun_iouring_rx_err(); - }; + // Batch resubmit of reads if needed + if reads_to_resubmit { + let _guard = submission_lock.lock(); + let mut sq = unsafe { ring.submission_shared() }; + + for idx in 0..rx_len { + if rx_pool.buffers[idx].state.load(Ordering::Acquire) == 0 { + let read_op = opcode::ReadFixed::new( + types::Fd(fd), + rx_pool.buffers[idx].data.get() as *mut u8, + BUFFER_SIZE as _, + idx as _, + ) + .build() + .user_data(idx as u64); - rx_slots.push(slot); - } - SlotIdx::Tx(_) => { - if res <= 0 { - tracing::info!("rx slot {slot:?} completed with {res}"); + if unsafe { sq.push(&read_op) }.is_ok() { + rx_pool.buffers[idx].state.store(1, Ordering::Release); + in_flight_reads.fetch_add(1, Ordering::Release); } - // handle tx complete, we just need to drop the buffer - // SAFETY: By construction instances of SlotIdx are always in bounds. - #[allow(unsafe_code)] - unsafe { - *tx_bufs.get_unchecked_mut(slot.idx()) = None - }; - tx_slots.push(slot); } } } } } - -#[cfg(test)] -mod tests { - use super::*; - use test_case::test_case; - - #[test_case(SlotIdx::Tx(0) => 0x0000_0000_0000_0000)] - #[test_case(SlotIdx::Tx(10) => 0x0000_0000_0000_000a)] - #[test_case(SlotIdx::Tx(isize::MAX) => 0x7fff_ffff_ffff_ffff)] - #[test_case(SlotIdx::Rx(0) => 0x0000_0000_0000_0000)] - #[test_case(SlotIdx::Rx(10) => 0x0000_0000_0000_000a)] - #[test_case(SlotIdx::Rx(isize::MAX) => 0x7fff_ffff_ffff_ffff)] - fn slotid_idx(id: SlotIdx) -> usize { - id.idx() - } - - #[test_case(SlotIdx::Tx(0) => 0x0000_0000_0000_0000)] - #[test_case(SlotIdx::Tx(10) => 0x0000_0000_0000_000a)] - #[test_case(SlotIdx::Tx(isize::MAX) => 0x7fff_ffff_ffff_ffff)] - #[test_case(SlotIdx::Rx(0) => 0xffff_ffff_ffff_ffff)] - #[test_case(SlotIdx::Rx(10) => 0xffff_ffff_ffff_fff5)] - #[test_case(SlotIdx::Rx(isize::MAX) => 0x8000_0000_0000_0000)] - fn slotid_user_data(id: SlotIdx) -> u64 { - id.user_data() - } - - #[test_case(0x0000_0000_0000_0000 => matches SlotIdx::Tx(0))] - #[test_case(0x0000_0000_0000_000a => matches SlotIdx::Tx(10))] - #[test_case(0x7fff_ffff_ffff_ffff => matches SlotIdx::Tx(isize::MAX))] - #[test_case(0xffff_ffff_ffff_ffff => matches SlotIdx::Rx(0))] - #[test_case(0xffff_ffff_ffff_fff5 => matches SlotIdx::Rx(10))] - #[test_case(0x8000_0000_0000_0000 => matches SlotIdx::Rx(isize::MAX))] - fn slotid_from(u: u64) -> SlotIdx { - SlotIdx::from_user_data(u) - } -} diff --git a/lightway-app-utils/src/metrics.rs b/lightway-app-utils/src/metrics.rs index ce336daa..299d6f9c 100644 --- a/lightway-app-utils/src/metrics.rs +++ b/lightway-app-utils/src/metrics.rs @@ -1,9 +1,16 @@ use metrics::{Counter, counter}; use std::sync::LazyLock; +static METRIC_TUN_IOURING_TX_ERR: LazyLock = + LazyLock::new(|| counter!("tun_iouring_tx_err")); static METRIC_TUN_IOURING_RX_ERR: LazyLock = LazyLock::new(|| counter!("tun_iouring_rx_err")); +/// Count iouring TX entries which complete with an error +pub(crate) fn tun_iouring_tx_err() { + METRIC_TUN_IOURING_TX_ERR.increment(1) +} + /// Count iouring RX entries which complete with an error pub(crate) fn tun_iouring_rx_err() { METRIC_TUN_IOURING_RX_ERR.increment(1) diff --git a/lightway-app-utils/src/tun.rs b/lightway-app-utils/src/tun.rs index 2b57db8c..df54ce9f 100644 --- a/lightway-app-utils/src/tun.rs +++ b/lightway-app-utils/src/tun.rs @@ -168,8 +168,9 @@ impl TunIoUring { /// Recv from Tun pub async fn recv_buf(&self) -> IOCallbackResult { match self.tun_io_uring.recv().await { - Ok(pkt) => IOCallbackResult::Ok(pkt), - Err(e) => { + IOCallbackResult::Ok(pkt) => IOCallbackResult::Ok(pkt), + IOCallbackResult::WouldBlock => IOCallbackResult::WouldBlock, + IOCallbackResult::Err(e) => { use std::io::{Error, ErrorKind}; IOCallbackResult::Err(Error::new(ErrorKind::Other, e)) } diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 5b545d39..4a73ecb1 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -12,6 +12,7 @@ license.workspace = true default = ["io-uring"] debug = ["lightway-core/debug"] io-uring = ["lightway-app-utils/io-uring"] +iouring-bufsize = ["lightway-app-utils/iouring-bufsize"] [lints] workspace = true From da417876ad1a65c6a50f5099051c1a4c658a28df Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Thu, 30 Jan 2025 12:09:48 +0200 Subject: [PATCH 09/18] fixup: remove native-cpu flag (unused for now) --- .cargo/config.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index e08b7963..04cacc9b 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,6 +1,3 @@ [target.aarch64-unknown-linux-gnu] linker = "aarch64-linux-gnu-gcc" -runner = ["qemu-aarch64-static"] # use qemu user emulation for cargo run and test - -[build] -rustflags = ["-C", "target-cpu=native"] \ No newline at end of file +runner = ["qemu-aarch64-static"] # use qemu user emulation for cargo run and test \ No newline at end of file From 774e04f66fbb1eb2602c5fabef9ecc150c8c49a2 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Sun, 2 Feb 2025 10:17:30 +0200 Subject: [PATCH 10/18] fixup: examples + cargo.lock redo --- lightway-app-utils/examples/udprelay.rs | 16 +++++++++++++--- lightway-app-utils/src/iouring.rs | 4 ++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lightway-app-utils/examples/udprelay.rs b/lightway-app-utils/examples/udprelay.rs index bd89bed8..957acbdc 100644 --- a/lightway-app-utils/examples/udprelay.rs +++ b/lightway-app-utils/examples/udprelay.rs @@ -206,14 +206,18 @@ struct TunIOUring { impl TunIOUring { async fn new(tun: Tun, ring_size: usize, channel_size: usize) -> Result { - let tun_iouring = IOUring::new( + let tun_iouring = match IOUring::new( Arc::new(WrappedTun(tun)), ring_size, channel_size, TUN_MTU, Duration::from_millis(100), ) - .await?; + .await + { + Ok(it) => it, + Err(err) => return Err(err), + }; Ok(Self { tun_iouring }) } @@ -231,7 +235,13 @@ impl TunAdapter for TunIOUring { } async fn recv_from_tun(&self) -> Result { - self.tun_iouring.recv().await.map_err(anyhow::Error::msg) + match self.tun_iouring.recv().await { + IOCallbackResult::Ok(pkt) => Ok(pkt), + IOCallbackResult::WouldBlock => { + Err(std::io::Error::from(std::io::ErrorKind::WouldBlock).into()) + } + IOCallbackResult::Err(err) => Err(err.into()), + } } } diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index bbe1b806..a0e258b7 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -1,7 +1,7 @@ use crate::metrics; use anyhow::Result; use bytes::BytesMut; -use io_uring::{opcode, types, IoUring}; +use io_uring::{IoUring, opcode, types}; use libc::iovec; use lightway_core::IOCallbackResult; use parking_lot::Mutex; @@ -9,8 +9,8 @@ use std::{ cell::UnsafeCell, os::unix::io::{AsRawFd, RawFd}, sync::{ - atomic::{AtomicU32, AtomicUsize, Ordering}, Arc, + atomic::{AtomicU32, AtomicUsize, Ordering}, }, thread, time::Duration, From 563388e6d220b2afea51b935886748b682344cc2 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 4 Feb 2025 11:17:29 +0200 Subject: [PATCH 11/18] fixup: undo iouring-bufsize reverse-dependnacy --- lightway-app-utils/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightway-app-utils/Cargo.toml b/lightway-app-utils/Cargo.toml index 712db5d6..0604ea06 100644 --- a/lightway-app-utils/Cargo.toml +++ b/lightway-app-utils/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" [features] default = [ "tokio" ] io-uring = [ "dep:io-uring", "dep:tokio", "dep:tokio-eventfd", "dep:parking_lot" ] -iouring-bufsize = [ "io-uring" ] +iouring-bufsize = [] # Rust does bi-directional dependencies tokio = [ "dep:tokio", "dep:tokio-stream" ] [lints] From 8b79d44e8e4c24ac6ec9155ada55e071bfa54b03 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Thu, 6 Feb 2025 02:20:35 +0200 Subject: [PATCH 12/18] feat: next-gen implementation, wayyyy better now --- lightway-app-utils/Cargo.toml | 1 - lightway-app-utils/src/iouring.rs | 433 +++++++++++++++++++----------- lightway-server/Cargo.toml | 1 - 3 files changed, 273 insertions(+), 162 deletions(-) diff --git a/lightway-app-utils/Cargo.toml b/lightway-app-utils/Cargo.toml index 0604ea06..9d677679 100644 --- a/lightway-app-utils/Cargo.toml +++ b/lightway-app-utils/Cargo.toml @@ -12,7 +12,6 @@ readme = "README.md" [features] default = [ "tokio" ] io-uring = [ "dep:io-uring", "dep:tokio", "dep:tokio-eventfd", "dep:parking_lot" ] -iouring-bufsize = [] # Rust does bi-directional dependencies tokio = [ "dep:tokio", "dep:tokio-stream" ] [lints] diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index a0e258b7..e574dee9 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -6,57 +6,67 @@ use libc::iovec; use lightway_core::IOCallbackResult; use parking_lot::Mutex; use std::{ - cell::UnsafeCell, os::unix::io::{AsRawFd, RawFd}, sync::{ Arc, - atomic::{AtomicU32, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, }, thread, time::Duration, }; +use tokio::sync::Notify; +use tokio_eventfd::EventFd; + +#[repr(u64)] +enum IOUringActionID { + RecycleBuffers = 0x10001000, + ReceivedBuffer = 0xfeedfeed, + RecyclePending = 0xdead1000, +} +const RX_BUFFER_GROUP: u16 = 0xdead; -const DEFAULT_BUFFER_SIZE: usize = 2048; - -#[cfg(not(feature = "iouring-bufsize"))] -const BUFFER_SIZE: usize = DEFAULT_BUFFER_SIZE; +/// Wrapper for raw pointer that guarantees Send + Sync safety +/// Safety: The underlying memory is owned by the Arc'd BufferPool and outlives any pointer usage +struct BufferPtr(*mut u8); +#[allow(unsafe_code)] +unsafe impl Send for BufferPtr {} +#[allow(unsafe_code)] +unsafe impl Sync for BufferPtr {} -#[cfg(feature = "iouring-bufsize")] -const MAX_BUFFER_SIZE: usize = 65536; -#[cfg(feature = "iouring-bufsize")] -const BUFFER_SIZE: usize = { - let size = std::env!("IOURING_BUFFER_SIZE") - .parse::() - .expect("IOURING_BUFFER_SIZE must be a valid usize"); - assert!(size <= MAX_BUFFER_SIZE, "Buffer size cannot exceed 64KB"); - size -}; +impl BufferPtr { + fn as_ptr(&self) -> *mut u8 { + self.0 + } +} -#[repr(align(128))] -struct Buffer { - data: UnsafeCell<[u8; BUFFER_SIZE]>, - state: AtomicU32, // 0 = free, 1 = in_flight, 2 = completed - length: AtomicU32, +struct BufferPool { + data: Vec, // contiguous block of memory (all buffers) + lengths: Vec, + states: Vec, // 0 (false) = free, 1 (true) = in-use + usage_idx: AtomicUsize, + buffer_size: usize, } -#[allow(unsafe_code)] -unsafe impl Send for Buffer {} -#[allow(unsafe_code)] -unsafe impl Sync for Buffer {} +impl BufferPool { + fn new(entry_size: usize, pool_size: usize) -> Self { + // Ensure BUFFER_SIZE is multiple of 128-bit/16-byte (less cache-miss) + let buffer_size = (entry_size + 15) & !15; -impl Buffer { - fn new() -> Self { Self { - data: UnsafeCell::new([0u8; BUFFER_SIZE]), - state: AtomicU32::new(0), - length: AtomicU32::new(0), + data: vec![0u8; buffer_size * pool_size], + lengths: (0..pool_size).map(|_| AtomicUsize::new(0)).collect(), + states: (0..pool_size).map(|_| AtomicBool::new(false)).collect(), + usage_idx: AtomicUsize::new(0), + buffer_size, } } -} -struct BufferPool { - buffers: Vec, - read_idx: AtomicUsize, + fn get_buffer(&self, idx: usize) -> (BufferPtr, &AtomicUsize, &AtomicBool) { + #[allow(unsafe_code)] + let ptr = unsafe { self.data.as_ptr().add(idx * self.buffer_size) as *mut u8 }; + + (BufferPtr(ptr), &self.lengths[idx], &self.states[idx]) + } } /// IO-uring Struct @@ -64,9 +74,11 @@ pub struct IOUring { owned_fd: Arc, rx_pool: Arc, tx_pool: Arc, + rx_notify: Arc, + rx_eventfd: EventFd, + rx_provide_buffers: Arc, ring: Arc, submission_lock: Arc>, - write_index: AtomicUsize, } #[allow(unsafe_code)] @@ -79,17 +91,11 @@ impl IOUring { mtu: usize, sqpoll_idle_time: Duration, ) -> Result { - assert!(mtu <= BUFFER_SIZE); + // NOTE: it's probably a good idea for now to allocate rx/tx/ring at the same size + // this is because the VPN use-case usually has MTU-sized buffers going in-and-out - let rx_pool = Arc::new(BufferPool { - buffers: (0..ring_size / 2).map(|_| Buffer::new()).collect(), - read_idx: AtomicUsize::new(0), - }); - - let tx_pool = Arc::new(BufferPool { - buffers: (0..ring_size / 2).map(|_| Buffer::new()).collect(), - read_idx: AtomicUsize::new(0), - }); + let rx_pool = Arc::new(BufferPool::new(mtu, ring_size)); + let tx_pool = Arc::new(BufferPool::new(mtu, ring_size)); let ring = Arc::new( IoUring::builder() @@ -97,13 +103,55 @@ impl IOUring { .build(ring_size as u32)?, ); + let rx_notify = Arc::new(Notify::new()); + let rx_eventfd = EventFd::new(0, false)?; + let rx_provide_buffers = Arc::new(AtomicBool::new(false)); + + // NOTE: for now this ensures we only create 1 kthread per tunnel, and not 2 (rx/tx) + // we can opt to change this going forward, or redo the structure to not need a lock let submission_lock = Arc::new(Mutex::new(())); + // We can provide the buffers without a lock, as we still havn't shared the ownership + let fd = owned_fd.as_raw_fd(); + unsafe { + let mut sq = ring.submission_shared(); + sq.push( + &opcode::ProvideBuffers::new( + rx_pool.data.as_ptr() as *mut u8, + mtu as i32, + ring_size as u16, + RX_BUFFER_GROUP, + 0, + ) + .build() + .user_data(IOUringActionID::RecycleBuffers as u64), + )?; + sq.push( + &opcode::RecvMulti::new(types::Fd(fd), RX_BUFFER_GROUP) + .build() + .user_data(IOUringActionID::ReceivedBuffer as u64), + )?; + + // A bit ineffective vs. calculate offset directly, but more maintainable + let tx_iovecs: Vec<_> = (0..ring_size) + .map(|idx| { + let (ptr, _, _) = tx_pool.get_buffer(idx); + iovec { + iov_base: ptr.as_ptr() as *mut libc::c_void, + iov_len: mtu, + } + }) + .collect(); + ring.submitter().register_buffers(&tx_iovecs)?; + } + let rx_pool_clone = rx_pool.clone(); let tx_pool_clone = tx_pool.clone(); let ring_clone = ring.clone(); let lock_clone = submission_lock.clone(); - let fd = owned_fd.as_raw_fd(); + let notify_clone = rx_notify.clone(); + let eventfd = rx_eventfd.as_raw_fd(); + let provide_buffers = rx_provide_buffers.clone(); thread::Builder::new() .name("io_uring-main".to_string()) @@ -116,6 +164,9 @@ impl IOUring { fd, rx_pool_clone, tx_pool_clone, + notify_clone, + eventfd, + provide_buffers, ring_clone, lock_clone, )) @@ -125,9 +176,11 @@ impl IOUring { owned_fd, rx_pool, tx_pool, + rx_notify, + rx_eventfd, + rx_provide_buffers, ring, submission_lock, - write_index: AtomicUsize::new(0), }) } @@ -138,21 +191,26 @@ impl IOUring { /// Send packet on Tun device (push to RING and submit) pub fn try_send(&self, buf: BytesMut) -> IOCallbackResult { + // For semantics, see recv() function below + let idx = self + .tx_pool + .usage_idx + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |idx| { + Some((idx + 1) % self.rx_pool.states.len()) + }) + .unwrap(); + let (buffer, length, state) = self.tx_pool.get_buffer(idx); + let len = buf.len(); - if len > BUFFER_SIZE { + if len > length.load(Ordering::Relaxed) { return IOCallbackResult::WouldBlock; } - let write_idx = - self.write_index.fetch_add(1, Ordering::AcqRel) % self.tx_pool.buffers.len(); - let buffer = &self.tx_pool.buffers[write_idx]; - // Check if buffer is free (state = 0) - if buffer - .state + if state .compare_exchange( - 0, - 1, // free -> in_flight + false, + true, // free -> in-use Ordering::AcqRel, Ordering::Relaxed, ) @@ -164,19 +222,22 @@ impl IOUring { return IOCallbackResult::WouldBlock; } - unsafe { (*buffer.data.get())[..len].copy_from_slice(&buf) }; - buffer.length.store(len as u32, Ordering::Release); + unsafe { std::slice::from_raw_parts_mut(buffer.as_ptr(), len).copy_from_slice(&buf) }; + length.store(len, Ordering::Release); + // NOTE: IOUringActionID values have to be bigger then the ring-size + // this is because we use here as data for send_fixed operations let write_op = opcode::WriteFixed::new( types::Fd(self.owned_fd.as_raw_fd()), - buffer.data.get() as *mut u8, + buffer.as_ptr(), len as _, - write_idx as _, + idx as _, ) .build() // NOTE: we set the index starting from after the RX_POOL part - .user_data((self.rx_pool.buffers.len() + write_idx) as u64); + .user_data(idx as u64); + // Safely queue submission let _guard = self.submission_lock.lock(); unsafe { match self.ring.submission_shared().push(&write_op) { @@ -188,136 +249,188 @@ impl IOUring { /// Receive packet from Tun device pub async fn recv(&self) -> IOCallbackResult { - let idx = self.rx_pool.read_idx.load(Ordering::Relaxed) % self.rx_pool.buffers.len(); - let buffer = &self.rx_pool.buffers[idx]; - - if buffer - .state - .compare_exchange(2, 0, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - let len = buffer.length.load(Ordering::Acquire) as usize; - let mut new_buf = BytesMut::with_capacity(len); - unsafe { new_buf.extend_from_slice(&(*buffer.data.get())[..len]) }; - self.rx_pool.read_idx.fetch_add(1, Ordering::Release); - return IOCallbackResult::Ok(new_buf); + // NOTE: Explanation on why these semantics were used: + // Flow: + // 1. The current value is loaded + // 2. Our closure is called with that value + // 3. A compare-and-swap (CAS) operation attempts to update with our new value + // + // The calculation of (X+1 % len) happens INSIDE closure, after the load but before the CAS. + // So if multiple threads are running concurrently: + // - Thread A loads value X + // - Thread B loads value X (before A's CAS completes) + // - Both calculate X+1 % len + // - We need AqcRel to ensure threads don't set values on top of each-other. + // - First thread's CAS should succeed as no value changed + // - Second thread's CAS should fail because the value changed + // - Second thread would retry, so we need Acquire on fetch to see Thread A's value + let idx = self + .rx_pool + .usage_idx + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |idx| { + Some((idx + 1) % self.rx_pool.states.len()) + }) + .unwrap(); + let (buffer, length, state) = self.rx_pool.get_buffer(idx); + + loop { + // NOTE: unlike the above case, here we can use Relaxed ordering for better performance. + // This is because we don't use the value in a closure, so we don't care for ensuring it's current value + if state + .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + let len = length.load(Ordering::Acquire); + let mut new_buf = BytesMut::with_capacity(len); + unsafe { + new_buf.extend_from_slice(std::slice::from_raw_parts(buffer.as_ptr(), len)) + }; + return IOCallbackResult::Ok(new_buf); + } + // IO-Bound wait for available buffers + self.rx_notify.notified().await; + + // Check if kernel needs more buffers (and ensure only one notification is sent) + if self + .rx_provide_buffers + .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + let val = 1u64; + unsafe { + if libc::write( + self.rx_eventfd.as_raw_fd(), + &val as *const u64 as *const _, + 8, + ) < 0 + { + let err = std::io::Error::last_os_error(); + tracing::error!("Failed to write to eventfd: {}", err); + // The following is a prayer to god to hopefully succeed next time around + self.rx_provide_buffers.store(true, Ordering::Release); + } + } + } } - IOCallbackResult::WouldBlock } } #[allow(unsafe_code)] async fn iouring_task( - fd: RawFd, + tun_fd: RawFd, rx_pool: Arc, tx_pool: Arc, + rx_notify: Arc, + rx_eventfd: RawFd, + rx_provide_buffers: Arc, ring: Arc, submission_lock: Arc>, ) -> Result<()> { - let rx_len = rx_pool.buffers.len(); - let mut iovecs = Vec::with_capacity(rx_len + tx_pool.buffers.len()); - - iovecs.extend(rx_pool.buffers.iter().map(|buf| iovec { - iov_base: buf.data.get() as *mut libc::c_void, - iov_len: BUFFER_SIZE, - })); - - iovecs.extend(tx_pool.buffers.iter().map(|buf| iovec { - iov_base: buf.data.get() as *mut libc::c_void, - iov_len: BUFFER_SIZE, - })); - - unsafe { ring.submitter().register_buffers(&iovecs)? }; + let mut eventfd_buf = [0u64; 1]; // Buffer for eventfd read (8 bytes) - // Initial submission of read operations - { + // Submit initial read for eventfd (needs to be here for buffer to be on stack of the task) + unsafe { let _guard = submission_lock.lock(); - let mut sq = unsafe { ring.submission_shared() }; - for idx in 0..rx_len { - let read_op = opcode::ReadFixed::new( - types::Fd(fd), - rx_pool.buffers[idx].data.get() as *mut u8, - BUFFER_SIZE as _, - idx as _, + let mut sq = ring.submission_shared(); + sq.push( + &opcode::Read::new( + types::Fd(rx_eventfd), + eventfd_buf.as_mut_ptr() as *mut u8, + 8, ) .build() - .user_data(idx as u64); - - unsafe { sq.push(&read_op)? }; - rx_pool.buffers[idx].state.store(1, Ordering::Release); - } + .user_data(IOUringActionID::RecyclePending as u64), + )?; } - let in_flight_reads = AtomicUsize::new(rx_len); - loop { ring.submit_and_wait(1)?; - let mut reads_to_resubmit = false; for cqe in unsafe { ring.completion_shared() } { - let user_data = cqe.user_data() as usize; - let result = cqe.result(); - - if user_data < rx_len { - let idx = user_data; - in_flight_reads.fetch_sub(1, Ordering::Release); - - if result > 0 { - rx_pool.buffers[idx] - .length - .store(result as u32, Ordering::Release); - rx_pool.buffers[idx].state.store(2, Ordering::Release); - - // Check if we need to resubmit batch - if in_flight_reads.load(Ordering::Acquire) < (rx_len / 4) { - reads_to_resubmit = true; + match cqe.user_data() { + x if x == IOUringActionID::RecycleBuffers as u64 => { + // Buffer provision completed + tracing::debug!("Buffer provision completed"); + } + + x if x == IOUringActionID::RecyclePending as u64 => { + if cqe.result() > 0 { + // Got notification we need more buffers + // NOTE: This approach is very good for cases we have constant data-flow + // we can only load the buffers for kernel when our read-threads are done with existing data, + // if our read-threads would block for too long elsewhere it would back-pressure the NIF device + let _guard = submission_lock.lock(); + unsafe { + let mut sq = ring.submission_shared(); + + // Make sure kernel can use all buffers again + sq.push( + &opcode::ProvideBuffers::new( + rx_pool.data.as_ptr() as *mut u8, + rx_pool.buffer_size as i32, + rx_pool.states.len() as u16, + RX_BUFFER_GROUP, + 0, + ) + .build() + .user_data(IOUringActionID::RecycleBuffers as u64), + )?; + + sq.push( + &opcode::RecvMulti::new(types::Fd(tun_fd), RX_BUFFER_GROUP) + .build() + .user_data(IOUringActionID::ReceivedBuffer as u64), + )?; + + // Resubmit eventfd read + sq.push( + &opcode::Read::new( + types::Fd(rx_eventfd), + eventfd_buf.as_mut_ptr() as *mut u8, + 8, + ) + .build() + .user_data(IOUringActionID::RecyclePending as u64), + )?; + } } - } else { - // Error or EOF case for read - rx_pool.buffers[idx].state.store(0, Ordering::Release); - reads_to_resubmit = true; + } + + x if x == IOUringActionID::ReceivedBuffer as u64 => { + let result = cqe.result(); if result < 0 { tracing::error!( - "Read operation failed: {}", + "Receive failed: {}", std::io::Error::from_raw_os_error(-result) ); metrics::tun_iouring_rx_err(); + continue; } - } - } else { - // Write completion - let tx_idx = user_data - rx_len; - if result <= 0 { - tracing::error!( - "Write operation failed: {}", - std::io::Error::from_raw_os_error(-result) - ); - metrics::tun_iouring_tx_err(); - } - tx_pool.buffers[tx_idx].state.store(0, Ordering::Release); - } - } - // Batch resubmit of reads if needed - if reads_to_resubmit { - let _guard = submission_lock.lock(); - let mut sq = unsafe { ring.submission_shared() }; - - for idx in 0..rx_len { - if rx_pool.buffers[idx].state.load(Ordering::Acquire) == 0 { - let read_op = opcode::ReadFixed::new( - types::Fd(fd), - rx_pool.buffers[idx].data.get() as *mut u8, - BUFFER_SIZE as _, - idx as _, - ) - .build() - .user_data(idx as u64); + let buf_id = io_uring::cqueue::buffer_select(cqe.flags()).unwrap(); + let (_, length, state) = rx_pool.get_buffer(buf_id as _); + + length.store(result as usize, Ordering::Release); + state.store(true, Ordering::Release); // Mark as ready-for-user + rx_notify.notify_waiters(); - if unsafe { sq.push(&read_op) }.is_ok() { - rx_pool.buffers[idx].state.store(1, Ordering::Release); - in_flight_reads.fetch_add(1, Ordering::Release); + if !io_uring::cqueue::more(cqe.flags()) { + rx_provide_buffers.store(true, Ordering::Release); + } + } + + idx => { + // TX completion + let result = cqe.result(); + if result < 0 { + tracing::error!( + "Send failed: {}", + std::io::Error::from_raw_os_error(-result) + ); + metrics::tun_iouring_tx_err(); } + let (_, _, state) = tx_pool.get_buffer(idx as _); + state.store(false, Ordering::Release); // mark as available for send } } } diff --git a/lightway-server/Cargo.toml b/lightway-server/Cargo.toml index 4a73ecb1..5b545d39 100644 --- a/lightway-server/Cargo.toml +++ b/lightway-server/Cargo.toml @@ -12,7 +12,6 @@ license.workspace = true default = ["io-uring"] debug = ["lightway-core/debug"] io-uring = ["lightway-app-utils/io-uring"] -iouring-bufsize = ["lightway-app-utils/iouring-bufsize"] [lints] workspace = true From 4ef37f56d57eed3641afc2760e85966da4b05ebe Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Mon, 10 Feb 2025 13:56:22 +0200 Subject: [PATCH 13/18] lint: fix lint-err (safety, split-unsafe, scoping) --- lightway-app-utils/src/iouring.rs | 254 +++++++++++++++++------------- 1 file changed, 148 insertions(+), 106 deletions(-) diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index e574dee9..49d80923 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -25,12 +25,14 @@ enum IOUringActionID { } const RX_BUFFER_GROUP: u16 = 0xdead; -/// Wrapper for raw pointer that guarantees Send + Sync safety -/// Safety: The underlying memory is owned by the Arc'd BufferPool and outlives any pointer usage +/// A wrapper around a raw pointer that guarantees thread safety through Arc ownership struct BufferPtr(*mut u8); + #[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures exclusive access unsafe impl Send for BufferPtr {} #[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures synchronized access unsafe impl Sync for BufferPtr {} impl BufferPtr { @@ -39,6 +41,7 @@ impl BufferPtr { } } +/// A pool of buffers with an underlying contiguous memory block struct BufferPool { data: Vec, // contiguous block of memory (all buffers) lengths: Vec, @@ -62,6 +65,7 @@ impl BufferPool { } fn get_buffer(&self, idx: usize) -> (BufferPtr, &AtomicUsize, &AtomicBool) { + // Safety: Index is bounds-checked by the caller, and buffer_size ensures no overflow #[allow(unsafe_code)] let ptr = unsafe { self.data.as_ptr().add(idx * self.buffer_size) as *mut u8 }; @@ -81,6 +85,7 @@ pub struct IOUring { submission_lock: Arc>, } +// Safety: IOUring implementation does direct memory manipulations for performence benifits #[allow(unsafe_code)] impl IOUring { /// Create `IOUring` struct @@ -113,45 +118,61 @@ impl IOUring { // We can provide the buffers without a lock, as we still havn't shared the ownership let fd = owned_fd.as_raw_fd(); - unsafe { - let mut sq = ring.submission_shared(); - sq.push( - &opcode::ProvideBuffers::new( - rx_pool.data.as_ptr() as *mut u8, - mtu as i32, - ring_size as u16, - RX_BUFFER_GROUP, - 0, - ) - .build() - .user_data(IOUringActionID::RecycleBuffers as u64), - )?; - sq.push( - &opcode::RecvMulti::new(types::Fd(fd), RX_BUFFER_GROUP) + + // Scope submission-queue operations to avoid borrowing ring + { + // Safety: Ring submission can be used without locks at this point + let mut sq = unsafe { ring.submission_shared() }; + + // Safety: Buffer memory is owned by rx_pool and outlives the usage + unsafe { + sq.push( + &opcode::ProvideBuffers::new( + rx_pool.data.as_ptr() as *mut u8, + mtu as i32, + ring_size as u16, + RX_BUFFER_GROUP, + 0, + ) .build() - .user_data(IOUringActionID::ReceivedBuffer as u64), - )?; - - // A bit ineffective vs. calculate offset directly, but more maintainable - let tx_iovecs: Vec<_> = (0..ring_size) - .map(|idx| { - let (ptr, _, _) = tx_pool.get_buffer(idx); - iovec { - iov_base: ptr.as_ptr() as *mut libc::c_void, - iov_len: mtu, - } - }) - .collect(); - ring.submitter().register_buffers(&tx_iovecs)?; + .user_data(IOUringActionID::RecycleBuffers as u64), + )? + }; + + // Safety: Ring is initialized and file descriptor is valid + unsafe { + sq.push( + &opcode::RecvMulti::new(types::Fd(fd), RX_BUFFER_GROUP) + .build() + .user_data(IOUringActionID::ReceivedBuffer as u64), + )? + }; } - let rx_pool_clone = rx_pool.clone(); - let tx_pool_clone = tx_pool.clone(); - let ring_clone = ring.clone(); - let lock_clone = submission_lock.clone(); - let notify_clone = rx_notify.clone(); - let eventfd = rx_eventfd.as_raw_fd(); - let provide_buffers = rx_provide_buffers.clone(); + // A bit ineffective vs. calculate offset directly, but more maintainable + let tx_iovecs: Vec<_> = (0..ring_size) + .map(|idx| { + let (ptr, _, _) = tx_pool.get_buffer(idx); + iovec { + iov_base: ptr.as_ptr() as *mut libc::c_void, + iov_len: mtu, + } + }) + .collect(); + + // Safety: tx_iovecs point to valid memory owned by tx_pool + unsafe { ring.submitter().register_buffers(&tx_iovecs)? }; + + let config = IOUringTaskConfig { + tun_fd: fd, + rx_pool: rx_pool.clone(), + tx_pool: tx_pool.clone(), + rx_notify: rx_notify.clone(), + rx_eventfd: rx_eventfd.as_raw_fd(), + rx_provide_buffers: rx_provide_buffers.clone(), + ring: ring.clone(), + submission_lock: submission_lock.clone(), + }; thread::Builder::new() .name("io_uring-main".to_string()) @@ -160,16 +181,7 @@ impl IOUring { .enable_io() .build() .expect("Failed building Tokio Runtime") - .block_on(iouring_task( - fd, - rx_pool_clone, - tx_pool_clone, - notify_clone, - eventfd, - provide_buffers, - ring_clone, - lock_clone, - )) + .block_on(iouring_task(config)) })?; Ok(Self { @@ -222,6 +234,7 @@ impl IOUring { return IOCallbackResult::WouldBlock; } + // Safety: Buffer is allocated with sufficient size and ownership is checked via state unsafe { std::slice::from_raw_parts_mut(buffer.as_ptr(), len).copy_from_slice(&buf) }; length.store(len, Ordering::Release); @@ -238,11 +251,16 @@ impl IOUring { .user_data(idx as u64); // Safely queue submission - let _guard = self.submission_lock.lock(); - unsafe { - match self.ring.submission_shared().push(&write_op) { - Ok(_) => IOCallbackResult::Ok(len), - Err(_) => IOCallbackResult::WouldBlock, + { + let _guard = self.submission_lock.lock(); + // Safety: protected by lock above + let mut sq = unsafe { self.ring.submission_shared() }; + // Safety: entry uses buffers from rx_pool which outlive task using them + unsafe { + match sq.push(&write_op) { + Ok(_) => IOCallbackResult::Ok(len), + Err(_) => IOCallbackResult::WouldBlock, + } } } } @@ -282,6 +300,8 @@ impl IOUring { { let len = length.load(Ordering::Acquire); let mut new_buf = BytesMut::with_capacity(len); + + // Safety: Buffer is allocated with sufficient size and ownership is checked via state unsafe { new_buf.extend_from_slice(std::slice::from_raw_parts(buffer.as_ptr(), len)) }; @@ -296,8 +316,9 @@ impl IOUring { .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { - let val = 1u64; + // Safety: buffer is defined on stack, event-fd outlives task using it unsafe { + let val = 1u64; if libc::write( self.rx_eventfd.as_raw_fd(), &val as *const u64 as *const _, @@ -315,8 +336,8 @@ impl IOUring { } } -#[allow(unsafe_code)] -async fn iouring_task( +/// Task variables +struct IOUringTaskConfig { tun_fd: RawFd, rx_pool: Arc, tx_pool: Arc, @@ -325,28 +346,38 @@ async fn iouring_task( rx_provide_buffers: Arc, ring: Arc, submission_lock: Arc>, -) -> Result<()> { +} + +// Safety: To manage ring completion and results effeciantly requires direct memory manipulations +#[allow(unsafe_code)] +async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { let mut eventfd_buf = [0u64; 1]; // Buffer for eventfd read (8 bytes) // Submit initial read for eventfd (needs to be here for buffer to be on stack of the task) - unsafe { - let _guard = submission_lock.lock(); - let mut sq = ring.submission_shared(); - sq.push( - &opcode::Read::new( - types::Fd(rx_eventfd), - eventfd_buf.as_mut_ptr() as *mut u8, - 8, - ) - .build() - .user_data(IOUringActionID::RecyclePending as u64), - )?; + { + let _guard = config.submission_lock.lock(); + // Safety: protected by above lock + let mut sq = unsafe { config.ring.submission_shared() }; + // Safety: event-fd outlives the task, queue protected by lock + unsafe { + sq.push( + &opcode::Read::new( + types::Fd(config.rx_eventfd), + eventfd_buf.as_mut_ptr() as *mut u8, + 8, + ) + .build() + .user_data(IOUringActionID::RecyclePending as u64), + )? + }; } loop { - ring.submit_and_wait(1)?; + // Work once we have at least 1 task to perform + config.ring.submit_and_wait(1)?; - for cqe in unsafe { ring.completion_shared() } { + // Safety: only task is using the completion-queue (concept should not change) + for cqe in unsafe { config.ring.completion_shared() } { match cqe.user_data() { x if x == IOUringActionID::RecycleBuffers as u64 => { // Buffer provision completed @@ -359,39 +390,50 @@ async fn iouring_task( // NOTE: This approach is very good for cases we have constant data-flow // we can only load the buffers for kernel when our read-threads are done with existing data, // if our read-threads would block for too long elsewhere it would back-pressure the NIF device - let _guard = submission_lock.lock(); - unsafe { - let mut sq = ring.submission_shared(); - - // Make sure kernel can use all buffers again - sq.push( - &opcode::ProvideBuffers::new( - rx_pool.data.as_ptr() as *mut u8, - rx_pool.buffer_size as i32, - rx_pool.states.len() as u16, - RX_BUFFER_GROUP, - 0, - ) - .build() - .user_data(IOUringActionID::RecycleBuffers as u64), - )?; - - sq.push( - &opcode::RecvMulti::new(types::Fd(tun_fd), RX_BUFFER_GROUP) + let _guard = config.submission_lock.lock(); + + // Safety: protected by above lock + let mut sq = unsafe { config.ring.submission_shared() }; + + // Make sure kernel can use all buffers again + { + // Safety: buffers are mapped from rx_pool which outlives this task + unsafe { + sq.push( + &opcode::ProvideBuffers::new( + config.rx_pool.data.as_ptr() as *mut u8, + config.rx_pool.buffer_size as i32, + config.rx_pool.states.len() as u16, + RX_BUFFER_GROUP, + 0, + ) + .build() + .user_data(IOUringActionID::RecycleBuffers as u64), + )? + }; + // Safety: buffer-group originates from rx_pool which outlives this task + unsafe { + sq.push( + &opcode::RecvMulti::new( + types::Fd(config.tun_fd), + RX_BUFFER_GROUP, + ) .build() .user_data(IOUringActionID::ReceivedBuffer as u64), - )?; - - // Resubmit eventfd read - sq.push( - &opcode::Read::new( - types::Fd(rx_eventfd), - eventfd_buf.as_mut_ptr() as *mut u8, - 8, - ) - .build() - .user_data(IOUringActionID::RecyclePending as u64), - )?; + )? + }; + // Safety: Event-fd outlives the task, buffer is task-bound (stack) + unsafe { + sq.push( + &opcode::Read::new( + types::Fd(config.rx_eventfd), + eventfd_buf.as_mut_ptr() as *mut u8, + 8, + ) + .build() + .user_data(IOUringActionID::RecyclePending as u64), + )? + }; } } } @@ -408,14 +450,14 @@ async fn iouring_task( } let buf_id = io_uring::cqueue::buffer_select(cqe.flags()).unwrap(); - let (_, length, state) = rx_pool.get_buffer(buf_id as _); + let (_, length, state) = config.rx_pool.get_buffer(buf_id as _); length.store(result as usize, Ordering::Release); state.store(true, Ordering::Release); // Mark as ready-for-user - rx_notify.notify_waiters(); + config.rx_notify.notify_waiters(); if !io_uring::cqueue::more(cqe.flags()) { - rx_provide_buffers.store(true, Ordering::Release); + config.rx_provide_buffers.store(true, Ordering::Release); } } @@ -429,7 +471,7 @@ async fn iouring_task( ); metrics::tun_iouring_tx_err(); } - let (_, _, state) = tx_pool.get_buffer(idx as _); + let (_, _, state) = config.tx_pool.get_buffer(idx as _); state.store(false, Ordering::Release); // mark as available for send } } From 758ccaf22b0dccb36909c5bda1bab79a035d0897 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 11 Feb 2025 14:25:42 +0200 Subject: [PATCH 14/18] test: strace to find where tun-sock originated --- lightway-app-utils/src/iouring.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index 49d80923..a13c448d 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -141,11 +141,11 @@ impl IOUring { // Safety: Ring is initialized and file descriptor is valid unsafe { - sq.push( - &opcode::RecvMulti::new(types::Fd(fd), RX_BUFFER_GROUP) + let op = opcode::RecvMulti::new(types::Fd(fd), RX_BUFFER_GROUP) .build() - .user_data(IOUringActionID::ReceivedBuffer as u64), - )? + .user_data(IOUringActionID::ReceivedBuffer as u64); + tracing::debug!("Started recv-multi: {}", op); + sq.push(&op)?; }; } From e52ba729ac1ced91acf4fd377a6050e14e4dceb9 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Mon, 17 Feb 2025 17:49:56 +0200 Subject: [PATCH 15/18] feat: multi-kernel version (+rlimit fix) --- lightway-app-utils/src/iouring.rs | 425 +++++++++++++++++++++++++++--- tests/e2e/docker-compose.yml | 8 + 2 files changed, 395 insertions(+), 38 deletions(-) diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index a13c448d..2782eb2a 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -1,12 +1,13 @@ use crate::metrics; -use anyhow::Result; +use anyhow::{Context, Result}; use bytes::BytesMut; -use io_uring::{IoUring, opcode, types}; +use io_uring::{opcode, squeue::PushError, types, IoUring}; use libc::iovec; use lightway_core::IOCallbackResult; use parking_lot::Mutex; use std::{ - os::unix::io::{AsRawFd, RawFd}, + alloc::{alloc_zeroed, dealloc, Layout}, + os::{fd::AsRawFd, unix::io::RawFd}, sync::{ Arc, atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -17,6 +18,225 @@ use std::{ use tokio::sync::Notify; use tokio_eventfd::EventFd; +// ------------------------------------------------------------- +// - IMPLEMENT read-multishot and RUNTIME variations - +// ------------------------------------------------------------- + +use io_uring::squeue::Entry; + +pub const IORING_OP_READ_MULTISHOT: u8 = 49; + +#[repr(C)] +pub struct CustomSQE { + pub opcode: u8, + pub flags: u8, + pub ioprio: u16, + pub fd: i32, + pub off_or_addr2: Union1, + pub addr_or_splice_off_in: Union2, + pub len: u32, + pub msg_flags: Union3, + pub user_data: u64, + pub buf_index: PackedU16, // Note: this is packed! + pub personality: u16, + pub splice_fd: Union5, + pub __pad2: [u64; 2], // The final union covers 16 bytes +} + +#[repr(C)] +pub union Union1 { + pub off: u64, + pub addr2: u64, + pub cmd_op: std::mem::ManuallyDrop, +} + +#[repr(C)] +pub struct CmdOp { + pub cmd_op: u32, + pub __pad1: u32, +} + +#[repr(C)] +pub union Union2 { + pub addr: u64, + pub splice_off_in: u64, + pub level_optname: std::mem::ManuallyDrop, +} + +#[repr(C)] +pub struct SockLevel { + pub level: u32, + pub optname: u32, +} + +#[repr(C)] +pub union Union3 { + pub rw_flags: i32, + pub fsync_flags: u32, + pub poll_events: u16, + pub poll32_events: u32, + pub sync_range_flags: u32, + pub msg_flags: u32, + pub timeout_flags: u32, + pub accept_flags: u32, + pub cancel_flags: u32, + pub open_flags: u32, + pub statx_flags: u32, + pub fadvise_advice: u32, + pub splice_flags: u32, + pub rename_flags: u32, + pub unlink_flags: u32, + pub hardlink_flags: u32, + pub xattr_flags: u32, + pub msg_ring_flags: u32, + pub uring_cmd_flags: u32, + pub waitid_flags: u32, + pub futex_flags: u32, + pub install_fd_flags: u32, + pub nop_flags: u32, +} + +#[repr(C, packed)] +pub struct PackedU16 { + pub buf_index: u16, +} + +#[repr(C)] +pub union Union5 { + pub splice_fd_in: i32, + pub file_index: u32, + pub optlen: u32, + pub addr_len_stuff: std::mem::ManuallyDrop, +} + +#[repr(C)] +pub struct AddrLenPad { + pub addr_len: u16, + pub __pad3: [u16; 1], +} + +impl Default for CustomSQE { + fn default() -> Self { + // Safety: memzero is ok + #[allow(unsafe_code)] + unsafe { + std::mem::zeroed() + } + } +} + +pub struct ReadMulti { + fd: i32, + buf_group: u16, + flags: i32, +} + +impl ReadMulti { + #[inline] + pub fn new(fd: i32, buf_group: u16) -> Self { + ReadMulti { + fd, + buf_group, + flags: 0, + } + } + + #[inline] + pub fn build(self) -> Entry { + let sqe = CustomSQE { + opcode: IORING_OP_READ_MULTISHOT as _, + flags: io_uring::squeue::Flags::BUFFER_SELECT.bits(), + fd: self.fd, + buf_index: PackedU16 { + buf_index: self.buf_group, + }, + msg_flags: Union3 { + msg_flags: self.flags as _, + }, + ..Default::default() + }; + + // Safety: CustomSQE has identical memory layout to io_uring_sqe + #[allow(unsafe_code)] + unsafe { + std::mem::transmute(sqe) + } + } +} + +// Static for one-time initialization +static INITIALIZED: AtomicBool = AtomicBool::new(false); +static SUPPORTED: AtomicBool = AtomicBool::new(false); + +#[cold] +fn initialize_kernel_check() -> bool { + let supported = std::fs::read_to_string("/proc/sys/kernel/osrelease") + .ok() + .and_then(|v| { + let version_numbers = v.split('-').next()?; + let parts: Vec<_> = version_numbers.split('.').collect(); + if parts.len() >= 2 { + Some((parts[0].parse::().ok()?, parts[1].parse::().ok()?)) + } else { + None + } + }) + .map_or(false, |(major, minor)| { + major > 6 || (major == 6 && minor >= 7) + }); + + SUPPORTED.store(supported, Ordering::Release); + INITIALIZED.store(true, Ordering::Release); + supported +} + +#[inline(always)] +pub fn kernel_supports_multishot() -> bool { + // Fast path - just load if initialized + if INITIALIZED.load(Ordering::Acquire) { + SUPPORTED.load(Ordering::Acquire) + } else { + // Slow path - do initialization + initialize_kernel_check() + } +} + +// Safety: SQE operations are always unsafe +/// Inline operation to ensure we queue reads without impacting runtime (multi-kernel) +#[inline(always)] +#[allow(unsafe_code)] +pub unsafe fn queue_reads( + sq: &mut io_uring::SubmissionQueue<'_>, + fd: i32, + n_entries: usize, + buf_group: u16, + user_data: u64, +) -> Result<(), PushError> { + if kernel_supports_multishot() { + tracing::debug!("Kernel supports - adding MULTISHOT_READ"); + // Safety: Ring is initialized and file descriptor is valid + unsafe { + let op = ReadMulti::new(fd, buf_group).build().user_data(user_data); + sq.push(&op) + } + } else { + tracing::debug!("NO Kernel support - adding {} READ", n_entries); + let mut ops = Vec::with_capacity(n_entries); + for _ in 0..n_entries { + let op = opcode::Read::new(types::Fd(fd), std::ptr::null_mut(), 0) + .buf_group(buf_group) + .build() + .flags(io_uring::squeue::Flags::BUFFER_SELECT) + .user_data(user_data); + ops.push(op); + } + // Safety: Ring is initialized and file descriptor is valid + unsafe { sq.push_multiple(&ops) } + } +} + +// ------------------------------------------------------------- + #[repr(u64)] enum IOUringActionID { RecycleBuffers = 0x10001000, @@ -25,6 +245,9 @@ enum IOUringActionID { } const RX_BUFFER_GROUP: u16 = 0xdead; +// Required 32MB for io-uring to function properly +const REQUIRED_RLIMIT_MEMLOCK_MAX: u64 = 32 * 1024 * 1024; + /// A wrapper around a raw pointer that guarantees thread safety through Arc ownership struct BufferPtr(*mut u8); @@ -41,35 +264,101 @@ impl BufferPtr { } } +struct PageAlignedBuffer { + ptr: *mut u8, + layout: Layout, + entry_size: usize, + num_entries: usize, +} + +#[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures exclusive access +unsafe impl Send for PageAlignedBuffer {} +#[allow(unsafe_code)] +// Safety: The pointer is owned by Arc which ensures synchronized access +unsafe impl Sync for PageAlignedBuffer {} + +impl PageAlignedBuffer { + fn new(entry_size: usize, num_entries: usize) -> Self { + #[allow(unsafe_code)] + // Safety: libc is not safe, variable is fine + let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as usize; + + // Round up entry_size to 16-byte alignment first + let aligned_entry_size = (entry_size + 15) & !15; + + // Calculate how many entries fit in one page + let entries_per_page = page_size / aligned_entry_size; + + // Calculate total pages needed + let pages_needed = num_entries.div_ceil(entries_per_page); + let total_size = pages_needed * page_size; + + let layout = Layout::from_size_align(total_size, page_size).expect("Invalid layout"); + + // Safety: allocate per layout selected (no aligned-allocator in rust) + #[allow(unsafe_code)] + let ptr = unsafe { alloc_zeroed(layout) }; + if ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + + Self { + ptr, + layout, + entry_size: aligned_entry_size, + num_entries, + } + } + + fn get_ptr(&self, idx: usize) -> *mut u8 { + assert!(idx < self.num_entries); + // Safety: asserted size within boundry before + #[allow(unsafe_code)] + unsafe { + self.ptr.add(idx * self.entry_size) + } + } + + fn as_ptr(&self) -> *mut u8 { + self.ptr + } +} + +impl Drop for PageAlignedBuffer { + fn drop(&mut self) { + // Safety: we know what layout we allocated (saved) + #[allow(unsafe_code)] + unsafe { + dealloc(self.ptr, self.layout); + } + } +} + /// A pool of buffers with an underlying contiguous memory block struct BufferPool { - data: Vec, // contiguous block of memory (all buffers) + data: PageAlignedBuffer, lengths: Vec, states: Vec, // 0 (false) = free, 1 (true) = in-use usage_idx: AtomicUsize, - buffer_size: usize, } impl BufferPool { fn new(entry_size: usize, pool_size: usize) -> Self { - // Ensure BUFFER_SIZE is multiple of 128-bit/16-byte (less cache-miss) - let buffer_size = (entry_size + 15) & !15; - Self { - data: vec![0u8; buffer_size * pool_size], + data: PageAlignedBuffer::new(entry_size, pool_size), lengths: (0..pool_size).map(|_| AtomicUsize::new(0)).collect(), states: (0..pool_size).map(|_| AtomicBool::new(false)).collect(), usage_idx: AtomicUsize::new(0), - buffer_size, } } fn get_buffer(&self, idx: usize) -> (BufferPtr, &AtomicUsize, &AtomicBool) { - // Safety: Index is bounds-checked by the caller, and buffer_size ensures no overflow - #[allow(unsafe_code)] - let ptr = unsafe { self.data.as_ptr().add(idx * self.buffer_size) as *mut u8 }; - - (BufferPtr(ptr), &self.lengths[idx], &self.states[idx]) + ( + BufferPtr(self.data.get_ptr(idx)), + &self.lengths[idx], + &self.states[idx], + ) } } @@ -99,6 +388,12 @@ impl IOUring { // NOTE: it's probably a good idea for now to allocate rx/tx/ring at the same size // this is because the VPN use-case usually has MTU-sized buffers going in-and-out + tracing::debug!( + "INIT io-uring, estimated memory (user | kernel): {}Mb | {}Mb", + (2 * (size_of::() + (mtu * ring_size))) / 1024 / 1024, + (ring_size * (16 + (2 * 64)) + 8192) / 1024 / 1024, + ); + let rx_pool = Arc::new(BufferPool::new(mtu, ring_size)); let tx_pool = Arc::new(BufferPool::new(mtu, ring_size)); @@ -124,11 +419,12 @@ impl IOUring { // Safety: Ring submission can be used without locks at this point let mut sq = unsafe { ring.submission_shared() }; + tracing::debug!("Sending PROVIDE_BUFFERS"); // Safety: Buffer memory is owned by rx_pool and outlives the usage unsafe { sq.push( &opcode::ProvideBuffers::new( - rx_pool.data.as_ptr() as *mut u8, + rx_pool.data.as_ptr(), mtu as i32, ring_size as u16, RX_BUFFER_GROUP, @@ -141,15 +437,16 @@ impl IOUring { // Safety: Ring is initialized and file descriptor is valid unsafe { - let op = opcode::RecvMulti::new(types::Fd(fd), RX_BUFFER_GROUP) - .build() - .user_data(IOUringActionID::ReceivedBuffer as u64); - tracing::debug!("Started recv-multi: {}", op); - sq.push(&op)?; + queue_reads( + &mut sq, + fd, + ring_size, + RX_BUFFER_GROUP, + IOUringActionID::ReceivedBuffer as _, + )? }; } - // A bit ineffective vs. calculate offset directly, but more maintainable let tx_iovecs: Vec<_> = (0..ring_size) .map(|idx| { let (ptr, _, _) = tx_pool.get_buffer(idx); @@ -160,8 +457,31 @@ impl IOUring { }) .collect(); + // Safety: memory for libc calls + let mut rlim: libc::rlimit = unsafe { std::mem::zeroed() }; + // Safety: fetch memory limitations defined + unsafe { + libc::getrlimit(libc::RLIMIT_MEMLOCK, &mut rlim); + } + + // Check memory usage needed + if rlim.rlim_max < REQUIRED_RLIMIT_MEMLOCK_MAX { + tracing::info!("RLIMIT too low ({}), adjusting", rlim.rlim_max); + rlim.rlim_max = REQUIRED_RLIMIT_MEMLOCK_MAX; + // Safety: rlimit API requires unsafe block + if unsafe { libc::setrlimit(libc::RLIMIT_MEMLOCK, &rlim) } != 0 { + tracing::warn!( + "Failed to set RLIMIT_MEMLOCK: {}", + std::io::Error::last_os_error() + ); + } + } + // Safety: tx_iovecs point to valid memory owned by tx_pool unsafe { ring.submitter().register_buffers(&tx_iovecs)? }; + ring.submitter() + .register_files(&[fd]) + .expect("io-uring support"); let config = IOUringTaskConfig { tun_fd: fd, @@ -174,6 +494,8 @@ impl IOUring { submission_lock: submission_lock.clone(), }; + // NOTE: currently we don't implement any Drop for class, it will require changes + // so until then, we can also ignore the need to close the FDs in rx_eventfd and owned_fd thread::Builder::new() .name("io_uring-main".to_string()) .spawn(move || { @@ -182,7 +504,8 @@ impl IOUring { .build() .expect("Failed building Tokio Runtime") .block_on(iouring_task(config)) - })?; + }) + .context("io_uring-task")?; Ok(Self { owned_fd, @@ -203,6 +526,7 @@ impl IOUring { /// Send packet on Tun device (push to RING and submit) pub fn try_send(&self, buf: BytesMut) -> IOCallbackResult { + tracing::debug!("try_send {} bytes", buf.len()); // For semantics, see recv() function below let idx = self .tx_pool @@ -214,7 +538,12 @@ impl IOUring { let (buffer, length, state) = self.tx_pool.get_buffer(idx); let len = buf.len(); - if len > length.load(Ordering::Relaxed) { + if len > self.tx_pool.data.entry_size { + tracing::warn!( + "We dont support buffer-splitting for now (max: {}, got: {})", + self.tx_pool.data.entry_size, + len + ); return IOCallbackResult::WouldBlock; } @@ -241,7 +570,7 @@ impl IOUring { // NOTE: IOUringActionID values have to be bigger then the ring-size // this is because we use here as data for send_fixed operations let write_op = opcode::WriteFixed::new( - types::Fd(self.owned_fd.as_raw_fd()), + types::Fixed(self.owned_fd.as_raw_fd() as _), buffer.as_ptr(), len as _, idx as _, @@ -250,16 +579,25 @@ impl IOUring { // NOTE: we set the index starting from after the RX_POOL part .user_data(idx as u64); + tracing::debug!("queuing WRITE_FIXED on buf-id {}", idx); + // Safely queue submission { let _guard = self.submission_lock.lock(); // Safety: protected by lock above let mut sq = unsafe { self.ring.submission_shared() }; - // Safety: entry uses buffers from rx_pool which outlive task using them + // Safety: entry uses buffers from tx_pool which outlive task using them unsafe { match sq.push(&write_op) { - Ok(_) => IOCallbackResult::Ok(len), - Err(_) => IOCallbackResult::WouldBlock, + Ok(_) => { + tracing::debug!("Successfully queued write for buffer {}", idx); + IOCallbackResult::Ok(len) + } + Err(_) => { + tracing::warn!("Failed to queue send"); + metrics::tun_iouring_tx_err(); + IOCallbackResult::WouldBlock + } } } } @@ -291,6 +629,7 @@ impl IOUring { .unwrap(); let (buffer, length, state) = self.rx_pool.get_buffer(idx); + tracing::debug!("recv blocking until buf-id {} is available", idx); loop { // NOTE: unlike the above case, here we can use Relaxed ordering for better performance. // This is because we don't use the value in a closure, so we don't care for ensuring it's current value @@ -301,6 +640,8 @@ impl IOUring { let len = length.load(Ordering::Acquire); let mut new_buf = BytesMut::with_capacity(len); + tracing::debug!("recv, got {} bytes", len); + // Safety: Buffer is allocated with sufficient size and ownership is checked via state unsafe { new_buf.extend_from_slice(std::slice::from_raw_parts(buffer.as_ptr(), len)) @@ -316,6 +657,7 @@ impl IOUring { .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { + tracing::debug!("Need to notify task out-of-buffers"); // Safety: buffer is defined on stack, event-fd outlives task using it unsafe { let val = 1u64; @@ -353,6 +695,8 @@ struct IOUringTaskConfig { async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { let mut eventfd_buf = [0u64; 1]; // Buffer for eventfd read (8 bytes) + tracing::debug!("Started iouring_task, queuing eventfd read"); + // Submit initial read for eventfd (needs to be here for buffer to be on stack of the task) { let _guard = config.submission_lock.lock(); @@ -385,7 +729,9 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { } x if x == IOUringActionID::RecyclePending as u64 => { - if cqe.result() > 0 { + let cqe_res = cqe.result(); + tracing::debug!("Out of buffers ({}) - recycling", cqe_res); + if cqe_res > 0 { // Got notification we need more buffers // NOTE: This approach is very good for cases we have constant data-flow // we can only load the buffers for kernel when our read-threads are done with existing data, @@ -397,12 +743,13 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { // Make sure kernel can use all buffers again { + let rx_ring_size = config.rx_pool.states.len(); // Safety: buffers are mapped from rx_pool which outlives this task unsafe { sq.push( &opcode::ProvideBuffers::new( - config.rx_pool.data.as_ptr() as *mut u8, - config.rx_pool.buffer_size as i32, + config.rx_pool.data.as_ptr(), + rx_ring_size as i32, config.rx_pool.states.len() as u16, RX_BUFFER_GROUP, 0, @@ -413,13 +760,12 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { }; // Safety: buffer-group originates from rx_pool which outlives this task unsafe { - sq.push( - &opcode::RecvMulti::new( - types::Fd(config.tun_fd), - RX_BUFFER_GROUP, - ) - .build() - .user_data(IOUringActionID::ReceivedBuffer as u64), + queue_reads( + &mut sq, + config.tun_fd, + rx_ring_size, + RX_BUFFER_GROUP, + IOUringActionID::ReceivedBuffer as _, )? }; // Safety: Event-fd outlives the task, buffer is task-bound (stack) @@ -452,6 +798,8 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { let buf_id = io_uring::cqueue::buffer_select(cqe.flags()).unwrap(); let (_, length, state) = config.rx_pool.get_buffer(buf_id as _); + tracing::debug!("recv {} bytes, saving to buf-id {}", result, buf_id); + length.store(result as usize, Ordering::Release); state.store(true, Ordering::Release); // Mark as ready-for-user config.rx_notify.notify_waiters(); @@ -471,6 +819,7 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { ); metrics::tun_iouring_tx_err(); } + tracing::debug!("sent {} bytes from buf-id {}", result, idx); let (_, _, state) = config.tx_pool.get_buffer(idx as _); state.store(false, Ordering::Release); // mark as available for send } diff --git a/tests/e2e/docker-compose.yml b/tests/e2e/docker-compose.yml index 453a8349..7ad94165 100644 --- a/tests/e2e/docker-compose.yml +++ b/tests/e2e/docker-compose.yml @@ -24,6 +24,10 @@ services: stop_grace_period: 10s cap_add: - NET_ADMIN + ulimits: + memlock: + soft: -1 + hard: -1 devices: - "/dev/net/tun:/dev/net/tun" networks: @@ -54,6 +58,10 @@ services: - net.ipv4.conf.all.promote_secondaries=1 cap_add: - NET_ADMIN + ulimits: + memlock: + soft: -1 + hard: -1 devices: - "/dev/net/tun:/dev/net/tun" depends_on: From 66903ed6f191bd7ea9c42fc1b6f556f4f1e158cf Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Sun, 23 Feb 2025 14:07:06 +0200 Subject: [PATCH 16/18] feat: add internal-testing cargo, remove dead-code --- Cargo.lock | 44 ++++ Cargo.toml | 1 + lightway-app-utils/src/iouring.rs | 207 ++++++---------- test-iouring/Cargo.toml | 59 +++++ test-iouring/src/main.rs | 395 ++++++++++++++++++++++++++++++ 5 files changed, 569 insertions(+), 137 deletions(-) create mode 100644 test-iouring/Cargo.toml create mode 100644 test-iouring/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 1a4db3be..f661c259 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2269,6 +2269,50 @@ dependencies = [ "test-case-core", ] +[[package]] +name = "test-iouring" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "average", + "bytes", + "bytesize", + "clap", + "ctrlc", + "delegate", + "educe", + "hashbrown 0.15.2", + "io-uring", + "ipnet", + "jsonwebtoken", + "libc", + "lightway-app-utils", + "lightway-core", + "metrics", + "metrics-util", + "more-asserts", + "parking_lot", + "pnet", + "ppp", + "pwhash", + "rand", + "serde", + "serde_json", + "socket2", + "strum", + "test-case", + "thiserror 2.0.11", + "time", + "tokio", + "tokio-stream", + "tracing", + "tracing-log", + "tracing-subscriber", + "tun", + "twelf", +] + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index febb033c..ddaa5232 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "lightway-app-utils", "lightway-client", "lightway-server", + "test-iouring", ] resolver = "3" diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index 2782eb2a..18ae3fb5 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -7,7 +7,7 @@ use lightway_core::IOCallbackResult; use parking_lot::Mutex; use std::{ alloc::{alloc_zeroed, dealloc, Layout}, - os::{fd::AsRawFd, unix::io::RawFd}, + os::fd::AsRawFd, sync::{ Arc, atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -16,12 +16,13 @@ use std::{ time::Duration, }; use tokio::sync::Notify; -use tokio_eventfd::EventFd; // ------------------------------------------------------------- // - IMPLEMENT read-multishot and RUNTIME variations - // ------------------------------------------------------------- +// NOTE: temp until this is merged: https://github.com/tokio-rs/io-uring/pull/317 + use io_uring::squeue::Entry; pub const IORING_OP_READ_MULTISHOT: u8 = 49; @@ -241,7 +242,6 @@ pub unsafe fn queue_reads( enum IOUringActionID { RecycleBuffers = 0x10001000, ReceivedBuffer = 0xfeedfeed, - RecyclePending = 0xdead1000, } const RX_BUFFER_GROUP: u16 = 0xdead; @@ -368,8 +368,6 @@ pub struct IOUring { rx_pool: Arc, tx_pool: Arc, rx_notify: Arc, - rx_eventfd: EventFd, - rx_provide_buffers: Arc, ring: Arc, submission_lock: Arc>, } @@ -391,7 +389,7 @@ impl IOUring { tracing::debug!( "INIT io-uring, estimated memory (user | kernel): {}Mb | {}Mb", (2 * (size_of::() + (mtu * ring_size))) / 1024 / 1024, - (ring_size * (16 + (2 * 64)) + 8192) / 1024 / 1024, + (ring_size * 2 * (16 + (2 * 64)) + 8192) / 1024 / 1024, ); let rx_pool = Arc::new(BufferPool::new(mtu, ring_size)); @@ -400,12 +398,10 @@ impl IOUring { let ring = Arc::new( IoUring::builder() .setup_sqpoll(sqpoll_idle_time.as_millis() as u32) - .build(ring_size as u32)?, + .build((ring_size * 2) as u32)?, ); let rx_notify = Arc::new(Notify::new()); - let rx_eventfd = EventFd::new(0, false)?; - let rx_provide_buffers = Arc::new(AtomicBool::new(false)); // NOTE: for now this ensures we only create 1 kthread per tunnel, and not 2 (rx/tx) // we can opt to change this going forward, or redo the structure to not need a lock @@ -484,14 +480,10 @@ impl IOUring { .expect("io-uring support"); let config = IOUringTaskConfig { - tun_fd: fd, rx_pool: rx_pool.clone(), tx_pool: tx_pool.clone(), rx_notify: rx_notify.clone(), - rx_eventfd: rx_eventfd.as_raw_fd(), - rx_provide_buffers: rx_provide_buffers.clone(), ring: ring.clone(), - submission_lock: submission_lock.clone(), }; // NOTE: currently we don't implement any Drop for class, it will require changes @@ -512,8 +504,6 @@ impl IOUring { rx_pool, tx_pool, rx_notify, - rx_eventfd, - rx_provide_buffers, ring, submission_lock, }) @@ -569,15 +559,10 @@ impl IOUring { // NOTE: IOUringActionID values have to be bigger then the ring-size // this is because we use here as data for send_fixed operations - let write_op = opcode::WriteFixed::new( - types::Fixed(self.owned_fd.as_raw_fd() as _), - buffer.as_ptr(), - len as _, - idx as _, - ) - .build() - // NOTE: we set the index starting from after the RX_POOL part - .user_data(idx as u64); + let write_op = + opcode::WriteFixed::new(types::Fixed(0), buffer.as_ptr(), len as _, idx as _) + .build() + .user_data(idx as u64); tracing::debug!("queuing WRITE_FIXED on buf-id {}", idx); @@ -588,11 +573,22 @@ impl IOUring { let mut sq = unsafe { self.ring.submission_shared() }; // Safety: entry uses buffers from tx_pool which outlive task using them unsafe { + // let res = libc::write( + // self.owned_fd.as_raw_fd(), + // buffer.as_ptr() as *const libc::c_void, + // len, + // ); + // tracing::debug!("write (sync) results: {}", res); + // if res > 0 { + // return IOCallbackResult::Ok(res as usize); + // } + + // let err = std::io::Error::last_os_error(); + // tracing::error!("write faild: {}", err); + // IOCallbackResult::Err(err) + match sq.push(&write_op) { - Ok(_) => { - tracing::debug!("Successfully queued write for buffer {}", idx); - IOCallbackResult::Ok(len) - } + Ok(_) => IOCallbackResult::Ok(len), Err(_) => { tracing::warn!("Failed to queue send"); metrics::tun_iouring_tx_err(); @@ -637,6 +633,41 @@ impl IOUring { .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { + // Last buffer - need to reload + // NOTE: this is why io_uring is not really practical in a lot of use-cases... + if idx + 1 == self.rx_pool.data.num_entries { + let _guard = self.submission_lock.lock(); + // Safety: protected by lock above + let mut sq = unsafe { self.ring.submission_shared() }; + let rx_ring_size = self.rx_pool.states.len(); + // Safety: buffers are mapped from rx_pool which outlives this task + unsafe { + sq.push( + &opcode::ProvideBuffers::new( + self.rx_pool.data.as_ptr(), + rx_ring_size as i32, + self.rx_pool.states.len() as u16, + RX_BUFFER_GROUP, + 0, + ) + .build() + .user_data(IOUringActionID::RecycleBuffers as u64), + ) + .expect("iouring queue should work") + }; + // Safety: buffer-group originates from rx_pool which outlives this task + unsafe { + queue_reads( + &mut sq, + self.owned_fd.as_raw_fd(), + rx_ring_size, + RX_BUFFER_GROUP, + IOUringActionID::ReceivedBuffer as _, + ) + .expect("iouring queue should work") + }; + } + let len = length.load(Ordering::Acquire); let mut new_buf = BytesMut::with_capacity(len); @@ -650,76 +681,29 @@ impl IOUring { } // IO-Bound wait for available buffers self.rx_notify.notified().await; - - // Check if kernel needs more buffers (and ensure only one notification is sent) - if self - .rx_provide_buffers - .compare_exchange(true, false, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - tracing::debug!("Need to notify task out-of-buffers"); - // Safety: buffer is defined on stack, event-fd outlives task using it - unsafe { - let val = 1u64; - if libc::write( - self.rx_eventfd.as_raw_fd(), - &val as *const u64 as *const _, - 8, - ) < 0 - { - let err = std::io::Error::last_os_error(); - tracing::error!("Failed to write to eventfd: {}", err); - // The following is a prayer to god to hopefully succeed next time around - self.rx_provide_buffers.store(true, Ordering::Release); - } - } - } } } } /// Task variables struct IOUringTaskConfig { - tun_fd: RawFd, rx_pool: Arc, tx_pool: Arc, rx_notify: Arc, - rx_eventfd: RawFd, - rx_provide_buffers: Arc, ring: Arc, - submission_lock: Arc>, } // Safety: To manage ring completion and results effeciantly requires direct memory manipulations #[allow(unsafe_code)] async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { - let mut eventfd_buf = [0u64; 1]; // Buffer for eventfd read (8 bytes) - - tracing::debug!("Started iouring_task, queuing eventfd read"); - - // Submit initial read for eventfd (needs to be here for buffer to be on stack of the task) - { - let _guard = config.submission_lock.lock(); - // Safety: protected by above lock - let mut sq = unsafe { config.ring.submission_shared() }; - // Safety: event-fd outlives the task, queue protected by lock - unsafe { - sq.push( - &opcode::Read::new( - types::Fd(config.rx_eventfd), - eventfd_buf.as_mut_ptr() as *mut u8, - 8, - ) - .build() - .user_data(IOUringActionID::RecyclePending as u64), - )? - }; - } + tracing::debug!("Started iouring_task"); loop { // Work once we have at least 1 task to perform config.ring.submit_and_wait(1)?; + tracing::debug!("iotask woke up"); + // Safety: only task is using the completion-queue (concept should not change) for cqe in unsafe { config.ring.completion_shared() } { match cqe.user_data() { @@ -728,62 +712,6 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { tracing::debug!("Buffer provision completed"); } - x if x == IOUringActionID::RecyclePending as u64 => { - let cqe_res = cqe.result(); - tracing::debug!("Out of buffers ({}) - recycling", cqe_res); - if cqe_res > 0 { - // Got notification we need more buffers - // NOTE: This approach is very good for cases we have constant data-flow - // we can only load the buffers for kernel when our read-threads are done with existing data, - // if our read-threads would block for too long elsewhere it would back-pressure the NIF device - let _guard = config.submission_lock.lock(); - - // Safety: protected by above lock - let mut sq = unsafe { config.ring.submission_shared() }; - - // Make sure kernel can use all buffers again - { - let rx_ring_size = config.rx_pool.states.len(); - // Safety: buffers are mapped from rx_pool which outlives this task - unsafe { - sq.push( - &opcode::ProvideBuffers::new( - config.rx_pool.data.as_ptr(), - rx_ring_size as i32, - config.rx_pool.states.len() as u16, - RX_BUFFER_GROUP, - 0, - ) - .build() - .user_data(IOUringActionID::RecycleBuffers as u64), - )? - }; - // Safety: buffer-group originates from rx_pool which outlives this task - unsafe { - queue_reads( - &mut sq, - config.tun_fd, - rx_ring_size, - RX_BUFFER_GROUP, - IOUringActionID::ReceivedBuffer as _, - )? - }; - // Safety: Event-fd outlives the task, buffer is task-bound (stack) - unsafe { - sq.push( - &opcode::Read::new( - types::Fd(config.rx_eventfd), - eventfd_buf.as_mut_ptr() as *mut u8, - 8, - ) - .build() - .user_data(IOUringActionID::RecyclePending as u64), - )? - }; - } - } - } - x if x == IOUringActionID::ReceivedBuffer as u64 => { let result = cqe.result(); if result < 0 { @@ -804,9 +732,14 @@ async fn iouring_task(config: IOUringTaskConfig) -> Result<()> { state.store(true, Ordering::Release); // Mark as ready-for-user config.rx_notify.notify_waiters(); - if !io_uring::cqueue::more(cqe.flags()) { - config.rx_provide_buffers.store(true, Ordering::Release); - } + // TODO: consider below implementation in the future + // issue with this is that we have to gurentee no in-flight buffers ! + // see the comment under `recv` function, we can consider a buffer migration. + // NOTE: Here if we use new kernels we can auto-opt for multishot via: + // if !io_uring::cqueue::more(cqe.flags()) { + // let opt = ReadMulti::new(fd, buf_group).build().user_data(IOUringActionID::ReceivedBuffer); + // unsafe { sq.push(&opt) }; + // } } idx => { diff --git a/test-iouring/Cargo.toml b/test-iouring/Cargo.toml new file mode 100644 index 00000000..588bff46 --- /dev/null +++ b/test-iouring/Cargo.toml @@ -0,0 +1,59 @@ +[package] +name = "test-iouring" +version = "0.1.0" +repository = "https://github.com/expressvpn/lightway" +edition = "2021" +authors = ["lightway-developers@expressvpn.com"] +license = "AGPL-3.0-only" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = ["io-uring"] +debug = ["lightway-core/debug"] +io-uring = ["lightway-app-utils/io-uring", "dep:io-uring"] + +[lints] +workspace = true + +[dependencies] +anyhow.workspace = true +async-trait.workspace = true +average = "0.15.1" +bytes.workspace = true +bytesize.workspace = true +clap.workspace = true +ctrlc.workspace = true +delegate.workspace = true +educe.workspace = true +hashbrown = "0.15.2" +ipnet.workspace = true +jsonwebtoken = "9.3.0" +libc.workspace = true +lightway-app-utils.workspace = true +lightway-core = { workspace = true, features = ["postquantum"] } +metrics.workspace = true +metrics-util = "0.18.0" +parking_lot = "0.12.3" +pnet.workspace = true +ppp = "2.2.0" +pwhash = "1.0.0" +rand.workspace = true +serde.workspace = true +serde_json = "1.0.128" +socket2.workspace = true +strum = { version = "0.26.3", features = ["derive"] } +thiserror.workspace = true +time = "0.3.29" +tokio.workspace = true +tokio-stream = { workspace = true, features = ["time"] } +tracing.workspace = true +tracing-log = "0.2.0" +tracing-subscriber = { workspace = true, features = ["json"] } +twelf.workspace = true +tun = { version = "0.7", features = ["async"] } +io-uring = { version = "0.7.0", optional = true } + +[dev-dependencies] +more-asserts.workspace = true +test-case.workspace = true diff --git a/test-iouring/src/main.rs b/test-iouring/src/main.rs new file mode 100644 index 00000000..3fc9761a --- /dev/null +++ b/test-iouring/src/main.rs @@ -0,0 +1,395 @@ +// use anyhow::{Context, Result}; +// use bytes::BytesMut; +// use lightway_app_utils::TunIoUring; +// use lightway_core::IOCallbackResult; +// use std::{env, process::Command, time::Duration}; +// use tokio::{net::UdpSocket, time::sleep}; +// use tracing::{debug, error, info, warn}; +// use tun::{Configuration, Layer}; + +// struct NetNSSetup { +// ns1_name: String, +// ns2_name: String, +// veth1_name: String, +// veth2_name: String, +// ns1_ip: String, +// ns2_ip: String, +// } + +// // Keep NetNSSetup struct as is, but add logging to its methods +// impl NetNSSetup { +// fn new(prefix: &str) -> Self { +// Self { +// ns1_name: format!("{}_ns1", prefix), +// ns2_name: format!("{}_ns2", prefix), +// veth1_name: format!("{}_veth1", prefix), +// veth2_name: format!("{}_veth2", prefix), +// ns1_ip: "10.0.0.1/24".to_string(), +// ns2_ip: "10.0.0.2/24".to_string(), +// } +// } + +// fn setup(&self) -> Result<()> { +// info!("Creating network namespaces"); + +// debug!("Creating namespace {}", self.ns1_name); +// Command::new("ip") +// .args(["netns", "add", &self.ns1_name]) +// .status() +// .context("Failed to create ns1")?; + +// debug!("Creating namespace {}", self.ns2_name); +// Command::new("ip") +// .args(["netns", "add", &self.ns2_name]) +// .status() +// .context("Failed to create ns2")?; + +// info!("Creating veth pair"); +// Command::new("ip") +// .args([ +// "link", +// "add", +// &self.veth1_name, +// "type", +// "veth", +// "peer", +// "name", +// &self.veth2_name, +// ]) +// .status() +// .context("Failed to create veth pair")?; + +// info!("Moving interfaces to namespaces"); +// Command::new("ip") +// .args(["link", "set", &self.veth1_name, "netns", &self.ns1_name]) +// .status() +// .context("Failed to move veth1")?; + +// Command::new("ip") +// .args(["link", "set", &self.veth2_name, "netns", &self.ns2_name]) +// .status() +// .context("Failed to move veth2")?; + +// info!("Configuring IP addresses"); +// Command::new("ip") +// .args([ +// "netns", +// "exec", +// &self.ns1_name, +// "ip", +// "addr", +// "add", +// &self.ns1_ip, +// "dev", +// &self.veth1_name, +// ]) +// .status() +// .context("Failed to set ns1 IP")?; + +// Command::new("ip") +// .args([ +// "netns", +// "exec", +// &self.ns2_name, +// "ip", +// "addr", +// "add", +// &self.ns2_ip, +// "dev", +// &self.veth2_name, +// ]) +// .status() +// .context("Failed to set ns2 IP")?; + +// info!("Bringing up interfaces"); +// Command::new("ip") +// .args([ +// "netns", +// "exec", +// &self.ns1_name, +// "ip", +// "link", +// "set", +// &self.veth1_name, +// "up", +// ]) +// .status() +// .context("Failed to bring up veth1")?; + +// Command::new("ip") +// .args([ +// "netns", +// "exec", +// &self.ns2_name, +// "ip", +// "link", +// "set", +// &self.veth2_name, +// "up", +// ]) +// .status() +// .context("Failed to bring up veth2")?; + +// Ok(()) +// } + +// fn cleanup(&self) -> Result<()> { +// info!("Cleaning up network namespaces"); + +// debug!("Deleting namespace {}", self.ns1_name); +// Command::new("ip") +// .args(["netns", "del", &self.ns1_name]) +// .status() +// .context("Failed to delete ns1")?; + +// debug!("Deleting namespace {}", self.ns2_name); +// Command::new("ip") +// .args(["netns", "del", &self.ns2_name]) +// .status() +// .context("Failed to delete ns2")?; + +// Ok(()) +// } +// } + +// struct Tunnel { +// io_uring: TunIoUring, +// transport: UdpSocket, +// client_endpoint: Option, +// } + +// impl Tunnel { +// async fn new() -> Result<()> { +// info!("Initializing tunnel device"); + +// // Setup TUN device +// let mut config = Configuration::default(); +// config.tun_name("tun0"); +// config.layer(Layer::L3); +// config.mtu(1500); +// config.up(); + +// debug!("Creating TUN device with config: {:?}", config); +// let io_uring = TunIoUring::new(config, 1024, Duration::from_secs(1)) +// .await +// .context("Failed to create TUN device")?; + +// // Configure TUN IP +// info!("Configuring TUN device IP"); +// Command::new("ip") +// .args(["addr", "add", "10.0.0.10/24", "dev", "tun0"]) +// .status() +// .context("Failed to set TUN IP")?; + +// Command::new("ip") +// .args(["link", "set", "dev", "tun0", "up"]) +// .status() +// .context("Failed to bring up TUN")?; + +// // Create UDP transport +// info!("Creating UDP transport socket"); +// let transport = UdpSocket::bind("0.0.0.0:4789") +// .await +// .context("Failed to bind UDP socket")?; +// info!("UDP transport listening on port 4789"); + +// let mut tunnel = Self { +// io_uring, +// transport, +// client_endpoint: None, +// }; + +// info!("Starting tunnel operation"); +// tunnel.run().await +// } + +// async fn run(&mut self) -> Result<()> { +// let mut udp_buf = [0u8; 2000]; + +// loop { +// tokio::select! { +// tun_result = self.io_uring.recv_buf() => { +// match tun_result { +// IOCallbackResult::Ok(buf) => { +// if let Some(client) = self.client_endpoint { +// debug!("Forwarding {} bytes from TUN to client {}", buf.len(), client); +// if let Err(e) = self.transport.send_to(&buf, client).await { +// error!("Failed to send to client {}: {}", client, e); +// } +// } +// } +// IOCallbackResult::WouldBlock => { +// debug!("TUN receive would block"); +// sleep(Duration::from_millis(10)).await; +// } +// IOCallbackResult::Err(e) => { +// error!("TUN receive error: {}", e); +// } +// } +// } + +// udp_result = self.transport.recv_from(&mut udp_buf) => { +// match udp_result { +// Ok((size, addr)) => { +// if self.client_endpoint.is_none() { +// info!("New client connected from {}", addr); +// self.client_endpoint = Some(addr); +// } + +// if Some(addr) == self.client_endpoint { +// debug!("Received {} bytes from client {}", size, addr); +// let packet = BytesMut::from(&udp_buf[..size]); +// match self.io_uring.try_send(packet) { +// IOCallbackResult::Ok(n) => { +// debug!("Wrote {} bytes to TUN", n); +// } +// IOCallbackResult::WouldBlock => { +// warn!("TUN send would block"); +// sleep(Duration::from_millis(10)).await; +// } +// IOCallbackResult::Err(e) => { +// error!("TUN send error: {}", e); +// } +// } +// } else { +// warn!("Ignored packet from unknown client {}", addr); +// } +// } +// Err(e) => { +// error!("UDP receive error: {}", e); +// } +// } +// } +// } +// } +// } +// } + +// #[tokio::main] +// async fn main() -> Result<()> { +// // Initialize logging +// tracing_subscriber::fmt() +// .with_max_level(tracing::Level::DEBUG) +// .init(); + +// info!("Starting VPN tunnel server"); + +// // Check if we're running inside namespace +// if env::args().any(|arg| arg == "--in-namespace") { +// info!("Running in namespace, initializing tunnel"); +// Tunnel::new().await +// } else { +// info!("Setting up network namespaces"); + +// let ns_setup = NetNSSetup::new("test"); +// debug!( +// "Created namespace setup with: ns1={}, ns2={}", +// ns_setup.ns1_name, ns_setup.ns2_name +// ); + +// ns_setup +// .setup() +// .context("Failed to setup network namespaces")?; +// info!("Network namespaces configured successfully"); + +// info!("Starting tunnel in namespace test_ns1"); +// let status = Command::new("ip") +// .args([ +// "netns", +// "exec", +// "test_ns1", +// &env::current_exe()?.to_string_lossy(), +// "--in-namespace", +// ]) +// .status() +// .context("Failed to start tunnel in namespace")?; + +// info!("Tunnel exited with status: {}", status); + +// if !status.success() { +// error!("Tunnel failed with status: {}", status); +// } + +// info!("Cleaning up network namespaces"); +// ns_setup +// .cleanup() +// .context("Failed to cleanup network namespaces")?; +// info!("Cleanup complete"); + +// Ok(()) +// } +// } + +use anyhow::Result; +use io_uring::{opcode, types, IoUring}; +use libc::iovec; +use std::os::fd::AsRawFd; +use tun::{Configuration, Layer}; + +fn main() -> Result<()> { + // Create TUN device + let mut config = Configuration::default(); + config.tun_name("tun0"); + config.layer(Layer::L3); + config.up(); + + let tun = tun::create(&config)?; + println!("Created TUN device: tun0"); + + // Create a test IP packet (very basic IPv4) + let packet = [ + 0x45, 0x00, 0x00, 0x14, // IPv4, len=20 + 0x00, 0x00, 0x40, 0x00, // DF flag + 0x40, 0x01, 0x00, 0x00, // TTL=64, proto=ICMP + 0x0a, 0x00, 0x00, 0x0a, // src: 10.0.0.10 + 0x0a, 0x00, 0x00, 0x02, // dst: 10.0.0.2 + ]; + + let iov = iovec { + iov_base: packet.as_ptr() as *mut libc::c_void, + iov_len: packet.len(), + }; + + // Setup ring + let mut ring = IoUring::new(8)?; + println!("Created IO_URING"); + + // Register buffer and file + #[allow(unsafe_code)] + unsafe { + ring.submitter().register_buffers(std::slice::from_ref(&iov))?; + println!("Registered buffer"); + ring.submitter().register_files(&[tun.as_raw_fd()])?; + println!("Registered TUN fd"); + } + + // Create WriteFixed operation + let write_op = opcode::WriteFixed::new( + types::Fixed(0), // registered file index + packet.as_ptr(), // buffer pointer + packet.len() as _, // length + 0, // registered buffer index + ) + .build() + .user_data(100); + + println!("Created write operation"); + + // Queue operation + #[allow(unsafe_code)] + unsafe { + ring.submission() + .push(&write_op)?; + } + println!("Queued operation"); + + // Submit and wait for completion + ring.submit_and_wait(1)?; + println!("Submitted and waiting"); + + // Check completion + let cqe = ring.completion().next().expect("completion queue empty"); + println!("Write result: {}", cqe.result()); + + Ok(()) +} \ No newline at end of file From 4e6c30e1782813b0efe8024d1baaccfad59085bb Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Sun, 23 Feb 2025 14:26:26 +0200 Subject: [PATCH 17/18] edition-fixes --- lightway-app-utils/src/iouring.rs | 8 +++----- test-iouring/src/main.rs | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/lightway-app-utils/src/iouring.rs b/lightway-app-utils/src/iouring.rs index 18ae3fb5..a5a92ebc 100644 --- a/lightway-app-utils/src/iouring.rs +++ b/lightway-app-utils/src/iouring.rs @@ -1,12 +1,12 @@ use crate::metrics; use anyhow::{Context, Result}; use bytes::BytesMut; -use io_uring::{opcode, squeue::PushError, types, IoUring}; +use io_uring::{IoUring, opcode, squeue::PushError, types}; use libc::iovec; use lightway_core::IOCallbackResult; use parking_lot::Mutex; use std::{ - alloc::{alloc_zeroed, dealloc, Layout}, + alloc::{Layout, alloc_zeroed, dealloc}, os::fd::AsRawFd, sync::{ Arc, @@ -182,9 +182,7 @@ fn initialize_kernel_check() -> bool { None } }) - .map_or(false, |(major, minor)| { - major > 6 || (major == 6 && minor >= 7) - }); + .is_some_and(|(major, minor)| major > 6 || (major == 6 && minor >= 7)); SUPPORTED.store(supported, Ordering::Release); INITIALIZED.store(true, Ordering::Release); diff --git a/test-iouring/src/main.rs b/test-iouring/src/main.rs index 3fc9761a..9d931a50 100644 --- a/test-iouring/src/main.rs +++ b/test-iouring/src/main.rs @@ -353,11 +353,13 @@ fn main() -> Result<()> { // Setup ring let mut ring = IoUring::new(8)?; println!("Created IO_URING"); - + // Register buffer and file #[allow(unsafe_code)] + // Safety: we manage buffer lifecycle unsafe { - ring.submitter().register_buffers(std::slice::from_ref(&iov))?; + ring.submitter() + .register_buffers(std::slice::from_ref(&iov))?; println!("Registered buffer"); ring.submitter().register_files(&[tun.as_raw_fd()])?; println!("Registered TUN fd"); @@ -365,10 +367,10 @@ fn main() -> Result<()> { // Create WriteFixed operation let write_op = opcode::WriteFixed::new( - types::Fixed(0), // registered file index - packet.as_ptr(), // buffer pointer - packet.len() as _, // length - 0, // registered buffer index + types::Fixed(0), // registered file index + packet.as_ptr(), // buffer pointer + packet.len() as _, // length + 0, // registered buffer index ) .build() .user_data(100); @@ -377,9 +379,9 @@ fn main() -> Result<()> { // Queue operation #[allow(unsafe_code)] + // Safety: io_uring crate works unsafe { - ring.submission() - .push(&write_op)?; + ring.submission().push(&write_op)?; } println!("Queued operation"); @@ -392,4 +394,4 @@ fn main() -> Result<()> { println!("Write result: {}", cqe.result()); Ok(()) -} \ No newline at end of file +} From 5dbdad5ba31b6990717d0055cc67008b865e0dc7 Mon Sep 17 00:00:00 2001 From: Omer Shamash Date: Tue, 25 Feb 2025 11:04:34 +0200 Subject: [PATCH 18/18] remove test-iouring (CI issues) this way it'll be kept in commit-history for reference --- Cargo.lock | 44 ----- Cargo.toml | 1 - test-iouring/Cargo.toml | 59 ------ test-iouring/src/main.rs | 397 --------------------------------------- 4 files changed, 501 deletions(-) delete mode 100644 test-iouring/Cargo.toml delete mode 100644 test-iouring/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index f661c259..1a4db3be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2269,50 +2269,6 @@ dependencies = [ "test-case-core", ] -[[package]] -name = "test-iouring" -version = "0.1.0" -dependencies = [ - "anyhow", - "async-trait", - "average", - "bytes", - "bytesize", - "clap", - "ctrlc", - "delegate", - "educe", - "hashbrown 0.15.2", - "io-uring", - "ipnet", - "jsonwebtoken", - "libc", - "lightway-app-utils", - "lightway-core", - "metrics", - "metrics-util", - "more-asserts", - "parking_lot", - "pnet", - "ppp", - "pwhash", - "rand", - "serde", - "serde_json", - "socket2", - "strum", - "test-case", - "thiserror 2.0.11", - "time", - "tokio", - "tokio-stream", - "tracing", - "tracing-log", - "tracing-subscriber", - "tun", - "twelf", -] - [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index ddaa5232..febb033c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "lightway-app-utils", "lightway-client", "lightway-server", - "test-iouring", ] resolver = "3" diff --git a/test-iouring/Cargo.toml b/test-iouring/Cargo.toml deleted file mode 100644 index 588bff46..00000000 --- a/test-iouring/Cargo.toml +++ /dev/null @@ -1,59 +0,0 @@ -[package] -name = "test-iouring" -version = "0.1.0" -repository = "https://github.com/expressvpn/lightway" -edition = "2021" -authors = ["lightway-developers@expressvpn.com"] -license = "AGPL-3.0-only" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[features] -default = ["io-uring"] -debug = ["lightway-core/debug"] -io-uring = ["lightway-app-utils/io-uring", "dep:io-uring"] - -[lints] -workspace = true - -[dependencies] -anyhow.workspace = true -async-trait.workspace = true -average = "0.15.1" -bytes.workspace = true -bytesize.workspace = true -clap.workspace = true -ctrlc.workspace = true -delegate.workspace = true -educe.workspace = true -hashbrown = "0.15.2" -ipnet.workspace = true -jsonwebtoken = "9.3.0" -libc.workspace = true -lightway-app-utils.workspace = true -lightway-core = { workspace = true, features = ["postquantum"] } -metrics.workspace = true -metrics-util = "0.18.0" -parking_lot = "0.12.3" -pnet.workspace = true -ppp = "2.2.0" -pwhash = "1.0.0" -rand.workspace = true -serde.workspace = true -serde_json = "1.0.128" -socket2.workspace = true -strum = { version = "0.26.3", features = ["derive"] } -thiserror.workspace = true -time = "0.3.29" -tokio.workspace = true -tokio-stream = { workspace = true, features = ["time"] } -tracing.workspace = true -tracing-log = "0.2.0" -tracing-subscriber = { workspace = true, features = ["json"] } -twelf.workspace = true -tun = { version = "0.7", features = ["async"] } -io-uring = { version = "0.7.0", optional = true } - -[dev-dependencies] -more-asserts.workspace = true -test-case.workspace = true diff --git a/test-iouring/src/main.rs b/test-iouring/src/main.rs deleted file mode 100644 index 9d931a50..00000000 --- a/test-iouring/src/main.rs +++ /dev/null @@ -1,397 +0,0 @@ -// use anyhow::{Context, Result}; -// use bytes::BytesMut; -// use lightway_app_utils::TunIoUring; -// use lightway_core::IOCallbackResult; -// use std::{env, process::Command, time::Duration}; -// use tokio::{net::UdpSocket, time::sleep}; -// use tracing::{debug, error, info, warn}; -// use tun::{Configuration, Layer}; - -// struct NetNSSetup { -// ns1_name: String, -// ns2_name: String, -// veth1_name: String, -// veth2_name: String, -// ns1_ip: String, -// ns2_ip: String, -// } - -// // Keep NetNSSetup struct as is, but add logging to its methods -// impl NetNSSetup { -// fn new(prefix: &str) -> Self { -// Self { -// ns1_name: format!("{}_ns1", prefix), -// ns2_name: format!("{}_ns2", prefix), -// veth1_name: format!("{}_veth1", prefix), -// veth2_name: format!("{}_veth2", prefix), -// ns1_ip: "10.0.0.1/24".to_string(), -// ns2_ip: "10.0.0.2/24".to_string(), -// } -// } - -// fn setup(&self) -> Result<()> { -// info!("Creating network namespaces"); - -// debug!("Creating namespace {}", self.ns1_name); -// Command::new("ip") -// .args(["netns", "add", &self.ns1_name]) -// .status() -// .context("Failed to create ns1")?; - -// debug!("Creating namespace {}", self.ns2_name); -// Command::new("ip") -// .args(["netns", "add", &self.ns2_name]) -// .status() -// .context("Failed to create ns2")?; - -// info!("Creating veth pair"); -// Command::new("ip") -// .args([ -// "link", -// "add", -// &self.veth1_name, -// "type", -// "veth", -// "peer", -// "name", -// &self.veth2_name, -// ]) -// .status() -// .context("Failed to create veth pair")?; - -// info!("Moving interfaces to namespaces"); -// Command::new("ip") -// .args(["link", "set", &self.veth1_name, "netns", &self.ns1_name]) -// .status() -// .context("Failed to move veth1")?; - -// Command::new("ip") -// .args(["link", "set", &self.veth2_name, "netns", &self.ns2_name]) -// .status() -// .context("Failed to move veth2")?; - -// info!("Configuring IP addresses"); -// Command::new("ip") -// .args([ -// "netns", -// "exec", -// &self.ns1_name, -// "ip", -// "addr", -// "add", -// &self.ns1_ip, -// "dev", -// &self.veth1_name, -// ]) -// .status() -// .context("Failed to set ns1 IP")?; - -// Command::new("ip") -// .args([ -// "netns", -// "exec", -// &self.ns2_name, -// "ip", -// "addr", -// "add", -// &self.ns2_ip, -// "dev", -// &self.veth2_name, -// ]) -// .status() -// .context("Failed to set ns2 IP")?; - -// info!("Bringing up interfaces"); -// Command::new("ip") -// .args([ -// "netns", -// "exec", -// &self.ns1_name, -// "ip", -// "link", -// "set", -// &self.veth1_name, -// "up", -// ]) -// .status() -// .context("Failed to bring up veth1")?; - -// Command::new("ip") -// .args([ -// "netns", -// "exec", -// &self.ns2_name, -// "ip", -// "link", -// "set", -// &self.veth2_name, -// "up", -// ]) -// .status() -// .context("Failed to bring up veth2")?; - -// Ok(()) -// } - -// fn cleanup(&self) -> Result<()> { -// info!("Cleaning up network namespaces"); - -// debug!("Deleting namespace {}", self.ns1_name); -// Command::new("ip") -// .args(["netns", "del", &self.ns1_name]) -// .status() -// .context("Failed to delete ns1")?; - -// debug!("Deleting namespace {}", self.ns2_name); -// Command::new("ip") -// .args(["netns", "del", &self.ns2_name]) -// .status() -// .context("Failed to delete ns2")?; - -// Ok(()) -// } -// } - -// struct Tunnel { -// io_uring: TunIoUring, -// transport: UdpSocket, -// client_endpoint: Option, -// } - -// impl Tunnel { -// async fn new() -> Result<()> { -// info!("Initializing tunnel device"); - -// // Setup TUN device -// let mut config = Configuration::default(); -// config.tun_name("tun0"); -// config.layer(Layer::L3); -// config.mtu(1500); -// config.up(); - -// debug!("Creating TUN device with config: {:?}", config); -// let io_uring = TunIoUring::new(config, 1024, Duration::from_secs(1)) -// .await -// .context("Failed to create TUN device")?; - -// // Configure TUN IP -// info!("Configuring TUN device IP"); -// Command::new("ip") -// .args(["addr", "add", "10.0.0.10/24", "dev", "tun0"]) -// .status() -// .context("Failed to set TUN IP")?; - -// Command::new("ip") -// .args(["link", "set", "dev", "tun0", "up"]) -// .status() -// .context("Failed to bring up TUN")?; - -// // Create UDP transport -// info!("Creating UDP transport socket"); -// let transport = UdpSocket::bind("0.0.0.0:4789") -// .await -// .context("Failed to bind UDP socket")?; -// info!("UDP transport listening on port 4789"); - -// let mut tunnel = Self { -// io_uring, -// transport, -// client_endpoint: None, -// }; - -// info!("Starting tunnel operation"); -// tunnel.run().await -// } - -// async fn run(&mut self) -> Result<()> { -// let mut udp_buf = [0u8; 2000]; - -// loop { -// tokio::select! { -// tun_result = self.io_uring.recv_buf() => { -// match tun_result { -// IOCallbackResult::Ok(buf) => { -// if let Some(client) = self.client_endpoint { -// debug!("Forwarding {} bytes from TUN to client {}", buf.len(), client); -// if let Err(e) = self.transport.send_to(&buf, client).await { -// error!("Failed to send to client {}: {}", client, e); -// } -// } -// } -// IOCallbackResult::WouldBlock => { -// debug!("TUN receive would block"); -// sleep(Duration::from_millis(10)).await; -// } -// IOCallbackResult::Err(e) => { -// error!("TUN receive error: {}", e); -// } -// } -// } - -// udp_result = self.transport.recv_from(&mut udp_buf) => { -// match udp_result { -// Ok((size, addr)) => { -// if self.client_endpoint.is_none() { -// info!("New client connected from {}", addr); -// self.client_endpoint = Some(addr); -// } - -// if Some(addr) == self.client_endpoint { -// debug!("Received {} bytes from client {}", size, addr); -// let packet = BytesMut::from(&udp_buf[..size]); -// match self.io_uring.try_send(packet) { -// IOCallbackResult::Ok(n) => { -// debug!("Wrote {} bytes to TUN", n); -// } -// IOCallbackResult::WouldBlock => { -// warn!("TUN send would block"); -// sleep(Duration::from_millis(10)).await; -// } -// IOCallbackResult::Err(e) => { -// error!("TUN send error: {}", e); -// } -// } -// } else { -// warn!("Ignored packet from unknown client {}", addr); -// } -// } -// Err(e) => { -// error!("UDP receive error: {}", e); -// } -// } -// } -// } -// } -// } -// } - -// #[tokio::main] -// async fn main() -> Result<()> { -// // Initialize logging -// tracing_subscriber::fmt() -// .with_max_level(tracing::Level::DEBUG) -// .init(); - -// info!("Starting VPN tunnel server"); - -// // Check if we're running inside namespace -// if env::args().any(|arg| arg == "--in-namespace") { -// info!("Running in namespace, initializing tunnel"); -// Tunnel::new().await -// } else { -// info!("Setting up network namespaces"); - -// let ns_setup = NetNSSetup::new("test"); -// debug!( -// "Created namespace setup with: ns1={}, ns2={}", -// ns_setup.ns1_name, ns_setup.ns2_name -// ); - -// ns_setup -// .setup() -// .context("Failed to setup network namespaces")?; -// info!("Network namespaces configured successfully"); - -// info!("Starting tunnel in namespace test_ns1"); -// let status = Command::new("ip") -// .args([ -// "netns", -// "exec", -// "test_ns1", -// &env::current_exe()?.to_string_lossy(), -// "--in-namespace", -// ]) -// .status() -// .context("Failed to start tunnel in namespace")?; - -// info!("Tunnel exited with status: {}", status); - -// if !status.success() { -// error!("Tunnel failed with status: {}", status); -// } - -// info!("Cleaning up network namespaces"); -// ns_setup -// .cleanup() -// .context("Failed to cleanup network namespaces")?; -// info!("Cleanup complete"); - -// Ok(()) -// } -// } - -use anyhow::Result; -use io_uring::{opcode, types, IoUring}; -use libc::iovec; -use std::os::fd::AsRawFd; -use tun::{Configuration, Layer}; - -fn main() -> Result<()> { - // Create TUN device - let mut config = Configuration::default(); - config.tun_name("tun0"); - config.layer(Layer::L3); - config.up(); - - let tun = tun::create(&config)?; - println!("Created TUN device: tun0"); - - // Create a test IP packet (very basic IPv4) - let packet = [ - 0x45, 0x00, 0x00, 0x14, // IPv4, len=20 - 0x00, 0x00, 0x40, 0x00, // DF flag - 0x40, 0x01, 0x00, 0x00, // TTL=64, proto=ICMP - 0x0a, 0x00, 0x00, 0x0a, // src: 10.0.0.10 - 0x0a, 0x00, 0x00, 0x02, // dst: 10.0.0.2 - ]; - - let iov = iovec { - iov_base: packet.as_ptr() as *mut libc::c_void, - iov_len: packet.len(), - }; - - // Setup ring - let mut ring = IoUring::new(8)?; - println!("Created IO_URING"); - - // Register buffer and file - #[allow(unsafe_code)] - // Safety: we manage buffer lifecycle - unsafe { - ring.submitter() - .register_buffers(std::slice::from_ref(&iov))?; - println!("Registered buffer"); - ring.submitter().register_files(&[tun.as_raw_fd()])?; - println!("Registered TUN fd"); - } - - // Create WriteFixed operation - let write_op = opcode::WriteFixed::new( - types::Fixed(0), // registered file index - packet.as_ptr(), // buffer pointer - packet.len() as _, // length - 0, // registered buffer index - ) - .build() - .user_data(100); - - println!("Created write operation"); - - // Queue operation - #[allow(unsafe_code)] - // Safety: io_uring crate works - unsafe { - ring.submission().push(&write_op)?; - } - println!("Queued operation"); - - // Submit and wait for completion - ring.submit_and_wait(1)?; - println!("Submitted and waiting"); - - // Check completion - let cqe = ring.completion().next().expect("completion queue empty"); - println!("Write result: {}", cqe.result()); - - Ok(()) -}