diff --git a/Cargo.toml b/Cargo.toml index db302d5..4d094fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyper-timeout" -version = "0.4.1" +version = "0.5.0" authors = ["Herman J. Radtke III "] edition = "2018" description = "A connect, read and write timeout aware connector to be used with hyper Client." @@ -11,12 +11,14 @@ repository = "https://github.com/hjr3/hyper-timeout" readme = "README.md" [dependencies] -hyper = { version = "0.14.2", features = ["client"] } +hyper = "1.0" +hyper-util = { version = "0.1", features = ["client-legacy", "http1"] } pin-project-lite = "0.2" -tokio = "1.0.0" -tokio-io-timeout = "1.1.0" +tokio = { version = "1.34.0" } +tower-service = "0.3" [dev-dependencies] -hyper = { version = "0.14", features = ["client", "http1", "tcp"] } -hyper-tls = "0.5" tokio = { version = "1.0.0", features = ["io-std", "io-util", "macros"] } +hyper = { version = "1.0", features = ["http1"] } +hyper-tls = "0.6" +http-body-util = "0.1" diff --git a/README.md b/README.md index 6c34c76..8b66716 100644 --- a/README.md +++ b/README.md @@ -19,16 +19,17 @@ There is a `TimeoutConnector` that implements the `hyper::Connect` trait. This c Hyper version compatibility: * The `master` branch will track on going development for hyper. +* The `0.5` release supports hyper 1.0. * The `0.4` release supports hyper 0.14. * The `0.3` release supports hyper 0.13. * The `0.2` release supports hyper 0.12. * The `0.1` release supports hyper 0.11. -Assuming you are using hyper 0.14, add this to your `Cargo.toml`: +Assuming you are using hyper 1.0, add this to your `Cargo.toml`: ```toml [dependencies] -hyper-timeout = "0.4" +hyper-timeout = "0.5" ``` See the [client example](./examples/client.rs) for a working example. diff --git a/examples/client.rs b/examples/client.rs index 4c39189..0e75c9c 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,8 +1,10 @@ use std::env; use std::time::Duration; -use hyper::{body::HttpBody as _, Client}; -use tokio::io::{self, AsyncWriteExt as _}; +use http_body_util::{BodyExt, Empty}; +use hyper::body::Bytes; +use hyper_util::{client::legacy::Client, rt::TokioExecutor}; +use tokio::io::{self, AsyncWriteExt}; use hyper_tls::HttpsConnector; @@ -22,22 +24,25 @@ async fn main() -> Result<(), Box> { let url = url.parse::().unwrap(); // This example uses `HttpsConnector`, but you can also use hyper `HttpConnector` - //let h = hyper::client::HttpConnector::new(); + //let h = hyper_util::client::legacy::connect::HttpConnector::new(); let h = HttpsConnector::new(); let mut connector = TimeoutConnector::new(h); connector.set_connect_timeout(Some(Duration::from_secs(5))); connector.set_read_timeout(Some(Duration::from_secs(5))); connector.set_write_timeout(Some(Duration::from_secs(5))); - let client = Client::builder().build::<_, hyper::Body>(connector); + let client = Client::builder(TokioExecutor::new()).build::<_, Empty>(connector); let mut res = client.get(url).await?; println!("Status: {}", res.status()); println!("Headers:\n{:#?}", res.headers()); - while let Some(chunk) = res.body_mut().data().await { - let chunk = chunk?; - io::stdout().write_all(&chunk).await? + while let Some(frame) = res.body_mut().frame().await { + let bytes = frame? + .into_data() + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Error when consuming frame"))?; + io::stdout().write_all(&bytes).await?; } + Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index c2e4285..ce0e90c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,20 +4,19 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite}; +use hyper::rt::{Read, Write}; use tokio::time::timeout; -use tokio_io_timeout::TimeoutStream; -use hyper::client::connect::{Connected, Connection}; -use hyper::{service::Service, Uri}; +use hyper::Uri; +use hyper_util::client::legacy::connect::{Connected, Connection}; +use tower_service::Service; mod stream; - -use stream::TimeoutConnectorStream; +use stream::TimeoutStream; type BoxError = Box; -/// A connector that enforces as connection timeout +/// A connector that enforces a connection timeout #[derive(Debug, Clone)] pub struct TimeoutConnector { /// A connector implementing the `Connect` trait @@ -33,7 +32,7 @@ pub struct TimeoutConnector { impl TimeoutConnector where T: Service + Send, - T::Response: AsyncRead + AsyncWrite + Send + Unpin, + T::Response: Read + Write + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { @@ -51,11 +50,11 @@ where impl Service for TimeoutConnector where T: Service + Send, - T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin, + T::Response: Read + Write + Connection + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { - type Response = Pin>>; + type Response = Pin>>; type Error = BoxError; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; @@ -71,7 +70,7 @@ where let connecting = self.connector.call(dst); let fut = async move { - let stream = match connect_timeout { + let mut stream = match connect_timeout { None => { let io = connecting.await.map_err(Into::into)?; TimeoutStream::new(io) @@ -85,11 +84,9 @@ where TimeoutStream::new(io) } }; - - let mut tm = TimeoutConnectorStream::new(stream); - tm.set_read_timeout(read_timeout); - tm.set_write_timeout(write_timeout); - Ok(Box::pin(tm)) + stream.set_read_timeout(read_timeout); + stream.set_write_timeout(write_timeout); + Ok(Box::pin(stream)) }; Box::pin(fut) @@ -124,8 +121,8 @@ impl TimeoutConnector { impl Connection for TimeoutConnector where - T: AsyncRead + AsyncWrite + Connection + Service + Send + Unpin, - T::Response: AsyncRead + AsyncWrite + Send + Unpin, + T: Read + Write + Connection + Service + Send + Unpin, + T::Response: Read + Write + Send + Unpin, T::Future: Send + 'static, T::Error: Into, { @@ -136,12 +133,15 @@ where #[cfg(test)] mod tests { - use std::error::Error; - use std::io; use std::time::Duration; + use std::{error::Error, io}; - use hyper::client::HttpConnector; - use hyper::Client; + use http_body_util::Empty; + use hyper::body::Bytes; + use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, + }; use super::TimeoutConnector; @@ -154,7 +154,7 @@ mod tests { let mut connector = TimeoutConnector::new(http); connector.set_connect_timeout(Some(Duration::from_millis(1))); - let client = Client::builder().build::<_, hyper::Body>(connector); + let client = Client::builder(TokioExecutor::new()).build::<_, Empty>(connector); let res = client.get(url).await; @@ -179,19 +179,17 @@ mod tests { // A 1 ms read timeout should be so short that we trigger a timeout error connector.set_read_timeout(Some(Duration::from_millis(1))); - let client = Client::builder().build::<_, hyper::Body>(connector); + let client = Client::builder(TokioExecutor::new()).build::<_, Empty>(connector); let res = client.get(url).await; - match res { - Ok(_) => panic!("Expected a timeout"), - Err(e) => { - if let Some(io_e) = e.source().unwrap().downcast_ref::() { - assert_eq!(io_e.kind(), io::ErrorKind::TimedOut); - } else { - panic!("Expected timeout error"); + if let Err(client_e) = res { + if let Some(hyper_e) = client_e.source() { + if let Some(io_e) = hyper_e.source().unwrap().downcast_ref::() { + return assert_eq!(io_e.kind(), io::ErrorKind::TimedOut); } } } + panic!("Expected timeout error"); } } diff --git a/src/stream.rs b/src/stream.rs index dd1bd35..d8010eb 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,37 +1,371 @@ +//! Wrappers for applying timeouts to IO operations. +//! +//! This used to depend on [tokio-io-timeout]. After Hyper 1.0 introduced hyper-specific IO traits, this was rewritten to use hyper IO traits instead of tokio IO traits. +//! +//! These timeouts are analogous to the read and write timeouts on traditional blocking sockets. A timeout countdown is +//! initiated when a read/write operation returns [`Poll::Pending`]. If a read/write does not return successfully before +//! the countdown expires, an [`io::Error`] with a kind of [`TimedOut`](io::ErrorKind::TimedOut) is returned. +#![warn(missing_docs)] + +use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write}; +use hyper_util::client::legacy::connect::{Connected, Connection}; +use pin_project_lite::pin_project; +use std::future::Future; use std::io; -use std::io::IoSlice; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; +use tokio::time::{sleep_until, Instant, Sleep}; -use hyper::client::connect::{Connected, Connection}; -use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_io_timeout::TimeoutStream; +pin_project! { + #[derive(Debug)] + struct TimeoutState { + timeout: Option, + #[pin] + cur: Sleep, + active: bool, + } +} + +impl TimeoutState { + #[inline] + fn new() -> TimeoutState { + TimeoutState { + timeout: None, + cur: sleep_until(Instant::now()), + active: false, + } + } + + #[inline] + fn timeout(&self) -> Option { + self.timeout + } + + #[inline] + fn set_timeout(&mut self, timeout: Option) { + // since this takes &mut self, we can't yet be active + self.timeout = timeout; + } + + #[inline] + fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option) { + *self.as_mut().project().timeout = timeout; + self.reset(); + } + + #[inline] + fn reset(self: Pin<&mut Self>) { + let this = self.project(); + + if *this.active { + *this.active = false; + this.cur.reset(Instant::now()); + } + } + + #[inline] + fn poll_check(self: Pin<&mut Self>, cx: &mut Context) -> io::Result<()> { + let mut this = self.project(); + + let timeout = match this.timeout { + Some(timeout) => *timeout, + None => return Ok(()), + }; + + if !*this.active { + this.cur.as_mut().reset(Instant::now() + timeout); + *this.active = true; + } + + match this.cur.poll(cx) { + Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)), + Poll::Pending => Ok(()), + } + } +} + +pin_project! { + /// An `hyper::rt::Read`er which applies a timeout to read operations. + #[derive(Debug)] + pub struct TimeoutReader { + #[pin] + reader: R, + #[pin] + state: TimeoutState, + } +} + +impl TimeoutReader +where + R: Read, +{ + /// Returns a new `TimeoutReader` wrapping the specified reader. + /// + /// There is initially no timeout. + pub fn new(reader: R) -> TimeoutReader { + TimeoutReader { + reader, + state: TimeoutState::new(), + } + } + + /// Returns the current read timeout. + pub fn timeout(&self) -> Option { + self.state.timeout() + } + + /// Sets the read timeout. + /// + /// This can only be used before the reader is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned) + /// otherwise. + pub fn set_timeout(&mut self, timeout: Option) { + self.state.set_timeout(timeout); + } + + /// Sets the read timeout. + /// + /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet + /// pinned. + pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option) { + self.project().state.set_timeout_pinned(timeout); + } + + /// Returns a shared reference to the inner reader. + pub fn get_ref(&self) -> &R { + &self.reader + } + + /// Returns a mutable reference to the inner reader. + pub fn get_mut(&mut self) -> &mut R { + &mut self.reader + } + + /// Returns a pinned mutable reference to the inner reader. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.project().reader + } + + /// Consumes the `TimeoutReader`, returning the inner reader. + pub fn into_inner(self) -> R { + self.reader + } +} + +impl Read for TimeoutReader +where + R: Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: ReadBufCursor, + ) -> Poll> { + let this = self.project(); + let r = this.reader.poll_read(cx, buf); + match r { + Poll::Pending => this.state.poll_check(cx)?, + _ => this.state.reset(), + } + r + } +} + +impl Write for TimeoutReader +where + R: Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + self.project().reader.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().reader.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().reader.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context, + bufs: &[io::IoSlice], + ) -> Poll> { + self.project().reader.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.reader.is_write_vectored() + } +} + +pin_project! { + /// An `hyper::rt::Write`er which applies a timeout to write operations. + #[derive(Debug)] + pub struct TimeoutWriter { + #[pin] + writer: W, + #[pin] + state: TimeoutState, + } +} + +impl TimeoutWriter +where + W: Write, +{ + /// Returns a new `TimeoutReader` wrapping the specified reader. + /// + /// There is initially no timeout. + pub fn new(writer: W) -> TimeoutWriter { + TimeoutWriter { + writer, + state: TimeoutState::new(), + } + } + + /// Returns the current write timeout. + pub fn timeout(&self) -> Option { + self.state.timeout() + } + + /// Sets the write timeout. + /// + /// This can only be used before the writer is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned) + /// otherwise. + pub fn set_timeout(&mut self, timeout: Option) { + self.state.set_timeout(timeout); + } + + /// Sets the write timeout. + /// + /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet + /// pinned. + pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option) { + self.project().state.set_timeout_pinned(timeout); + } + + /// Returns a shared reference to the inner writer. + pub fn get_ref(&self) -> &W { + &self.writer + } + + /// Returns a mutable reference to the inner writer. + pub fn get_mut(&mut self) -> &mut W { + &mut self.writer + } + + /// Returns a pinned mutable reference to the inner writer. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { + self.project().writer + } + + /// Consumes the `TimeoutWriter`, returning the inner writer. + pub fn into_inner(self) -> W { + self.writer + } +} + +impl Write for TimeoutWriter +where + W: Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + let r = this.writer.poll_write(cx, buf); + match r { + Poll::Pending => this.state.poll_check(cx)?, + _ => this.state.reset(), + } + r + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + let r = this.writer.poll_flush(cx); + match r { + Poll::Pending => this.state.poll_check(cx)?, + _ => this.state.reset(), + } + r + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + let r = this.writer.poll_shutdown(cx); + match r { + Poll::Pending => this.state.poll_check(cx)?, + _ => this.state.reset(), + } + r + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context, + bufs: &[io::IoSlice], + ) -> Poll> { + let this = self.project(); + let r = this.writer.poll_write_vectored(cx, bufs); + match r { + Poll::Pending => this.state.poll_check(cx)?, + _ => this.state.reset(), + } + r + } + + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } +} + +impl Read for TimeoutWriter +where + W: Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: ReadBufCursor, + ) -> Poll> { + self.project().writer.poll_read(cx, buf) + } +} pin_project! { - /// A timeout stream that implements required traits to be a Connector + /// A stream which applies read and write timeouts to an inner stream. #[derive(Debug)] - pub struct TimeoutConnectorStream { + pub struct TimeoutStream { #[pin] - stream: TimeoutStream + stream: TimeoutReader> } } -impl TimeoutConnectorStream +impl TimeoutStream where - S: AsyncRead + AsyncWrite + Unpin, + S: Read + Write, { - /// Returns a new `TimeoutConnectorStream` wrapping the specified stream. + /// Returns a new `TimeoutStream` wrapping the specified stream. /// /// There is initially no read or write timeout. - pub fn new(stream: TimeoutStream) -> TimeoutConnectorStream { - TimeoutConnectorStream { stream } + pub fn new(stream: S) -> TimeoutStream { + let writer = TimeoutWriter::new(stream); + let stream = TimeoutReader::new(writer); + TimeoutStream { stream } } /// Returns the current read timeout. pub fn read_timeout(&self) -> Option { - self.stream.read_timeout() + self.stream.timeout() } /// Sets the read timeout. @@ -39,23 +373,20 @@ where /// This can only be used before the stream is pinned; use /// [`set_read_timeout_pinned`](Self::set_read_timeout_pinned) otherwise. pub fn set_read_timeout(&mut self, timeout: Option) { - self.stream.set_read_timeout(timeout) + self.stream.set_timeout(timeout) } /// Sets the read timeout. /// - /// This will reset any pending read timeout. Use - /// [`set_read_timeout`](Self::set_read_timeout) instead if the stream has not yet been pinned. + /// This will reset any pending read timeout. Use [`set_read_timeout`](Self::set_read_timeout) instead if the stream + /// has not yet been pinned. pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option) { - self.project() - .stream - .as_mut() - .set_read_timeout_pinned(timeout) + self.project().stream.set_timeout_pinned(timeout) } /// Returns the current write timeout. pub fn write_timeout(&self) -> Option { - self.stream.write_timeout() + self.stream.get_ref().timeout() } /// Sets the write timeout. @@ -63,102 +394,353 @@ where /// This can only be used before the stream is pinned; use /// [`set_write_timeout_pinned`](Self::set_write_timeout_pinned) otherwise. pub fn set_write_timeout(&mut self, timeout: Option) { - self.stream.set_write_timeout(timeout) + self.stream.get_mut().set_timeout(timeout) } /// Sets the write timeout. /// - /// This will reset any pending write timeout. Use - /// [`set_write_timeout`](Self::set_write_timeout) instead if the stream has not yet been - /// pinned. + /// This will reset any pending write timeout. Use [`set_write_timeout`](Self::set_write_timeout) instead if the + /// stream has not yet been pinned. pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option) { self.project() .stream - .as_mut() - .set_write_timeout_pinned(timeout) + .get_pin_mut() + .set_timeout_pinned(timeout) } /// Returns a shared reference to the inner stream. pub fn get_ref(&self) -> &S { - self.stream.get_ref() + self.stream.get_ref().get_ref() } /// Returns a mutable reference to the inner stream. pub fn get_mut(&mut self) -> &mut S { - self.stream.get_mut() + self.stream.get_mut().get_mut() } /// Returns a pinned mutable reference to the inner stream. pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { - self.project().stream.get_pin_mut() + self.project().stream.get_pin_mut().get_pin_mut() } /// Consumes the stream, returning the inner stream. pub fn into_inner(self) -> S { - self.stream.into_inner() + self.stream.into_inner().into_inner() } } -impl AsyncRead for TimeoutConnectorStream +impl Read for TimeoutStream where - S: AsyncRead + AsyncWrite + Unpin, + S: Read + Write, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut ReadBuf, + buf: ReadBufCursor, ) -> Poll> { self.project().stream.poll_read(cx, buf) } } -impl AsyncWrite for TimeoutConnectorStream +impl Write for TimeoutStream where - S: AsyncRead + AsyncWrite + Unpin, + S: Read + Write, { fn poll_write( self: Pin<&mut Self>, - cx: &mut Context<'_>, + cx: &mut Context, buf: &[u8], ) -> Poll> { self.project().stream.poll_write(cx, buf) } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().stream.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().stream.poll_shutdown(cx) + } + fn poll_write_vectored( self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { + cx: &mut Context, + bufs: &[io::IoSlice], + ) -> Poll> { self.project().stream.poll_write_vectored(cx, bufs) } fn is_write_vectored(&self) -> bool { self.stream.is_write_vectored() } +} - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().stream.poll_flush(cx) +impl Connection for TimeoutStream +where + S: Read + Write + Connection + Unpin, +{ + fn connected(&self) -> Connected { + self.get_ref().connected() } +} - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().stream.poll_shutdown(cx) +impl Connection for Pin>> +where + S: Read + Write + Connection + Unpin, +{ + fn connected(&self) -> Connected { + self.get_ref().connected() + } +} + +pin_project! { + /// A future which can be used to easily read available number of bytes to fill + /// a buffer. Based on the internal [tokio::io::util::read::Read] + struct ReadFut<'a, R: ?Sized> { + reader: &'a mut R, + buf: &'a mut [u8], } } -impl Connection for TimeoutConnectorStream +/// Tries to read some bytes directly into the given `buf` in asynchronous +/// manner, returning a future type. +/// +/// The returned future will resolve to both the I/O stream and the buffer +/// as well as the number of bytes read once the read operation is completed. +fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R> where - S: AsyncRead + AsyncWrite + Connection + Unpin, + R: Read + Unpin + ?Sized, { - fn connected(&self) -> Connected { - self.stream.get_ref().connected() + ReadFut { reader, buf } +} + +impl Future for ReadFut<'_, R> +where + R: Read + Unpin + ?Sized, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + let mut buf = ReadBuf::new(me.buf); + ready!(Pin::new(me.reader).poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +trait ReadExt: Read { + /// Pulls some bytes from this source into the specified buffer, + /// returning how many bytes were read. + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> + where + Self: Unpin, + { + read(self, buf) } } -impl Connection for Pin>> +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`.- + struct WriteFut<'a, W: ?Sized> { + writer: &'a mut W, + buf: &'a [u8], + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W> where - S: AsyncRead + AsyncWrite + Connection + Unpin, + W: Write + Unpin + ?Sized, { - fn connected(&self) -> Connected { - self.stream.get_ref().connected() + WriteFut { writer, buf } +} + +impl Future for WriteFut<'_, W> +where + W: Write + Unpin + ?Sized, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = self.project(); + Pin::new(&mut *me.writer).poll_write(cx, me.buf) + } +} + +trait WriteExt: Write { + /// Writes a buffer into this writer, returning how many bytes were + /// written. + fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> + where + Self: Unpin, + { + write(self, src) + } +} + +impl ReadExt for Pin<&mut TimeoutReader> +where + R: Read, +{ + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> { + read(self, buf) + } +} + +impl WriteExt for Pin<&mut TimeoutWriter> +where + W: Write, +{ + fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> { + write(self, src) + } +} + +impl ReadExt for Pin<&mut TimeoutStream> +where + S: Read + Write, +{ + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> { + read(self, buf) + } +} + +impl WriteExt for Pin<&mut TimeoutStream> +where + S: Read + Write, +{ + fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> { + write(self, src) + } +} + +#[cfg(test)] +mod test { + use super::*; + use hyper_util::rt::TokioIo; + use std::io::Write; + use std::net::TcpListener; + use std::thread; + use tokio::net::TcpStream; + use tokio::pin; + + pin_project! { + struct DelayStream { + #[pin] + sleep: Sleep, + } + } + + impl DelayStream { + fn new(until: Instant) -> Self { + DelayStream { + sleep: sleep_until(until), + } + } + } + + impl Read for DelayStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + _buf: ReadBufCursor, + ) -> Poll> { + match self.project().sleep.poll(cx) { + Poll::Ready(()) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } + } + + impl hyper::rt::Write for DelayStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + match self.project().sleep.poll(cx) { + Poll::Ready(()) => Poll::Ready(Ok(buf.len())), + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[tokio::test] + async fn read_timeout() { + let reader = DelayStream::new(Instant::now() + Duration::from_millis(500)); + let mut reader = TimeoutReader::new(reader); + reader.set_timeout(Some(Duration::from_millis(100))); + pin!(reader); + + let r = reader.read(&mut [0, 1, 2]).await; + assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut); + } + + #[tokio::test] + async fn read_ok() { + let reader = DelayStream::new(Instant::now() + Duration::from_millis(100)); + let mut reader = TimeoutReader::new(reader); + reader.set_timeout(Some(Duration::from_millis(500))); + pin!(reader); + + reader.read(&mut [0]).await.unwrap(); + } + + #[tokio::test] + async fn write_timeout() { + let writer = DelayStream::new(Instant::now() + Duration::from_millis(500)); + let mut writer = TimeoutWriter::new(writer); + writer.set_timeout(Some(Duration::from_millis(100))); + pin!(writer); + + let r = writer.write(&[0]).await; + assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut); + } + + #[tokio::test] + async fn write_ok() { + let writer = DelayStream::new(Instant::now() + Duration::from_millis(100)); + let mut writer = TimeoutWriter::new(writer); + writer.set_timeout(Some(Duration::from_millis(500))); + pin!(writer); + + writer.write(&[0]).await.unwrap(); + } + + #[tokio::test] + async fn tcp_read() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + thread::spawn(move || { + let mut socket = listener.accept().unwrap().0; + thread::sleep(Duration::from_millis(10)); + socket.write_all(b"f").unwrap(); + thread::sleep(Duration::from_millis(500)); + let _ = socket.write_all(b"f"); // this may hit an eof + }); + + let s = TcpStream::connect(&addr).await.unwrap(); + let s = TokioIo::new(s); + let mut s = TimeoutStream::new(s); + s.set_read_timeout(Some(Duration::from_millis(100))); + pin!(s); + s.read(&mut [0]).await.unwrap(); + let r = s.read(&mut [0]).await; + + match r { + Ok(_) => panic!("unexpected success"), + Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (), + Err(e) => panic!("{:?}", e), + } } }