Skip to content

Commit

Permalink
Merge pull request #30 from hjr3/push-mrnymlsxkzwl
Browse files Browse the repository at this point in the history
Add 'set_reset_reader_on_write' function to TimeoutStream.
  • Loading branch information
hjr3 authored Nov 3, 2024
2 parents e4538af + 7346290 commit 246bc07
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 13 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"

[dependencies]
hyper = "1.1"
hyper-util = { version = "0.1", features = ["client-legacy", "http1"] }
hyper-util = { version = "0.1.10", features = ["client-legacy", "http1"] }
pin-project-lite = "0.2"
tokio = "1.35"
tower-service = "0.3"
Expand All @@ -22,3 +22,4 @@ tokio = { version = "1.35", features = ["io-std", "io-util", "macros"] }
hyper = { version = "1.1", features = ["http1"] }
hyper-tls = "0.6"
http-body-util = "0.1"
hyper-util = { version = "0.1.10", features = ["client-legacy", "http1", "server", "server-graceful"] }
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,32 @@ A connect, read and write timeout aware connector to be used with hyper `Client`

## Problem

At the time this crate was created, hyper does not support timeouts. There is a way to do general timeouts, but no easy way to get connect, read and write specific timeouts.
At the time this crate was created, hyper did not support timeouts. There is a way to do general timeouts, but no easy way to get connect, read and write specific timeouts.

## Solution

There is a `TimeoutConnector` that implements the `hyper::Connect` trait. This connector wraps around `HttpConnector` or `HttpsConnector` values and provides timeouts.

**Note:** In hyper 0.11, a read or write timeout will return a _broken pipe_ error because of the way `tokio_proto::ClientProto` works
> [!IMPORTANT]
> The timeouts are on the underlying stream and _not_ the request.
- The read timeout will start when the underlying stream is first polled for read.
- The write timeout will start when the underlying stream is first polled for write.

Tokio often interleaves poll_read and poll_write calls to handle this bi-directional communication efficiently. Due to this behavior, both the read and write timeouts start at the same time. This means your read timeout can expire while the client is still writing the request to the server. If you are writing large bodies, consider using `set_reset_reader_on_write` to avoid this behavior.

## Usage

Hyper version compatibility:

* The `master` branch will track on going development for hyper.
* The `0.5` release supports hyper 1.0.
* The `0.4` release supports hyper 0.14.
* The `0.3` release supports hyper 0.13.
* The `0.2` release supports hyper 0.12.
* The `0.1` release supports hyper 0.11.
- The `master` branch will track on going development for hyper.
- The `0.5` release supports hyper 1.0.
- The `0.4` release supports hyper 0.14.
- The `0.3` release supports hyper 0.13.
- The `0.2` release supports hyper 0.12.
- The `0.1` release supports hyper 0.11.
- **Note:** In hyper 0.11, a read or write timeout will return a _broken pipe_ error because of the way `tokio_proto::ClientProto` works


Assuming you are using hyper 1.0, add this to your `Cargo.toml`:

Expand Down
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub struct TimeoutConnector<T> {
read_timeout: Option<Duration>,
/// Amount of time to wait writing request
write_timeout: Option<Duration>,
/// If true, resets the reader timeout whenever a write occures
reset_reader_on_write: bool,
}

impl<T> TimeoutConnector<T>
Expand All @@ -43,6 +45,7 @@ where
connect_timeout: None,
read_timeout: None,
write_timeout: None,
reset_reader_on_write: false,
}
}
}
Expand All @@ -67,6 +70,7 @@ where
let connect_timeout = self.connect_timeout;
let read_timeout = self.read_timeout;
let write_timeout = self.write_timeout;
let reset_reader_on_write = self.reset_reader_on_write;
let connecting = self.connector.call(dst);

let fut = async move {
Expand All @@ -86,6 +90,7 @@ where
};
stream.set_read_timeout(read_timeout);
stream.set_write_timeout(write_timeout);
stream.set_reset_reader_on_write(reset_reader_on_write);
Ok(Box::pin(stream))
};

