diff --git a/src/net/udp.rs b/src/net/udp.rs index de2e9b7..ba60bde 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -10,11 +10,15 @@ use crate::{ ToSocketAddrs, World, TRACING_TARGET, }; +use std::future::Future; +use std::pin::pin; +use std::task::{ready, Context, Poll}; use std::{ cmp, io::{self, Error, ErrorKind, Result}, net::{Ipv6Addr, SocketAddr}, }; +use tokio::io::ReadBuf; /// A simulated UDP socket. /// @@ -60,6 +64,22 @@ impl Rx { Ok((limit, datagram, origin)) } + /// Tries to receive from either the buffered message or the mpsc channel. + fn poll_recv_from( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let (datagram, origin) = if let Some(datagram) = self.buffer.take() { + datagram + } else { + ready!(self.recv.poll_recv(cx)).expect("sender should never be dropped") + }; + let limit = cmp::min(buf.remaining(), datagram.0.len()); + buf.put_slice(&datagram.0[..limit]); + Poll::Ready(Ok(origin)) + } + /// Waits for the socket to become readable. /// /// This function is usually paired with `try_recv_from()`. @@ -247,6 +267,21 @@ impl UdpSocket { Ok((limit, origin)) } + /// Tries to receive a single datagram message on the socket. On success, + /// appends to `buf` and returns the origin. + /// + /// If a message is too long to fit in the unfilled part of `buf` , + /// excess bytes may be discarded. + pub fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let lock_future = pin!(self.rx.lock()); + let mut rx = ready!(lock_future.poll(cx)); + rx.poll_recv_from(cx, buf) + } + /// Waits for the socket to become readable. /// /// This function is usually paired with `try_recv_from()`. diff --git a/tests/udp.rs b/tests/udp.rs index b4f5925..e6488ff 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -1,3 +1,5 @@ +use std::future::poll_fn; +use std::task::Poll; use std::{ io::{self, ErrorKind}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -5,7 +7,8 @@ use std::{ sync::{atomic::AtomicUsize, atomic::Ordering}, time::Duration, }; - +use tokio::io::ReadBuf; +use tokio::time::sleep; use tokio::{sync::oneshot, time::timeout}; use turmoil::{ lookup, @@ -159,6 +162,61 @@ fn try_ping_pong() -> Result { sim.run() } +#[test] +fn poll_recv() -> Result { + let mut sim = Builder::new().build(); + + sim.client("server", async { + let expected_origin = lookup("client"); + let sock = bind().await?; + let buffer = &mut [0u8; 64]; + let mut read_buf = ReadBuf::new(buffer); + + poll_fn(|cx| { + let received = sock.poll_recv_from(cx, &mut read_buf); + assert!(matches!(received, Poll::Pending)); + Poll::Ready(()) + }) + .await; + + // before client sends + sleep(Duration::from_millis(1000)).await; + // after client sends + + poll_fn(|cx| { + let received = sock.poll_recv_from(cx, &mut read_buf); + assert!(matches!(received , Poll::Ready(Ok(x)) if x.ip() == expected_origin)); + Poll::Ready(()) + }) + .await; + sock.readable().await?; + poll_fn(|cx| { + let received = sock.poll_recv_from(cx, &mut read_buf); + assert!(matches!(received , Poll::Ready(Ok(x)) if x.ip() == expected_origin)); + Poll::Ready(()) + }) + .await; + poll_fn(|cx| { + let received = sock.poll_recv_from(cx, &mut read_buf); + assert!(matches!(received, Poll::Pending)); + Poll::Ready(()) + }) + .await; + assert_eq!(read_buf.filled(), b"pingping"); + Ok(()) + }); + + sim.client("client", async { + let sock = bind().await.unwrap(); + sleep(Duration::from_millis(500)).await; + try_send_ping(&sock)?; + try_send_ping(&sock)?; + Ok(()) + }); + + sim.run() +} + #[test] fn recv_buf_is_clipped() -> Result { let mut sim = Builder::new().build();