Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 159 additions & 4 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use core::task::Poll::{Pending, Ready};
use core::time::Duration;

use async_trait::async_trait;
use num_enum::{IntoPrimitive, TryFromPrimitive};
#[cfg(feature = "net")]
use smoltcp::wire::{IpEndpoint, IpListenEndpoint};

Expand All @@ -14,6 +15,8 @@ use crate::errno::Errno;
use crate::executor::block_on;
use crate::fs::{FileAttr, SeekWhence};
use crate::io;
#[cfg(feature = "net")]
use crate::syscalls::socket::{Ipproto, SOL_SOCKET, socklen_t};

mod eventfd;
#[cfg(any(feature = "net", feature = "vsock"))]
Expand Down Expand Up @@ -43,9 +46,157 @@ pub(crate) enum ListenEndpoint {
}

#[allow(dead_code)]
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum SocketOption {
TcpNoDelay,
TcpOption(SocketOptionTcp),
SocketOption(SocketOptionSocket),
}

#[cfg(feature = "net")]
impl SocketOption {
pub fn from_level_optname(level: i32, optname: i32) -> Option<SocketOption> {
if level == SOL_SOCKET {
SocketOptionSocket::try_from(optname)
.ok()
.map(SocketOption::SocketOption)
} else {
let protocol = u8::try_from(level)
.ok()
.and_then(|proto| Ipproto::try_from(proto).ok())?;

match protocol {
Ipproto::Tcp => SocketOptionTcp::try_from(optname)
.ok()
.map(SocketOption::TcpOption),
_ => None,
}
}
}
}

#[cfg(feature = "net")]
pub struct SocketOptionValue {
optval: *const core::ffi::c_void,
optlen: socklen_t,
}

#[cfg(not(feature = "net"))]
pub struct SocketOptionValue;

unsafe impl Send for SocketOptionValue {}

#[cfg(feature = "net")]
impl SocketOptionValue {
pub fn new(optval: *const core::ffi::c_void, optlen: socklen_t) -> Self {
Self { optval, optlen }
}
}

#[cfg(feature = "net")]
impl TryFrom<&SocketOptionValue> for i32 {
type Error = Errno;

fn try_from(value: &SocketOptionValue) -> Result<Self, Self::Error> {
if value.optval.is_null() {
return Err(Errno::Inval);
}

if value.optlen != size_of::<i32>() as u32 {
return Err(Errno::Inval);
}

let value = unsafe { *value.optval.cast::<i32>() };
Ok(value)
}
}

#[cfg(feature = "net")]
impl TryFrom<&SocketOptionValue> for bool {
type Error = Errno;

fn try_from(value: &SocketOptionValue) -> Result<Self, Self::Error> {
let value: i32 = value.try_into()?;
Ok(value != 0)
}
}

#[cfg(feature = "net")]
pub struct SocketOptionValueWriter {
optval: *mut core::ffi::c_void,
optlen: *mut socklen_t,
touched: bool,
}

#[cfg(not(feature = "net"))]
pub struct SocketOptionValueWriter;

unsafe impl Send for SocketOptionValueWriter {}

#[cfg(feature = "net")]
impl SocketOptionValueWriter {
/// Create a wrapper that will contain a getsockopt result, passing the return buffer and its
/// size as given by the caller.
pub fn new(optval: *mut core::ffi::c_void, optlen: *mut socklen_t) -> Self {
Self {
optval,
optlen,
touched: false,
}
}

fn set_value<T>(&mut self, value: T) -> Result<(), Errno> {
if self.optval.is_null() || self.optlen.is_null() {
return Err(Errno::Inval);
}

let optlen = unsafe { *self.optlen };
let min_size = size_of::<T>();
if min_size < optlen as usize {
Err(Errno::Fault)
} else {
unsafe { *self.optlen = min_size as socklen_t };
let value_ptr = self.optval.cast::<T>();
unsafe {
*value_ptr = value;
}
self.touched = true;
Ok(())
}
}

pub fn write_i32(&mut self, target: i32) -> Result<(), Errno> {
self.set_value(target)
}

pub fn write_bool(&mut self, target: bool) -> Result<(), Errno> {
self.write_i32(if target { 1 } else { 0 })
}

pub fn nullify_untouched(&mut self) {
if !self.touched && !self.optlen.is_null() {
unsafe {
*self.optlen = 0;
}
}
}
}

