From eb74c390aca0ff6fcca499e8a26ce9857b106762 Mon Sep 17 00:00:00 2001 From: Louis Vialar Date: Wed, 10 Sep 2025 14:38:39 +0200 Subject: [PATCH] network: change implementation of sockopt The goal is to permit non boolean options in the future, and also to add the keepalive option which is required by some other software --- src/fd/mod.rs | 163 ++++++++++++++++++++++++++++++++++++- src/fd/socket/tcp.rs | 71 +++++++++++----- src/syscalls/socket/mod.rs | 118 +++++++++++---------------- 3 files changed, 259 insertions(+), 93 deletions(-) diff --git a/src/fd/mod.rs b/src/fd/mod.rs index 5629a37709..3dee5b743c 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -6,6 +6,7 @@ use core::task::Poll::{Pending, Ready}; use core::time::Duration; use async_trait::async_trait; +use num_enum::{IntoPrimitive, TryFromPrimitive}; #[cfg(feature = "net")] use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; @@ -14,6 +15,8 @@ use crate::errno::Errno; use crate::executor::block_on; use crate::fs::{FileAttr, SeekWhence}; use crate::io; +#[cfg(feature = "net")] +use crate::syscalls::socket::{Ipproto, SOL_SOCKET, socklen_t}; mod eventfd; #[cfg(any(feature = "net", feature = "vsock"))] @@ -43,9 +46,157 @@ pub(crate) enum ListenEndpoint { } #[allow(dead_code)] -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] pub(crate) enum SocketOption { - TcpNoDelay, + TcpOption(SocketOptionTcp), + SocketOption(SocketOptionSocket), +} + +#[cfg(feature = "net")] +impl SocketOption { + pub fn from_level_optname(level: i32, optname: i32) -> Option { + if level == SOL_SOCKET { + SocketOptionSocket::try_from(optname) + .ok() + .map(SocketOption::SocketOption) + } else { + let protocol = u8::try_from(level) + .ok() + .and_then(|proto| Ipproto::try_from(proto).ok())?; + + match protocol { + Ipproto::Tcp => SocketOptionTcp::try_from(optname) + .ok() + .map(SocketOption::TcpOption), + _ => None, + } + } + } +} + +#[cfg(feature = "net")] +pub struct SocketOptionValue { + optval: *const core::ffi::c_void, + optlen: socklen_t, +} + +#[cfg(not(feature = "net"))] +pub struct SocketOptionValue; + +unsafe impl Send for SocketOptionValue {} + +#[cfg(feature = "net")] +impl SocketOptionValue { + pub fn new(optval: *const core::ffi::c_void, optlen: socklen_t) -> Self { + Self { optval, optlen } + } +} + +#[cfg(feature = "net")] +impl TryFrom<&SocketOptionValue> for i32 { + type Error = Errno; + + fn try_from(value: &SocketOptionValue) -> Result { + if value.optval.is_null() { + return Err(Errno::Inval); + } + + if value.optlen != size_of::() as u32 { + return Err(Errno::Inval); + } + + let value = unsafe { *value.optval.cast::() }; + Ok(value) + } +} + +#[cfg(feature = "net")] +impl TryFrom<&SocketOptionValue> for bool { + type Error = Errno; + + fn try_from(value: &SocketOptionValue) -> Result { + let value: i32 = value.try_into()?; + Ok(value != 0) + } +} + +#[cfg(feature = "net")] +pub struct SocketOptionValueWriter { + optval: *mut core::ffi::c_void, + optlen: *mut socklen_t, + touched: bool, +} + +#[cfg(not(feature = "net"))] +pub struct SocketOptionValueWriter; + +unsafe impl Send for SocketOptionValueWriter {} + +#[cfg(feature = "net")] +impl SocketOptionValueWriter { + /// Create a wrapper that will contain a getsockopt result, passing the return buffer and its + /// size as given by the caller. + pub fn new(optval: *mut core::ffi::c_void, optlen: *mut socklen_t) -> Self { + Self { + optval, + optlen, + touched: false, + } + } + + fn set_value(&mut self, value: T) -> Result<(), Errno> { + if self.optval.is_null() || self.optlen.is_null() { + return Err(Errno::Inval); + } + + let optlen = unsafe { *self.optlen }; + let min_size = size_of::(); + if min_size < optlen as usize { + Err(Errno::Fault) + } else { + unsafe { *self.optlen = min_size as socklen_t }; + let value_ptr = self.optval.cast::(); + unsafe { + *value_ptr = value; + } + self.touched = true; + Ok(()) + } + } + + pub fn write_i32(&mut self, target: i32) -> Result<(), Errno> { + self.set_value(target) + } + + pub fn write_bool(&mut self, target: bool) -> Result<(), Errno> { + self.write_i32(if target { 1 } else { 0 }) + } + + pub fn nullify_untouched(&mut self) { + if !self.touched && !self.optlen.is_null() { + unsafe { + *self.optlen = 0; + } + } + } +} + +#[derive(TryFromPrimitive, IntoPrimitive, PartialEq, Eq, Clone, Copy, Debug)] +#[repr(i32)] +#[non_exhaustive] +pub(crate) enum SocketOptionTcp { + #[doc(alias = "TCP_NODELAY")] + TcpNoDelay = 1, +} + +#[derive(TryFromPrimitive, IntoPrimitive, PartialEq, Eq, Clone, Copy, Debug)] +#[repr(i32)] +#[non_exhaustive] +pub(crate) enum SocketOptionSocket { + #[doc(alias = "SO_REUSEADDR")] + ReuseAddr = 1, + #[doc(alias = "SO_KEEPALIVE")] + KeepAlive = 8, } pub(crate) type FileDescriptor = i32; @@ -259,13 +410,17 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { /// `setsockopt` sets options on sockets #[cfg(any(feature = "net", feature = "vsock"))] - async fn setsockopt(&self, _opt: SocketOption, _optval: bool) -> io::Result<()> { + async fn setsockopt(&self, _opt: SocketOption, _optval: SocketOptionValue) -> io::Result<()> { Err(Errno::Notsock) } /// `getsockopt` gets options on sockets #[cfg(any(feature = "net", feature = "vsock"))] - async fn getsockopt(&self, _opt: SocketOption) -> io::Result { + async fn getsockopt( + &self, + _opt: SocketOption, + _output: &mut SocketOptionValueWriter, + ) -> io::Result<()> { Err(Errno::Notsock) } diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index e432f7abd7..d242bd8dab 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -14,7 +14,10 @@ use smoltcp::wire::{IpEndpoint, Ipv4Address, Ipv6Address}; use crate::errno::Errno; use crate::executor::block_on; use crate::executor::network::{Handle, NIC}; -use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption}; +use crate::fd::{ + self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption, SocketOptionSocket, + SocketOptionTcp, SocketOptionValue, SocketOptionValueWriter, +}; use crate::syscalls::socket::Af; use crate::{DEFAULT_KEEP_ALIVE_INTERVAL, io}; @@ -415,31 +418,61 @@ impl ObjectInterface for Socket { Ok(()) } - async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> { - if opt == SocketOption::TcpNoDelay { - let mut guard = NIC.lock(); - let nic = guard.as_nic_mut().unwrap(); + async fn setsockopt(&self, opt: SocketOption, optval: SocketOptionValue) -> io::Result<()> { + let mut guard = NIC.lock(); + let nic = guard.as_nic_mut().unwrap(); - for i in self.handle.iter() { - let socket = nic.get_mut_socket::>(*i); - socket.set_nagle_enabled(optval); + match opt { + SocketOption::TcpOption(SocketOptionTcp::TcpNoDelay) => { + let is_enabled = (&optval).try_into()?; + for i in self.handle.iter() { + let socket = nic.get_mut_socket::>(*i); + socket.set_nagle_enabled(is_enabled); + } + Ok(()) } + SocketOption::SocketOption(SocketOptionSocket::KeepAlive) => { + let keepalive: bool = (&optval).try_into()?; + let keepalive = if keepalive { + Some(Duration::from_secs(120)) + } else { + None + }; - Ok(()) - } else { - Err(Errno::Inval) + for i in self.handle.iter() { + let socket = nic.get_mut_socket::>(*i); + socket.set_keep_alive(keepalive); + } + + Ok(()) + } + other => { + warn!("TCP: unsupported option {other:?}"); + Err(Errno::Inval) + } } } - async fn getsockopt(&self, opt: SocketOption) -> io::Result { - if opt == SocketOption::TcpNoDelay { - let mut guard = NIC.lock(); - let nic = guard.as_nic_mut().unwrap(); - let socket = nic.get_mut_socket::>(*self.handle.first().unwrap()); + async fn getsockopt( + &self, + opt: SocketOption, + optval: &mut SocketOptionValueWriter, + ) -> io::Result<()> { + let mut guard = NIC.lock(); + let nic = guard.as_nic_mut().unwrap(); + let socket = nic.get_mut_socket::>(*self.handle.first().unwrap()); - Ok(socket.nagle_enabled()) - } else { - Err(Errno::Inval) + match opt { + SocketOption::TcpOption(SocketOptionTcp::TcpNoDelay) => { + optval.write_bool(socket.nagle_enabled()) + } + SocketOption::SocketOption(SocketOptionSocket::KeepAlive) => { + optval.write_bool(socket.keep_alive().is_some()) + } + other => { + warn!("TCP: unsupported option {other:?}"); + Err(Errno::Inval) + } } } diff --git a/src/syscalls/socket/mod.rs b/src/syscalls/socket/mod.rs index f50338e146..09c03984c0 100644 --- a/src/syscalls/socket/mod.rs +++ b/src/syscalls/socket/mod.rs @@ -26,7 +26,8 @@ use crate::fd::socket::udp; #[cfg(feature = "vsock")] use crate::fd::socket::vsock::{self, VsockEndpoint, VsockListenEndpoint}; use crate::fd::{ - self, Endpoint, ListenEndpoint, ObjectInterface, SocketOption, get_object, insert_object, + self, Endpoint, ListenEndpoint, ObjectInterface, SocketOption, SocketOptionSocket, + SocketOptionValue, SocketOptionValueWriter, get_object, insert_object, }; use crate::syscalls::block_on; @@ -910,44 +911,34 @@ pub unsafe extern "C" fn sys_setsockopt( optval: *const c_void, optlen: socklen_t, ) -> i32 { - if level == SOL_SOCKET && optname == SO_REUSEADDR { + let option = SocketOption::from_level_optname(level, optname); + let Some(option) = option else { + warn!( + "setsockopt: unsupported option level={level:x} optname={optname:x}, faking success." + ); return 0; - } - - let Ok(Ok(level)) = u8::try_from(level).map(Ipproto::try_from) else { - return -i32::from(Errno::Inval); }; - debug!("sys_setsockopt: {fd}, level {level:?}, optname {optname}"); - - if level == Ipproto::Tcp - && optname == TCP_NODELAY - && optlen == u32::try_from(size_of::()).unwrap() - { - if optval.is_null() { - return -i32::from(Errno::Inval); - } - - let value = unsafe { *optval.cast::() }; - let obj = get_object(fd); - obj.map_or_else( - |e| -i32::from(e), - |v| { - block_on( - async { - v.read() - .await - .setsockopt(SocketOption::TcpNoDelay, value != 0) - .await - }, - None, - ) - .map_or_else(|e| -i32::from(e), |()| 0) - }, - ) - } else { - -i32::from(Errno::Inval) + if option == SocketOption::SocketOption(SocketOptionSocket::ReuseAddr) { + return 0; } + + let obj = get_object(fd); + obj.map_or_else( + |e| -i32::from(e), + |v| { + block_on( + async { + v.read() + .await + .setsockopt(option, SocketOptionValue::new(optval, optlen)) + .await + }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) + }, + ) } #[hermit_macro::system(errno)] @@ -959,45 +950,32 @@ pub unsafe extern "C" fn sys_getsockopt( optval: *mut c_void, optlen: *mut socklen_t, ) -> i32 { - let Ok(Ok(level)) = u8::try_from(level).map(Ipproto::try_from) else { - return -i32::from(Errno::Inval); + let option = SocketOption::from_level_optname(level, optname); + let mut value = SocketOptionValueWriter::new(optval, optlen); + + let Some(option) = option else { + warn!( + "getsockopt: unsupported option level={level:x} optname={optname:x}, faking success." + ); + value.nullify_untouched(); + return 0; }; debug!("sys_getsockopt: {fd}, level {level:?}, optname {optname}"); + let obj = get_object(fd); + let result = obj.map_or_else( + |e| -i32::from(e), + |v| { + block_on( + async { v.read().await.getsockopt(option, &mut value).await }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) + }, + ); - if level == Ipproto::Tcp && optname == TCP_NODELAY { - if optval.is_null() || optlen.is_null() { - return -i32::from(Errno::Inval); - } - - let optval = unsafe { &mut *optval.cast::() }; - let optlen = unsafe { &mut *optlen }; - let obj = get_object(fd); - obj.map_or_else( - |e| -i32::from(e), - |v| { - block_on( - async { v.read().await.getsockopt(SocketOption::TcpNoDelay).await }, - None, - ) - .map_or_else( - |e| -i32::from(e), - |value| { - if value { - *optval = 1; - } else { - *optval = 0; - } - *optlen = core::mem::size_of::().try_into().unwrap(); - - 0 - }, - ) - }, - ) - } else { - -i32::from(Errno::Inval) - } + value.nullify_untouched(); + result } #[hermit_macro::system(errno)]