Skip to content

Commit

Permalink
simplify handling of network timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
stlankes committed Sep 30, 2024
1 parent 3d00fbf commit 80278bb
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 171 deletions.
161 changes: 55 additions & 106 deletions src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ use crate::drivers::mmio::get_network_driver;
use crate::drivers::net::NetworkDriver;
#[cfg(all(any(feature = "tcp", feature = "udp"), feature = "pci"))]
use crate::drivers::pci::get_network_driver;
#[cfg(any(feature = "tcp", feature = "udp"))]
use crate::executor::network::network_delay;
use crate::executor::task::AsyncTask;
use crate::io;
#[cfg(any(feature = "tcp", feature = "udp"))]
Expand Down Expand Up @@ -97,30 +95,11 @@ pub fn init() {
crate::executor::vsock::init();
}

#[inline]
pub(crate) fn now() -> u64 {
crate::arch::kernel::systemtime::now_micros()
}

/// Blocks the current thread on `f`, running the executor when idling.
pub(crate) fn poll_on<F, T>(future: F, timeout: Option<Duration>) -> io::Result<T>
pub(crate) fn poll_on<F, T>(future: F) -> io::Result<T>
where
F: Future<Output = io::Result<T>>,
{
#[cfg(any(feature = "tcp", feature = "udp"))]
let nic = get_network_driver();

// disable network interrupts
#[cfg(any(feature = "tcp", feature = "udp"))]
let no_retransmission = if let Some(nic) = nic {
let mut guard = nic.lock();
guard.set_polling_mode(true);
guard.get_checksums().tcp.tx()
} else {
true
};

let start = now();
let mut cx = Context::from_waker(Waker::noop());
let mut future = pin!(future);

Expand All @@ -129,42 +108,8 @@ where
run();

if let Poll::Ready(t) = future.as_mut().poll(&mut cx) {
#[cfg(any(feature = "tcp", feature = "udp"))]
if !no_retransmission {
let wakeup_time =
network_delay(Instant::from_micros_const(now().try_into().unwrap()))
.map(|d| crate::arch::processor::get_timer_ticks() + d.total_micros());
core_scheduler().add_network_timer(wakeup_time);
}

// allow network interrupts
#[cfg(any(feature = "tcp", feature = "udp"))]
if let Some(nic) = nic {
nic.lock().set_polling_mode(false);
}

return t;
}

if let Some(duration) = timeout {
if Duration::from_micros(now() - start) >= duration {
#[cfg(any(feature = "tcp", feature = "udp"))]
if !no_retransmission {
let wakeup_time =
network_delay(Instant::from_micros_const(now().try_into().unwrap()))
.map(|d| crate::arch::processor::get_timer_ticks() + d.total_micros());
core_scheduler().add_network_timer(wakeup_time);
}

// allow network interrupts
#[cfg(any(feature = "tcp", feature = "udp"))]
if let Some(nic) = nic {
nic.lock().set_polling_mode(false);
}

return Err(io::Error::ETIME);
}
}
}
}