Expand Down Expand Up @@ -117,6 +122,15 @@ impl<T> TimeoutConnector<T> {
pub fn set_write_timeout(&mut self, val: Option<Duration>) {
self.write_timeout = val;
}

/// Reset on the reader timeout on write
///
/// This will reset the reader timeout when a write is done through the
/// the TimeoutReader. This is useful when you don't want to trigger
/// a reader timeout while writes are still be accepted.
pub fn set_reset_reader_on_write(&mut self, reset: bool) {
self.reset_reader_on_write = reset;
}
}

impl<T> Connection for TimeoutConnector<T>
Expand Down
75 changes: 71 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ impl TimeoutState {
}
}

#[inline]
fn restart(self: Pin<&mut Self>) {
let this = self.project();

if *this.active {
let timeout = match this.timeout {
Some(timeout) => *timeout,
None => return,
};

this.cur.reset(Instant::now() + timeout);
}
}

#[inline]
fn poll_check(self: Pin<&mut Self>, cx: &mut Context) -> io::Result<()> {
let mut this = self.project();
Expand Down Expand Up @@ -93,6 +107,7 @@ pin_project! {
reader: R,
#[pin]
state: TimeoutState,
reset_on_write: bool,
}
}

Expand All @@ -107,6 +122,7 @@ where
TimeoutReader {
reader,
state: TimeoutState::new(),
reset_on_write: false,
}
}

Expand Down Expand Up @@ -152,6 +168,20 @@ where
}
}

impl<R> TimeoutReader<R>
where
R: Read + Write,
{
/// Reset on the reader timeout on write
///
/// This will reset the reader timeout when a write is done through the
/// the TimeoutReader. This is useful when you don't want to trigger
/// a reader timeout while writes are still be accepted.
pub fn set_reset_on_write(&mut self, reset: bool) {
self.reset_on_write = reset
}
}

impl<R> Read for TimeoutReader<R>
where
R: Read,
Expand Down Expand Up @@ -180,23 +210,43 @@ where
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().reader.poll_write(cx, buf)
let this = self.project();
let r = this.reader.poll_write(cx, buf);
if *this.reset_on_write && r.is_ready() {
this.state.restart();
}
r
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().reader.poll_flush(cx)
let this = self.project();
let r = this.reader.poll_flush(cx);
if *this.reset_on_write && r.is_ready() {
this.state.restart();
}
r
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().reader.poll_shutdown(cx)
let this = self.project();
let r = this.reader.poll_shutdown(cx);
if *this.reset_on_write && r.is_ready() {
this.state.restart();
}
r
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[io::IoSlice],
) -> Poll<io::Result<usize>> {
self.project().reader.poll_write_vectored(cx, bufs)
let this = self.project();
let r = this.reader.poll_write_vectored(cx, bufs);
if *this.reset_on_write && r.is_ready() {
this.state.restart();
}
r
}

fn is_write_vectored(&self) -> bool {
Expand Down Expand Up @@ -408,6 +458,15 @@ where
.set_timeout_pinned(timeout)
}

/// Reset on the reader timeout on write
///
/// This will reset the reader timeout when a write is done through the
/// the TimeoutReader. This is useful when you don't want to trigger
/// a reader timeout while writes are still be accepted.
pub fn set_reset_reader_on_write(&mut self, reset: bool) {
self.stream.set_reset_on_write(reset);
}

/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S {
self.stream.get_ref().get_ref()
Expand Down Expand Up @@ -507,6 +566,7 @@ pin_project! {
///
/// The returned future will resolve to both the I/O stream and the buffer
/// as well as the number of bytes read once the read operation is completed.
#[cfg(test)]
fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R>
where
R: Read + Unpin + ?Sized,
Expand All @@ -528,6 +588,7 @@ where
}
}

