Skip to content

Commit

Permalink
use a async reader/write lock to protect the raw udp socket
Browse files Browse the repository at this point in the history
  • Loading branch information
stlankes committed Sep 22, 2024
1 parent 7fba8cc commit 4c1e545
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 41 deletions.
91 changes: 51 additions & 40 deletions src/fd/socket/udp.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
use alloc::boxed::Box;
use core::future;
use core::ops::DerefMut;
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::Poll;

use async_trait::async_trait;
use crossbeam_utils::atomic::AtomicCell;
use smoltcp::socket::udp;
use smoltcp::socket::udp::UdpMetadata;
use smoltcp::wire::IpEndpoint;

use crate::executor::block_on;
use crate::executor::network::{now, Handle, NetworkState, NIC};
use crate::executor::network::{now, Handle, NIC};
use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent};
use crate::io;

#[derive(Debug)]
pub struct Socket {
handle: Handle,
nonblocking: AtomicBool,
endpoint: AtomicCell<Option<IpEndpoint>>,
nonblocking: bool,
endpoint: Option<IpEndpoint>,
}

impl Socket {
pub fn new(handle: Handle) -> Self {
Self {
handle,
nonblocking: AtomicBool::new(false),
endpoint: AtomicCell::new(None),
nonblocking: false,
endpoint: None,
}
}

Expand All @@ -40,7 +37,7 @@ impl Socket {
result
}

async fn async_close(&self) -> io::Result<()> {
async fn close(&self) -> io::Result<()> {
future::poll_fn(|_cx| {
self.with(|socket| {
socket.close();
Expand All @@ -50,7 +47,7 @@ impl Socket {
.await
}

async fn async_write_with_meta(&self, buffer: &[u8], meta: &UdpMetadata) -> io::Result<usize> {
async fn write_with_meta(&self, buffer: &[u8], meta: &UdpMetadata) -> io::Result<usize> {
future::poll_fn(|cx| {
self.with(|socket| {
if socket.is_open() {
Expand All @@ -72,10 +69,7 @@ impl Socket {
})
.await
}
}

#[async_trait]
impl ObjectInterface for Socket {
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
future::poll_fn(|cx| {
self.with(|socket| {
Expand Down Expand Up @@ -130,10 +124,10 @@ impl ObjectInterface for Socket {
}
}

async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
#[allow(irrefutable_let_patterns)]
if let Endpoint::Ip(endpoint) = endpoint {
self.endpoint.store(Some(endpoint));
self.endpoint = Some(endpoint);
Ok(())
} else {
Err(io::Error::EIO)
Expand All @@ -144,7 +138,7 @@ impl ObjectInterface for Socket {
#[allow(irrefutable_let_patterns)]
if let Endpoint::Ip(endpoint) = endpoint {
let meta = UdpMetadata::from(endpoint);
self.async_write_with_meta(buf, &meta).await
self.write_with_meta(buf, &meta).await
} else {
Err(io::Error::EIO)
}
Expand All @@ -156,7 +150,7 @@ impl ObjectInterface for Socket {
if socket.is_open() {
if socket.can_recv() {
match socket.recv_slice(buffer) {
Ok((len, meta)) => match self.endpoint.load() {
Ok((len, meta)) => match self.endpoint {
Some(ep) => {
if meta.endpoint == ep {
Poll::Ready(Ok((len, meta.endpoint)))
Expand Down Expand Up @@ -189,7 +183,7 @@ impl ObjectInterface for Socket {
if socket.is_open() {
if socket.can_recv() {
match socket.recv_slice(buffer) {
Ok((len, meta)) => match self.endpoint.load() {
Ok((len, meta)) => match self.endpoint {
Some(ep) => {
if meta.endpoint == ep {
Poll::Ready(Ok(len))
Expand All @@ -216,22 +210,22 @@ impl ObjectInterface for Socket {
}

async fn write(&self, buf: &[u8]) -> io::Result<usize> {
if let Some(endpoint) = self.endpoint.load() {
if let Some(endpoint) = self.endpoint {
let meta = UdpMetadata::from(endpoint);
self.async_write_with_meta(buf, &meta).await
self.write_with_meta(buf, &meta).await
} else {
Err(io::Error::EINVAL)
}
}

async fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> {
async fn ioctl(&mut self, cmd: IoCtl, value: bool) -> io::Result<()> {
if cmd == IoCtl::NonBlocking {
if value {
info!("set device to nonblocking mode");
self.nonblocking.store(true, Ordering::Release);
self.nonblocking = true;
} else {
info!("set device to blocking mode");
self.nonblocking.store(false, Ordering::Release);
self.nonblocking = false;
}

Ok(())
Expand All @@ -241,26 +235,43 @@ impl ObjectInterface for Socket {
}
}

impl Clone for Socket {
fn clone(&self) -> Self {
let mut guard = NIC.lock();
impl Drop for Socket {
fn drop(&mut self) {
let _ = block_on(self.close(), None);
}
}

let handle = if let NetworkState::Initialized(nic) = guard.deref_mut() {
nic.create_udp_handle().unwrap()
} else {
panic!("Unable to create handle");
};
#[async_trait]
impl ObjectInterface for async_lock::RwLock<Socket> {
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
self.read().await.poll(event).await
}

Self {
handle,
nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Acquire)),
endpoint: AtomicCell::new(self.endpoint.load()),
}
async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
self.read().await.bind(endpoint).await
}
}

impl Drop for Socket {
fn drop(&mut self) {
let _ = block_on(self.async_close(), None);
async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
self.write().await.connect(endpoint).await
}

async fn sendto(&self, buffer: &[u8], endpoint: Endpoint) -> io::Result<usize> {
self.read().await.sendto(buffer, endpoint).await
}

async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> {
self.read().await.recvfrom(buffer).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

async fn write(&self, buf: &[u8]) -> io::Result<usize> {
self.read().await.write(buf).await
}

async fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> {
self.write().await.ioctl(cmd, value).await
}
}
2 changes: 1 addition & 1 deletion src/syscalls/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32
if type_.contains(SockType::SOCK_DGRAM) {
let handle = nic.create_udp_handle().unwrap();
drop(guard);
let socket = Arc::new(udp::Socket::new(handle));
let socket = Arc::new(async_lock::RwLock::new(udp::Socket::new(handle)));

if type_.contains(SockType::SOCK_NONBLOCK) {
block_on(socket.ioctl(IoCtl::NonBlocking, true), None).unwrap();
Expand Down

0 comments on commit 4c1e545

Please sign in to comment.