#[derive(TryFromPrimitive, IntoPrimitive, PartialEq, Eq, Clone, Copy, Debug)]
#[repr(i32)]
#[non_exhaustive]
pub(crate) enum SocketOptionTcp {
#[doc(alias = "TCP_NODELAY")]
TcpNoDelay = 1,
}

#[derive(TryFromPrimitive, IntoPrimitive, PartialEq, Eq, Clone, Copy, Debug)]
#[repr(i32)]
#[non_exhaustive]
pub(crate) enum SocketOptionSocket {
#[doc(alias = "SO_REUSEADDR")]
ReuseAddr = 1,
#[doc(alias = "SO_KEEPALIVE")]
KeepAlive = 8,
}

pub(crate) type FileDescriptor = i32;
Expand Down Expand Up @@ -259,13 +410,17 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {

/// `setsockopt` sets options on sockets
#[cfg(any(feature = "net", feature = "vsock"))]
async fn setsockopt(&self, _opt: SocketOption, _optval: bool) -> io::Result<()> {
async fn setsockopt(&self, _opt: SocketOption, _optval: SocketOptionValue) -> io::Result<()> {
Err(Errno::Notsock)
}

/// `getsockopt` gets options on sockets
#[cfg(any(feature = "net", feature = "vsock"))]
async fn getsockopt(&self, _opt: SocketOption) -> io::Result<bool> {
async fn getsockopt(
&self,
_opt: SocketOption,
_output: &mut SocketOptionValueWriter,
) -> io::Result<()> {
Err(Errno::Notsock)
}

Expand Down
71 changes: 52 additions & 19 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ use smoltcp::wire::{IpEndpoint, Ipv4Address, Ipv6Address};
use crate::errno::Errno;
use crate::executor::block_on;
use crate::executor::network::{Handle, NIC};
use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption};
use crate::fd::{
self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption, SocketOptionSocket,
SocketOptionTcp, SocketOptionValue, SocketOptionValueWriter,
};
use crate::syscalls::socket::Af;
use crate::{DEFAULT_KEEP_ALIVE_INTERVAL, io};

Expand Down Expand Up @@ -415,31 +418,61 @@ impl ObjectInterface for Socket {
Ok(())
}

async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
if opt == SocketOption::TcpNoDelay {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
async fn setsockopt(&self, opt: SocketOption, optval: SocketOptionValue) -> io::Result<()> {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();

for i in self.handle.iter() {
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*i);
socket.set_nagle_enabled(optval);
match opt {
SocketOption::TcpOption(SocketOptionTcp::TcpNoDelay) => {
let is_enabled = (&optval).try_into()?;
for i in self.handle.iter() {
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*i);
socket.set_nagle_enabled(is_enabled);
}
Ok(())
}
SocketOption::SocketOption(SocketOptionSocket::KeepAlive) => {
let keepalive: bool = (&optval).try_into()?;
let keepalive = if keepalive {
Some(Duration::from_secs(120))
} else {
None
};

Ok(())
} else {
Err(Errno::Inval)
for i in self.handle.iter() {
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*i);
socket.set_keep_alive(keepalive);
}

Ok(())
}
other => {
warn!("TCP: unsupported option {other:?}");
Err(Errno::Inval)
}
}
}

async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
if opt == SocketOption::TcpNoDelay {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());
async fn getsockopt(
&self,
opt: SocketOption,
optval: &mut SocketOptionValueWriter,
) -> io::Result<()> {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());

Ok(socket.nagle_enabled())
} else {
Err(Errno::Inval)
match opt {
SocketOption::TcpOption(SocketOptionTcp::TcpNoDelay) => {
optval.write_bool(socket.nagle_enabled())
}
SocketOption::SocketOption(SocketOptionSocket::KeepAlive) => {
optval.write_bool(socket.keep_alive().is_some())
}
other => {
warn!("TCP: unsupported option {other:?}");
Err(Errno::Inval)
}
}
}

Expand Down
Loading
Loading