#[cfg(test)]
trait ReadExt: Read {
/// Pulls some bytes from this source into the specified buffer,
/// returning how many bytes were read.
Expand All @@ -549,6 +610,7 @@ pin_project! {

/// Tries to write some bytes from the given `buf` to the writer in an
/// asynchronous manner, returning a future.
#[cfg(test)]
fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W>
where
W: Write + Unpin + ?Sized,
Expand All @@ -568,6 +630,7 @@ where
}
}

#[cfg(test)]
trait WriteExt: Write {
/// Writes a buffer into this writer, returning how many bytes were
/// written.
Expand All @@ -579,6 +642,7 @@ trait WriteExt: Write {
}
}

#[cfg(test)]
impl<R> ReadExt for Pin<&mut TimeoutReader<R>>
where
R: Read,
Expand All @@ -588,6 +652,7 @@ where
}
}

#[cfg(test)]
impl<W> WriteExt for Pin<&mut TimeoutWriter<W>>
where
W: Write,
Expand All @@ -597,6 +662,7 @@ where
}
}

#[cfg(test)]
impl<S> ReadExt for Pin<&mut TimeoutStream<S>>
where
S: Read + Write,
Expand All @@ -606,6 +672,7 @@ where
}
}

#[cfg(test)]
impl<S> WriteExt for Pin<&mut TimeoutStream<S>>
where
S: Read + Write,
Expand Down
109 changes: 109 additions & 0 deletions tests/client_upload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Bytes;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::{client::legacy::Client, rt::TokioIo};
use std::{net::SocketAddr, time::Duration};
use tokio::io;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::task;

use hyper_timeout::TimeoutConnector;

async fn spawn_test_server(listener: TcpListener, shutdown_rx: oneshot::Receiver<()>) {
let http = http1::Builder::new();
let graceful = hyper_util::server::graceful::GracefulShutdown::new();
let mut signal = std::pin::pin!(shutdown_rx);

loop {
tokio::select! {
Ok((stream, _addr)) = listener.accept() => {
let io = TokioIo::new(stream);
let conn = http.serve_connection(io, service_fn(handle_request));
// watch this connection
let fut = graceful.watch(conn);
tokio::spawn(async move {
if let Err(e) = fut.await {
eprintln!("Error serving connection: {:?}", e);
}
});
},

_ = &mut signal => {
eprintln!("graceful shutdown signal received");
break;
}
}
}

tokio::select! {
_ = graceful.shutdown() => {
eprintln!("all connections gracefully closed");
},
_ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
eprintln!("timed out wait for all connections to close");
}
}
}

async fn handle_request(
req: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let body = req.collect().await.expect("Failed to read body").to_bytes();
assert!(!body.is_empty(), "empty body");

Ok(Response::new(full("finished")))
}

fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}

#[tokio::test]
async fn test_upload_timeout() {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr)
.await
.expect("Failed to bind listener");
let (shutdown_tx, shutdown_rx) = oneshot::channel();

let server_addr = listener.local_addr().unwrap();

let server_handle = task::spawn(spawn_test_server(listener, shutdown_rx));

let h = hyper_util::client::legacy::connect::HttpConnector::new();
let mut connector = TimeoutConnector::new(h);
connector.set_read_timeout(Some(Duration::from_millis(5)));

// comment this out and the test will fail
connector.set_reset_reader_on_write(true);

let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);

let body = vec![0; 10 * 1024 * 1024]; // 10MB
let req = Request::post(format!("http://{}/", server_addr))
.body(full(body))
.expect("request builder");

let mut res = client.request(req).await.expect("request failed");

let mut resp_body = Vec::new();
while let Some(frame) = res.body_mut().frame().await {
let bytes = frame
.expect("frame error")
.into_data()
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Error when consuming frame"))
.expect("data error");
resp_body.extend_from_slice(&bytes);
}

assert_eq!(res.status(), 200);
assert_eq!(resp_body, b"finished");

let _ = shutdown_tx.send(());
let _ = server_handle.await;
}

0 comments on commit 246bc07

Please sign in to comment.