diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 2f993d8950..7267cf529b 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -72,11 +72,11 @@ impl AsyncUdpSocket for UdpConn { } fn max_transmit_segments(&self) -> usize { - self.io.max_transmit_segments().unwrap_or_default() + self.io.max_gso_segments().unwrap_or_default() } fn max_receive_segments(&self) -> usize { - self.io.max_receive_segments().unwrap_or_default() + self.io.gro_segments().unwrap_or_default() } } diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 8f59ee7891..91af26f2e8 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -18,11 +18,9 @@ use super::IpFamily; /// Wrapper around a tokio UDP socket. #[derive(Debug)] pub struct UdpSocket { - socket: RwLock>, + socket: RwLock, recv_waker: AtomicWaker, send_waker: AtomicWaker, - /// The addr we are binding to. - addr: SocketAddr, /// Set to true, when an error occurred, that means we need to rebind the socket. is_broken: AtomicBool, } @@ -81,27 +79,17 @@ impl UdpSocket { /// Rebind the underlying socket. pub fn rebind(&self) -> Result<()> { - debug!("rebinding {}", self.addr); - // Remove old socket - let mut guard = self.socket.write().unwrap(); { - let Some(socket) = guard.take() else { - bail!("cannot rebind closed socket"); - }; - drop(socket); - } + let mut guard = self.socket.write().unwrap(); + guard.rebind()?; - // Prepare new socket - let new_socket = inner_bind(self.addr)?; + // Clear errors + self.is_broken + .store(false, std::sync::atomic::Ordering::SeqCst); - // Insert new socket - guard.replace(new_socket); - - // Clear errors - self.is_broken - .store(false, std::sync::atomic::Ordering::SeqCst); + drop(guard); + } - drop(guard); // wakeup self.wake_all(); @@ -109,21 +97,27 @@ impl UdpSocket { } fn bind_raw(addr: impl Into) -> Result { - let mut addr = addr.into(); - let socket = inner_bind(addr)?; - // update to use selected port - addr.set_port(socket.0.local_addr()?.port()); + let socket = SocketState::bind(addr.into())?; Ok(UdpSocket { - socket: RwLock::new(Some(socket)), + socket: RwLock::new(socket), recv_waker: AtomicWaker::default(), send_waker: AtomicWaker::default(), - addr, is_broken: AtomicBool::new(false), }) } - /// TODO + /// Receives a single datagram message on the socket from the remote address + /// to which it is connected. On success, returns the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect pub fn recv<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFut<'a, 'b> { RecvFut { socket: self, @@ -131,7 +125,13 @@ impl UdpSocket { } } - /// TODO + /// Receives a single datagram message on the socket. On success, returns + /// the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// pub fn recv_from<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFromFut<'a, 'b> { RecvFromFut { socket: self, @@ -139,7 +139,18 @@ impl UdpSocket { } } - /// TODO + /// Sends data on the socket to the remote address that the socket is + /// connected to. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + /// + /// # Return + /// + /// On success, the number of bytes sent is returned, otherwise, the + /// encountered error is returned. pub fn send<'a, 'b>(&'b self, buffer: &'a [u8]) -> SendFut<'a, 'b> { SendFut { socket: self, @@ -147,7 +158,8 @@ impl UdpSocket { } } - /// TODO + /// Sends data on the socket to the given address. On success, returns the + /// number of bytes written. pub fn send_to<'a, 'b>(&'b self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a, 'b> { SendToFut { socket: self, @@ -156,17 +168,13 @@ impl UdpSocket { } } - /// TODO + /// Connects the UDP socket setting the default destination for send() and + /// limiting packets that are read via `recv` from the address specified in + /// `addr`. pub fn connect(&self, addr: SocketAddr) -> std::io::Result<()> { tracing::info!("connectnig to {}", addr); let guard = self.socket.read().unwrap(); - let Some((socket_tokio, _state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (socket_tokio, _state) = guard.try_get_connected()?; let sock_ref = socket2::SockRef::from(&socket_tokio); sock_ref.connect(&socket2::SockAddr::from(addr))?; @@ -177,20 +185,14 @@ impl UdpSocket { /// Returns the local address of this socket. pub fn local_addr(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (socket, _state) = guard.try_get_connected()?; socket.local_addr() } /// Closes the socket, and waits for the underlying `libc::close` call to be finished. pub async fn close(&self) { - let socket = self.socket.write().unwrap().take(); + let socket = self.socket.write().unwrap().close(); self.wake_all(); if let Some((sock, _)) = socket { let std_sock = sock.into_std(); @@ -208,7 +210,7 @@ impl UdpSocket { /// Check if this socket is closed. pub fn is_closed(&self) -> bool { - self.socket.read().unwrap().is_none() + self.socket.read().unwrap().is_closed() } /// Handle potential read errors, updating internal state. @@ -244,7 +246,7 @@ impl UdpSocket { &self, waker: &AtomicWaker, cx: &mut std::task::Context<'_>, - ) -> Poll>> { + ) -> Poll> { let guard = match self.socket.try_read() { Ok(guard) => guard, Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), @@ -274,33 +276,35 @@ impl UdpSocket { self.send_waker.wake(); } + /// Checks if the socket needs a rebind, and if so does it. + /// + /// Returns an error if the rebind is needed, but failed. + fn maybe_rebind(&self) -> std::io::Result<()> { + if self.is_broken() { + match self.rebind() { + Ok(()) => { + // all good + } + Err(err) => { + warn!("failed to rebind socket: {:?}", err); + // TODO: improve error + let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); + return Err(err); + } + } + } + Ok(()) + } + /// Poll for writable pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { loop { - // check if the socket needs a rebind - if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); } let guard = futures_lite::ready!(self.poll_read_socket(&self.send_waker, cx)); - let Some((socket, _state)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, _state) = guard.try_get_connected()?; match socket.poll_send_ready(cx) { Poll::Pending => { @@ -321,29 +325,18 @@ impl UdpSocket { /// Send a quinn based `Transmit`. pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> std::io::Result<()> { loop { - // check if the socket needs a rebind - if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Err(err); - } + self.maybe_rebind()?; + + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => { + panic!("lock poisoned: {:?}", e); + } + Err(TryLockError::WouldBlock) => { + return Err(std::io::Error::new(std::io::ErrorKind::WouldBlock, "")); } - } - let guard = self.socket.read().unwrap(); - let Some((socket, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); }; + let (socket, state) = guard.try_get_connected()?; let res = socket.try_io(Interest::WRITABLE, || state.send(socket.into(), transmit)); @@ -367,29 +360,12 @@ impl UdpSocket { meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { loop { - // check if the socket needs a rebind - if self.is_broken() { - match self.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); } + let guard = futures_lite::ready!(self.poll_read_socket(&self.recv_waker, cx)); - let Some((socket, state)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, state) = guard.try_get_connected()?; match socket.poll_recv_ready(cx) { Poll::Pending => { @@ -437,42 +413,33 @@ impl UdpSocket { } } - /// TODO + /// Whether transmitted datagrams might get fragmented by the IP layer + /// + /// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option. pub fn may_fragment(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((_, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (_, state) = guard.try_get_connected()?; Ok(state.may_fragment()) } - /// TODO - pub fn max_transmit_segments(&self) -> std::io::Result { + /// The maximum amount of segments which can be transmitted if a platform + /// supports Generic Send Offload (GSO). + /// + /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected + /// while using GSO. + pub fn max_gso_segments(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((_, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (_, state) = guard.try_get_connected()?; Ok(state.max_gso_segments()) } - /// TODO - pub fn max_receive_segments(&self) -> std::io::Result { + /// The number of segments to read when GRO is enabled. Used as a factor to + /// compute the receive buffer size. + /// + /// Returns 1 if the platform doesn't support GRO. + pub fn gro_segments(&self) -> std::io::Result { let guard = self.socket.read().unwrap(); - let Some((_, state)) = guard.as_ref() else { - warn!("socket closed"); - return Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - )); - }; + let (_, state) = guard.try_get_connected()?; Ok(state.gro_segments()) } } @@ -491,30 +458,12 @@ impl Future for RecvFut<'_, '_> { let Self { socket, buffer } = &mut *self; loop { - // check if the socket needs a rebind - if socket.is_broken() { - match socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); } let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); - let Some((inner_socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (inner_socket, _state) = guard.try_get_connected()?; match inner_socket.poll_recv_ready(cx) { Poll::Pending => { @@ -559,29 +508,12 @@ impl Future for RecvFromFut<'_, '_> { let Self { socket, buffer } = &mut *self; loop { - // check if the socket needs a rebind - if socket.is_broken() { - match socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); } + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); - let Some((inner_socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (inner_socket, _state) = guard.try_get_connected()?; match inner_socket.poll_recv_ready(cx) { Poll::Pending => { @@ -624,30 +556,13 @@ impl Future for SendFut<'_, '_> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { - // check if the socket needs a rebind - if self.socket.is_broken() { - match self.socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); } + let guard = futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); - let Some((socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, _state) = guard.try_get_connected()?; match socket.poll_send_ready(cx) { Poll::Pending => { @@ -691,31 +606,13 @@ impl Future for SendToFut<'_, '_> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { loop { - // check if the socket needs a rebind - if self.socket.is_broken() { - match self.socket.rebind() { - Ok(()) => { - // all good - } - Err(err) => { - warn!("failed to rebind socket: {:?}", err); - // TODO: improve error - let err = - std::io::Error::new(std::io::ErrorKind::BrokenPipe, err.to_string()); - return Poll::Ready(Err(err)); - } - } + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); } let guard = futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); - let Some((socket, _)) = guard.as_ref() else { - warn!("socket closed"); - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "socket closed", - ))); - }; + let (socket, _state) = guard.try_get_connected()?; match socket.poll_send_ready(cx) { Poll::Pending => { @@ -747,74 +644,138 @@ impl Future for SendToFut<'_, '_> { } } -fn inner_bind(addr: SocketAddr) -> Result<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { - let network = IpFamily::from(addr.ip()); - let socket = socket2::Socket::new( - network.into(), - socket2::Type::DGRAM, - Some(socket2::Protocol::UDP), - ) - .context("socket create")?; +#[derive(Debug)] +enum SocketState { + Connected { + socket: tokio::net::UdpSocket, + state: quinn_udp::UdpSocketState, + /// The addr we are binding to. + addr: SocketAddr, + }, + Closed, +} - if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( - "failed to set recv_buffer_size to {}: {:?}", - SOCKET_BUFFER_SIZE, err - ); - } - if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( - "failed to set send_buffer_size to {}: {:?}", - SOCKET_BUFFER_SIZE, err - ); - } - if network == IpFamily::V6 { - // Avoid dualstack - socket.set_only_v6(true).context("only IPv6")?; +impl SocketState { + fn try_get_connected( + &self, + ) -> std::io::Result<(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState)> { + match self { + Self::Connected { + socket, + state, + addr: _, + } => Ok((socket, state)), + Self::Closed => { + warn!("socket closed"); + Err(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "socket closed", + )) + } + } } - // Binding must happen before calling quinn, otherwise `local_addr` - // is not yet available on all OSes. - socket.bind(&addr.into()).context("binding")?; + fn bind(addr: SocketAddr) -> Result { + let network = IpFamily::from(addr.ip()); + let socket = socket2::Socket::new( + network.into(), + socket2::Type::DGRAM, + Some(socket2::Protocol::UDP), + ) + .context("socket create")?; + + if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { + warn!( + "failed to set recv_buffer_size to {}: {:?}", + SOCKET_BUFFER_SIZE, err + ); + } + if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { + warn!( + "failed to set send_buffer_size to {}: {:?}", + SOCKET_BUFFER_SIZE, err + ); + } + if network == IpFamily::V6 { + // Avoid dualstack + socket.set_only_v6(true).context("only IPv6")?; + } + + // Binding must happen before calling quinn, otherwise `local_addr` + // is not yet available on all OSes. + socket.bind(&addr.into()).context("binding")?; - // Ensure nonblocking - socket.set_nonblocking(true).context("nonblocking: true")?; + // Ensure nonblocking + socket.set_nonblocking(true).context("nonblocking: true")?; - let socket: std::net::UdpSocket = socket.into(); + let socket: std::net::UdpSocket = socket.into(); - // Convert into tokio UdpSocket - let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; - let socket_ref = quinn_udp::UdpSockRef::from(&socket); - let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; + // Convert into tokio UdpSocket + let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; + let socket_ref = quinn_udp::UdpSockRef::from(&socket); + let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; - if addr.port() != 0 { let local_addr = socket.local_addr().context("local addr")?; - ensure!( - local_addr.port() == addr.port(), - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ); + if addr.port() != 0 { + ensure!( + local_addr.port() == addr.port(), + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ); + } + + Ok(Self::Connected { + socket, + state: socket_state, + addr: local_addr, + }) + } + + fn rebind(&mut self) -> Result<()> { + let addr = match self { + Self::Connected { addr, .. } => *addr, + Self::Closed => { + bail!("socket is closed and cannot be rebound"); + } + }; + debug!("rebinding {}", addr); + + *self = SocketState::Closed; + *self = Self::bind(addr)?; + + Ok(()) + } + + fn is_closed(&self) -> bool { + matches!(self, Self::Closed) } - Ok((socket, socket_state)) + fn close(&mut self) -> Option<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { + match std::mem::replace(self, SocketState::Closed) { + Self::Connected { socket, state, .. } => Some((socket, state)), + Self::Closed => None, + } + } } impl Drop for UdpSocket { fn drop(&mut self) { - debug!("dropping UdpSocket"); - // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. - if let Ok(handle) = tokio::runtime::Handle::try_current() { - // No wakeup after dropping write lock here, since we're getting dropped. - if let Some((socket, _)) = self.socket.write().unwrap().take() { - // this will be empty if `close` was called before - let std_sock = socket.into_std(); - handle.spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }); + trace!("dropping UdpSocket"); + match self.socket.write().unwrap().close() { + Some((socket, _)) => { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + // No wakeup after dropping write lock here, since we're getting dropped. + // this will be empty if `close` was called before + let std_sock = socket.into_std(); + handle.spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }); + } } + None => {} } } }