Expand All @@ -174,98 +119,102 @@ where
F: Future<Output = io::Result<T>>,
{
#[cfg(any(feature = "tcp", feature = "udp"))]
let nic = get_network_driver();

// disable network interrupts
#[cfg(any(feature = "tcp", feature = "udp"))]
let no_retransmission = if let Some(nic) = nic {
let mut guard = nic.lock();
guard.set_polling_mode(true);
!guard.get_checksums().tcp.tx()
} else {
true
};
let device = get_network_driver();

let backoff = Backoff::new();
let start = now();
let start = crate::arch::kernel::systemtime::now_micros();
let task_notify = Arc::new(TaskNotify::new());
let waker = task_notify.clone().into();
let mut cx = Context::from_waker(&waker);
let mut future = pin!(future);

loop {
// run background tasks
run();
// check future
let result = future.as_mut().poll(&mut cx);

let now = now();
if let Poll::Ready(t) = future.as_mut().poll(&mut cx) {
#[cfg(any(feature = "tcp", feature = "udp"))]
if !no_retransmission {
let network_timer =
network_delay(Instant::from_micros_const(now.try_into().unwrap()))
.map(|d| crate::arch::processor::get_timer_ticks() + d.total_micros());
core_scheduler().add_network_timer(network_timer);
}
// run background all tasks, which poll also the network device
run();

let now = crate::arch::kernel::systemtime::now_micros();
if let Poll::Ready(t) = result {
// allow network interrupts
#[cfg(any(feature = "tcp", feature = "udp"))]
if let Some(nic) = nic {
nic.lock().set_polling_mode(false);
{
let delay = if let Ok(nic) = crate::executor::network::NIC.lock().as_nic_mut() {
nic.poll_delay(Instant::from_micros_const(now.try_into().unwrap()))
.map(|d| d.total_micros())
} else {
None
};
core_scheduler().add_network_timer(
delay.map(|d| crate::arch::processor::get_timer_ticks() + d),
);

if let Some(device) = device {
device.lock().set_polling_mode(false);
}
}

return t;
}

if let Some(duration) = timeout {
if Duration::from_micros(now - start) >= duration {
#[cfg(any(feature = "tcp", feature = "udp"))]
if !no_retransmission {
let network_timer =
network_delay(Instant::from_micros_const(now.try_into().unwrap()))
.map(|d| crate::arch::processor::get_timer_ticks() + d.total_micros());
core_scheduler().add_network_timer(network_timer);
}

// allow network interrupts
#[cfg(any(feature = "tcp", feature = "udp"))]
if let Some(nic) = nic {
nic.lock().set_polling_mode(false);
{
let delay = if let Ok(nic) = crate::executor::network::NIC.lock().as_nic_mut() {
nic.poll_delay(Instant::from_micros_const(now.try_into().unwrap()))
.map(|d| d.total_micros())
} else {
None
};
core_scheduler().add_network_timer(
delay.map(|d| crate::arch::processor::get_timer_ticks() + d),
);

if let Some(device) = device {
device.lock().set_polling_mode(false);
}
}

return Err(io::Error::ETIME);
}
}

#[cfg(any(feature = "tcp", feature = "udp"))]
{
let delay = network_delay(Instant::from_micros_const(now.try_into().unwrap()))
.map(|d| d.total_micros());
if backoff.is_completed() {
let delay = if let Ok(nic) = crate::executor::network::NIC.lock().as_nic_mut() {
nic.poll_delay(Instant::from_micros_const(now.try_into().unwrap()))
.map(|d| d.total_micros())
} else {
None
};

if backoff.is_completed() && delay.unwrap_or(10_000_000) > 10_000 {
if delay.unwrap_or(10_000_000) > 10_000 {
core_scheduler().add_network_timer(
delay.map(|d| crate::arch::processor::get_timer_ticks() + d),
);
let wakeup_time =
timeout.map(|duration| start + u64::try_from(duration.as_micros()).unwrap());
if !no_retransmission {
let ticks = crate::arch::processor::get_timer_ticks();
let network_timer = delay.map(|d| ticks + d);
core_scheduler().add_network_timer(network_timer);
}

// allow network interrupts
if let Some(nic) = nic {
nic.lock().set_polling_mode(false);
if let Some(device) = device {
device.lock().set_polling_mode(false);
}

// switch to another task
task_notify.wait(wakeup_time);

// restore default values
if let Some(nic) = nic {
nic.lock().set_polling_mode(true);
if let Some(device) = device {
device.lock().set_polling_mode(true);
}

backoff.reset();
} else {
backoff.snooze();
}
} else {
backoff.snooze();
}

#[cfg(not(any(feature = "tcp", feature = "udp")))]
Expand Down
25 changes: 3 additions & 22 deletions src/executor/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,9 @@ impl<'a> NetworkInterface<'a> {
Ok(tcp_handle)
}

pub(crate) fn poll_common(&mut self, timestamp: Instant) {
let _ = self
.iface
.poll(timestamp, &mut self.device, &mut self.sockets);
pub(crate) fn poll_common(&mut self, timestamp: Instant) -> bool {
self.iface
.poll(timestamp, &mut self.device, &mut self.sockets)
}

pub(crate) fn poll_delay(&mut self, timestamp: Instant) -> Option<Duration> {
Expand Down Expand Up @@ -321,21 +320,3 @@ impl<'a> NetworkInterface<'a> {
Ok(self.sockets.get_mut(dns_handle))
}
}

#[inline]
pub(crate) fn network_delay(timestamp: Instant) -> Option<Duration> {
crate::executor::network::NIC
.lock()
.as_nic_mut()
.unwrap()
.poll_delay(timestamp)
}

#[inline]
fn network_poll(timestamp: Instant) {
crate::executor::network::NIC
.lock()
.as_nic_mut()
.unwrap()
.poll_common(timestamp);
}
13 changes: 3 additions & 10 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use smoltcp::socket::tcp;
use smoltcp::time::Duration;

use crate::executor::block_on;
use crate::executor::network::{now, Handle, NIC};
use crate::executor::network::{Handle, NIC};
use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent, SocketOption};
use crate::{io, DEFAULT_KEEP_ALIVE_INTERVAL};

Expand Down Expand Up @@ -52,20 +52,14 @@ impl Socket {
fn with<R>(&self, f: impl FnOnce(&mut tcp::Socket<'_>) -> R) -> R {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let result = f(nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.front().unwrap()));
nic.poll_common(now());

result
f(nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.front().unwrap()))
}

fn with_context<R>(&self, f: impl FnOnce(&mut tcp::Socket<'_>, &mut iface::Context) -> R) -> R {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let (s, cx) = nic.get_socket_and_context::<tcp::Socket<'_>>(*self.handle.front().unwrap());
let result = f(s, cx);
nic.poll_common(now());

result
f(s, cx)
}

async fn close(&self) -> io::Result<()> {
Expand Down Expand Up @@ -418,7 +412,6 @@ impl Socket {
.map(|_| ())
.map_err(|_| io::Error::EIO)?;
}
nic.poll_common(now());

Ok(())
}
Expand Down
7 changes: 2 additions & 5 deletions src/fd/socket/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use smoltcp::socket::udp::UdpMetadata;
use smoltcp::wire::IpEndpoint;

use crate::executor::block_on;
use crate::executor::network::{now, Handle, NIC};
use crate::executor::network::{Handle, NIC};
use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent};
use crate::io;

Expand All @@ -31,10 +31,7 @@ impl Socket {
fn with<R>(&self, f: impl FnOnce(&mut udp::Socket<'_>) -> R) -> R {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let result = f(nic.get_mut_socket::<udp::Socket<'_>>(self.handle));
nic.poll_common(now());

result
f(nic.get_mut_socket::<udp::Socket<'_>>(self.handle))
}

async fn close(&self) -> io::Result<()> {
Expand Down
53 changes: 25 additions & 28 deletions src/scheduler/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,35 +462,32 @@ impl Task {
))))
.unwrap();
let objmap = OBJECT_MAP.get().unwrap().clone();
let _ = poll_on(
async {
let mut guard = objmap.write().await;
if env::is_uhyve() {
guard
.try_insert(STDIN_FILENO, Arc::new(UhyveStdin::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDOUT_FILENO, Arc::new(UhyveStdout::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDERR_FILENO, Arc::new(UhyveStderr::new()))
.map_err(|_| io::Error::EIO)?;
} else {
guard
.try_insert(STDIN_FILENO, Arc::new(GenericStdin::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDOUT_FILENO, Arc::new(GenericStdout::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDERR_FILENO, Arc::new(GenericStderr::new()))
.map_err(|_| io::Error::EIO)?;
}
let _ = poll_on(async {
let mut guard = objmap.write().await;
if env::is_uhyve() {
guard
.try_insert(STDIN_FILENO, Arc::new(UhyveStdin::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDOUT_FILENO, Arc::new(UhyveStdout::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDERR_FILENO, Arc::new(UhyveStderr::new()))
.map_err(|_| io::Error::EIO)?;
} else {
guard
.try_insert(STDIN_FILENO, Arc::new(GenericStdin::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDOUT_FILENO, Arc::new(GenericStdout::new()))
.map_err(|_| io::Error::EIO)?;
guard
.try_insert(STDERR_FILENO, Arc::new(GenericStderr::new()))
.map_err(|_| io::Error::EIO)?;
}

Ok(())
},
None,
);
Ok(())
});
}

Task {
Expand Down

0 comments on commit 80278bb

Please sign in to comment.