diff --git a/http-body-util/Cargo.toml b/http-body-util/Cargo.toml index c9d36aa..e4110b6 100644 --- a/http-body-util/Cargo.toml +++ b/http-body-util/Cargo.toml @@ -33,4 +33,4 @@ http-body = { version = "1", path = "../http-body" } pin-project-lite = "0.2" [dev-dependencies] -tokio = { version = "1", features = ["macros", "rt"] } +tokio = { version = "1", features = ["macros", "rt", "sync"] } diff --git a/http-body-util/src/combinators/mod.rs b/http-body-util/src/combinators/mod.rs index 0ecdb0b..38d2637 100644 --- a/http-body-util/src/combinators/mod.rs +++ b/http-body-util/src/combinators/mod.rs @@ -5,6 +5,7 @@ mod collect; mod frame; mod map_err; mod map_frame; +mod with_trailers; pub use self::{ box_body::{BoxBody, UnsyncBoxBody}, @@ -12,4 +13,5 @@ pub use self::{ frame::Frame, map_err::MapErr, map_frame::MapFrame, + with_trailers::WithTrailers, }; diff --git a/http-body-util/src/combinators/with_trailers.rs b/http-body-util/src/combinators/with_trailers.rs new file mode 100644 index 0000000..9a9e525 --- /dev/null +++ b/http-body-util/src/combinators/with_trailers.rs @@ -0,0 +1,158 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::ready; +use http::HeaderMap; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; + +pin_project! { + /// Adds trailers to a body. + /// + /// See [`BodyExt::with_trailers`] for more details. + pub struct WithTrailers { + #[pin] + state: State, + } +} + +impl WithTrailers { + pub(crate) fn new(body: T, trailers: F) -> Self { + Self { + state: State::PollBody { + body, + trailers: Some(trailers), + }, + } + } +} + +pin_project! { + #[project = StateProj] + enum State { + PollBody { + #[pin] + body: T, + trailers: Option, + }, + PollTrailers { + #[pin] + trailers: F, + }, + Trailers { + trailers: Option, + } + } +} + +impl Body for WithTrailers +where + T: Body, + F: Future>>, +{ + type Data = T::Data; + type Error = T::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + loop { + let mut this = self.as_mut().project(); + + let new_state: State<_, _> = match this.state.as_mut().project() { + StateProj::PollBody { body, trailers } => match ready!(body.poll_frame(cx)?) { + Some(frame) => { + return Poll::Ready(Some(Ok(frame))); + } + None => { + let trailers = trailers.take().unwrap(); + State::PollTrailers { trailers } + } + }, + StateProj::PollTrailers { trailers } => { + let trailers = ready!(trailers.poll(cx)?); + State::Trailers { trailers } + } + StateProj::Trailers { trailers } => { + return Poll::Ready(trailers.take().map(Frame::trailers).map(Ok)); + } + }; + + this.state.set(new_state); + } + } + + #[inline] + fn is_end_stream(&self) -> bool { + match &self.state { + State::PollBody { body, .. } => body.is_end_stream(), + State::PollTrailers { .. } | State::Trailers { .. } => true, + } + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + match &self.state { + State::PollBody { body, .. } => body.size_hint(), + State::PollTrailers { .. } | State::Trailers { .. } => Default::default(), + } + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use bytes::Bytes; + use http::{HeaderMap, HeaderName, HeaderValue}; + + use crate::{BodyExt, Full}; + + #[allow(unused_imports)] + use super::*; + + #[tokio::test] + async fn works() { + let mut trailers = HeaderMap::new(); + trailers.insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("bar"), + ); + + let body = + Full::::from("hello").with_trailers(std::future::ready(Some( + Ok::<_, Infallible>(trailers.clone()), + ))); + + futures_util::pin_mut!(body); + let waker = futures_util::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + let data = unwrap_ready(body.as_mut().poll_frame(&mut cx)) + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(data, "hello"); + + let body_trailers = unwrap_ready(body.as_mut().poll_frame(&mut cx)) + .unwrap() + .unwrap() + .into_trailers() + .unwrap(); + assert_eq!(body_trailers, trailers); + + assert!(unwrap_ready(body.as_mut().poll_frame(&mut cx)).is_none()); + } + + fn unwrap_ready(poll: Poll) -> T { + match poll { + Poll::Ready(t) => t, + Poll::Pending => panic!("pending"), + } + } +} diff --git a/http-body-util/src/lib.rs b/http-body-util/src/lib.rs index 059ada6..1b715b3 100644 --- a/http-body-util/src/lib.rs +++ b/http-body-util/src/lib.rs @@ -89,6 +89,50 @@ pub trait BodyExt: http_body::Body { collected: Some(crate::Collected::default()), } } + + /// Add trailers to the body. + /// + /// The trailers will be sent when all previous frames have been sent and the `trailers` future + /// resolves. + /// + /// # Example + /// + /// ``` + /// use http::HeaderMap; + /// use http_body_util::{Full, BodyExt}; + /// use bytes::Bytes; + /// + /// # #[tokio::main] + /// async fn main() { + /// let (tx, rx) = tokio::sync::oneshot::channel::(); + /// + /// let body = Full::::from("Hello, World!") + /// // add trailers via a future + /// .with_trailers(async move { + /// match rx.await { + /// Ok(trailers) => Some(Ok(trailers)), + /// Err(_err) => None, + /// } + /// }); + /// + /// // compute the trailers in the background + /// tokio::spawn(async move { + /// let _ = tx.send(compute_trailers().await); + /// }); + /// + /// async fn compute_trailers() -> HeaderMap { + /// // ... + /// # unimplemented!() + /// } + /// # } + /// ``` + fn with_trailers(self, trailers: F) -> combinators::WithTrailers + where + Self: Sized, + F: std::future::Future>>, + { + combinators::WithTrailers::new(self, trailers) + } } impl BodyExt for T where T: http_body::Body {}