diff --git a/src/ext/informational.rs b/src/ext/informational.rs new file mode 100644 index 0000000000..e728580fa5 --- /dev/null +++ b/src/ext/informational.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +#[derive(Clone)] +pub(crate) struct OnInformational(Arc); + +/// Add a callback for 1xx informational responses. +/// +/// # Example +/// +/// ``` +/// # let some_body = (); +/// let mut req = hyper::Request::new(some_body); +/// +/// hyper::ext::on_informational(&mut req, |res| { +/// println!("informational: {:?}", res.status()); +/// }); +/// +/// // send request on a client connection... +/// ``` +pub fn on_informational(req: &mut http::Request, callback: F) +where + F: Fn(Response<'_>) + Send + Sync + 'static, +{ + on_informational_raw(req, OnInformationalClosure(callback)); +} + +pub(crate) fn on_informational_raw(req: &mut http::Request, callback: C) +where + C: OnInformationalCallback + Send + Sync + 'static, +{ + req.extensions_mut() + .insert(OnInformational(Arc::new(callback))); +} + +// Sealed, not actually nameable bounds +pub(crate) trait OnInformationalCallback { + fn on_informational(&self, res: http::Response<()>); +} + +impl OnInformational { + pub(crate) fn call(&self, res: http::Response<()>) { + self.0.on_informational(res); + } +} + +struct OnInformationalClosure(F); + +impl OnInformationalCallback for OnInformationalClosure +where + F: Fn(Response<'_>) + Send + Sync + 'static, +{ + fn on_informational(&self, res: http::Response<()>) { + let res = Response(&res); + (self.0)(res); + } +} + +// A facade over http::Response. +// +// It purposefully hides being able to move the response out of the closure, +// while also not being able to expect it to be a reference `&Response`. +// (Otherwise, a closure can be written as `|res: &_|`, and then be broken if +// we make the closure take ownership.) +// +// With the type not being nameable, we could change from being a facade to +// being either a real reference, or moving the http::Response into the closure, +// in a backwards-compatible change in the future. +#[derive(Debug)] +pub struct Response<'a>(&'a http::Response<()>); + +impl Response<'_> { + #[inline] + pub fn status(&self) -> http::StatusCode { + self.0.status() + } + + #[inline] + pub fn version(&self) -> http::Version { + self.0.version() + } + + #[inline] + pub fn headers(&self) -> &http::HeaderMap { + self.0.headers() + } +} diff --git a/src/ext/mod.rs b/src/ext/mod.rs index 1235202291..6ae6f6da12 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -19,6 +19,13 @@ mod h1_reason_phrase; #[cfg(any(feature = "http1", feature = "ffi"))] pub use h1_reason_phrase::ReasonPhrase; +#[cfg(any(feature = "http1", feature = "client"))] +mod informational; +#[cfg(any(feature = "http1", feature = "client"))] +pub use informational::on_informational; +#[cfg(any(feature = "http1", feature = "client"))] +pub(crate) use informational::OnInformational; + #[cfg(feature = "http2")] /// Represents the `:protocol` pseudo-header used by /// the [Extended CONNECT Protocol]. diff --git a/src/ffi/http_types.rs b/src/ffi/http_types.rs index 8807e29481..a1ea03dc1b 100644 --- a/src/ffi/http_types.rs +++ b/src/ffi/http_types.rs @@ -70,7 +70,7 @@ pub struct hyper_headers { } #[derive(Clone)] -pub(crate) struct OnInformational { +struct OnInformational { func: hyper_request_on_informational_callback, data: UserDataPointer, } @@ -273,7 +273,7 @@ ffi_fn! { data: UserDataPointer(data), }; let req = non_null!(&mut *req ?= hyper_code::HYPERE_INVALID_ARG); - req.0.extensions_mut().insert(ext); + crate::ext::on_informational_raw(&mut req.0, ext); hyper_code::HYPERE_OK } } @@ -567,9 +567,10 @@ unsafe fn raw_name_value( // ===== impl OnInformational ===== -impl OnInformational { - pub(crate) fn call(&mut self, resp: Response) { - let mut resp = hyper_response::wrap(resp); +impl crate::ext::OnInformationalCallback for OnInformational { + fn on_informational(&self, res: http::Response<()>) { + let res = res.map(|()| IncomingBody::empty()); + let mut res = hyper_response::wrap(res); (self.func)(self.data.0, &mut resp); } } diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 8ddf7558e1..39bcd47bd1 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -73,7 +73,7 @@ where preserve_header_order: false, title_case_headers: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: None, notify_read: false, reading: Reading::Init, @@ -246,7 +246,7 @@ where #[cfg(feature = "ffi")] preserve_header_order: self.state.preserve_header_order, h09_responses: self.state.h09_responses, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut self.state.on_informational, }, ) { @@ -286,7 +286,7 @@ where self.state.h09_responses = false; // Drop any OnInformational callbacks, we're done there! - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] { self.state.on_informational = None; } @@ -636,10 +636,10 @@ where debug_assert!(head.headers.is_empty()); self.state.cached_headers = Some(head.headers); - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] { self.state.on_informational = - head.extensions.remove::(); + head.extensions.remove::(); } Some(encoder) @@ -943,8 +943,8 @@ struct State { /// If set, called with each 1xx informational response received for /// the current request. MUST be unset after a non-1xx response is /// received. - #[cfg(feature = "ffi")] - on_informational: Option, + #[cfg(feature = "client")] + on_informational: Option, /// Set to true when the Dispatcher should poll read operations /// again. See the `maybe_notify` method for more. notify_read: bool, diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 950bfee098..d5afba683a 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -188,7 +188,7 @@ where #[cfg(feature = "ffi")] preserve_header_order: parse_ctx.preserve_header_order, h09_responses: parse_ctx.h09_responses, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: parse_ctx.on_informational, }, )? { @@ -710,7 +710,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; assert!(buffered diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 017b8671fb..a8f36f5fd9 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -77,8 +77,8 @@ pub(crate) struct ParseContext<'a> { #[cfg(feature = "ffi")] preserve_header_order: bool, h09_responses: bool, - #[cfg(feature = "ffi")] - on_informational: &'a mut Option, + #[cfg(feature = "client")] + on_informational: &'a mut Option, } /// Passed to Http1Transaction::encode diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 528c2b81dd..405f4b4645 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1153,10 +1153,9 @@ impl Http1Transaction for Client { })); } - #[cfg(feature = "ffi")] if head.subject.is_informational() { if let Some(callback) = ctx.on_informational { - callback.call(head.into_response(crate::body::Incoming::empty())); + callback.call(head.into_response(())); } } @@ -1661,7 +1660,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -1689,7 +1688,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); @@ -1734,7 +1733,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: true, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); @@ -1757,7 +1756,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; Client::parse(&mut raw, ctx).unwrap_err(); @@ -1784,7 +1783,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); @@ -1808,7 +1807,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; Client::parse(&mut raw, ctx).unwrap_err(); @@ -1828,7 +1827,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }; let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap(); @@ -1867,7 +1866,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -1888,7 +1887,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -2118,7 +2117,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, } ) @@ -2139,7 +2138,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -2160,7 +2159,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -2730,7 +2729,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -2774,7 +2773,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ); @@ -2798,7 +2797,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ); @@ -2967,7 +2966,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) @@ -3012,7 +3011,7 @@ mod tests { #[cfg(feature = "ffi")] preserve_header_order: false, h09_responses: false, - #[cfg(feature = "ffi")] + #[cfg(feature = "client")] on_informational: &mut None, }, ) diff --git a/tests/client.rs b/tests/client.rs index 6808a6855f..1f1a456f95 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -2098,6 +2098,46 @@ mod conn { let _res = client.send_request(req).await.expect("send_request"); } + #[tokio::test] + async fn client_on_informational_ext() { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + let (server, addr) = setup_std_test_server(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))) + .unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").unwrap(); + sock.write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .unwrap(); + }); + + let tcp = tcp_connect(&addr).await.unwrap(); + + let (mut client, conn) = conn::http1::handshake(tcp).await.unwrap(); + + tokio::spawn(async move { + let _ = conn.await; + }); + + let mut req = Request::builder() + .uri("/a") + .body(Empty::::new()) + .unwrap(); + let cnt = Arc::new(AtomicUsize::new(0)); + let cnt2 = cnt.clone(); + hyper::ext::on_informational(&mut req, move |res| { + assert_eq!(res.status(), 100); + cnt2.fetch_add(1, Ordering::Relaxed); + }); + let _res = client.send_request(req).await.expect("send_request"); + assert_eq!(1, cnt.load(Ordering::Relaxed)); + } + #[tokio::test] async fn test_try_send_request() { use std::future::Future;