diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 43fba3b793..918ac62cdc 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,7 +8,7 @@ use std::time::Duration; use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; -use http::header::{HeaderValue, CONNECTION}; +use http::header::{HeaderValue, CONNECTION, TE}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; @@ -75,6 +75,7 @@ where // We assume a modern world where the remote speaks HTTP/1.1. // If they tell us otherwise, we'll downgrade in `read_head`. version: Version::HTTP_11, + allow_trailer_fields: false, }, _marker: PhantomData, } @@ -264,6 +265,13 @@ where self.state.reading = Reading::Body(Decoder::new(msg.decode)); } + self.state.allow_trailer_fields = msg + .head + .headers + .get(TE) + .map(|te_header| te_header == "trailers") + .unwrap_or(false); + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) } @@ -640,6 +648,31 @@ where self.state.writing = state; } + pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) { + if T::is_server() && self.state.allow_trailer_fields == false { + debug!("trailers not allowed to be sent"); + return; + } + debug_assert!(self.can_write_body() && self.can_buffer_body()); + + match self.state.writing { + Writing::Body(ref encoder) => { + if let Some(enc_buf) = + encoder.encode_trailers(trailers, self.state.title_case_headers) + { + self.io.buffer(enc_buf); + + self.state.writing = if encoder.is_last() || encoder.is_close_delimited() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + _ => unreachable!("write_trailers invalid state: {:?}", self.state.writing), + } + } + pub(crate) fn write_body_and_end(&mut self, chunk: B) { debug_assert!(self.can_write_body() && self.can_buffer_body()); // empty chunks should be discarded at Dispatcher level @@ -842,6 +875,8 @@ struct State { upgrade: Option, /// Either HTTP/1.0 or 1.1 connection version: Version, + /// Flag to track if trailer fields are allowed to be sent + allow_trailer_fields: bool, } #[derive(Debug)] diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index c29c15dcae..0871af12ef 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -351,27 +351,33 @@ where *clear_body = true; crate::Error::new_user_body(e) })?; - let chunk = if let Ok(data) = frame.into_data() { - data - } else { - trace!("discarding non-data frame"); - continue; - }; - let eos = body.is_end_stream(); - if eos { - *clear_body = true; - if chunk.remaining() == 0 { - trace!("discarding empty chunk"); - self.conn.end_body()?; + + if frame.is_data() { + let chunk = frame.into_data().unwrap_or_else(|_| unreachable!()); + let eos = body.is_end_stream(); + if eos { + *clear_body = true; + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body()?; + } else { + self.conn.write_body_and_end(chunk); + } } else { - self.conn.write_body_and_end(chunk); + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); } + } else if frame.is_trailers() { + *clear_body = true; + self.conn.write_trailers( + frame.into_trailers().unwrap_or_else(|_| unreachable!()), + ); } else { - if chunk.remaining() == 0 { - trace!("discarding empty chunk"); - continue; - } - self.conn.write_body(chunk); + trace!("discarding unknown frame"); + continue; } } else { *clear_body = true; diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index c98c55d664..90eeae4712 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -1,10 +1,19 @@ +use std::collections::HashMap; use std::fmt; use std::io::IoSlice; use bytes::buf::{Chain, Take}; -use bytes::Buf; +use bytes::{Buf, Bytes}; +use http::{ + header::{ + AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING, + }, + HeaderMap, HeaderName, HeaderValue, +}; use super::io::WriteBuf; +use super::role::{write_headers, write_headers_title_case}; type StaticBuf = &'static [u8]; @@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64); #[derive(Debug, PartialEq, Clone)] enum Kind { /// An Encoder for when Transfer-Encoding includes `chunked`. - Chunked, + Chunked(Option>), /// An Encoder for when Content-Length is set. /// /// Enforces that the body is not longer than the Content-Length header. @@ -45,6 +54,7 @@ enum BufKind { Limited(Take), Chunked(Chain, StaticBuf>), ChunkedEnd(StaticBuf), + Trailers(Chain, StaticBuf>), } impl Encoder { @@ -55,7 +65,7 @@ impl Encoder { } } pub(crate) fn chunked() -> Encoder { - Encoder::new(Kind::Chunked) + Encoder::new(Kind::Chunked(None)) } pub(crate) fn length(len: u64) -> Encoder { @@ -67,6 +77,16 @@ impl Encoder { Encoder::new(Kind::CloseDelimited) } + pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec) -> Encoder { + match self.kind { + Kind::Chunked(_) => Encoder { + kind: Kind::Chunked(Some(trailers)), + is_last: self.is_last, + }, + _ => self, + } + } + pub(crate) fn is_eof(&self) -> bool { matches!(self.kind, Kind::Length(0)) } @@ -89,10 +109,17 @@ impl Encoder { } } + pub(crate) fn is_chunked(&self) -> bool { + match self.kind { + Kind::Chunked(_) => true, + _ => false, + } + } + pub(crate) fn end(&self) -> Result>, NotEof> { match self.kind { Kind::Length(0) => Ok(None), - Kind::Chunked => Ok(Some(EncodedBuf { + Kind::Chunked(_) => Ok(Some(EncodedBuf { kind: BufKind::ChunkedEnd(b"0\r\n\r\n"), })), #[cfg(feature = "server")] @@ -109,7 +136,7 @@ impl Encoder { debug_assert!(len > 0, "encode() called with empty buf"); let kind = match self.kind { - Kind::Chunked => { + Kind::Chunked(_) => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) @@ -136,6 +163,53 @@ impl Encoder { EncodedBuf { kind } } + pub(crate) fn encode_trailers( + &self, + trailers: HeaderMap, + title_case_headers: bool, + ) -> Option> { + match &self.kind { + Kind::Chunked(Some(ref allowed_trailer_fields)) => { + let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields); + + let mut cur_name = None; + let mut allowed_trailers = HeaderMap::new(); + + for (opt_name, value) in trailers { + if let Some(n) = opt_name { + cur_name = Some(n); + } + let name = cur_name.as_ref().expect("current header name"); + + if allowed_trailer_field_map.contains_key(name.as_str()) + && valid_trailer_field(name) + { + allowed_trailers.insert(name, value); + } + } + + let mut buf = Vec::new(); + if title_case_headers { + write_headers_title_case(&allowed_trailers, &mut buf); + } else { + write_headers(&allowed_trailers, &mut buf); + } + + if buf.is_empty() { + return None; + } + + Some(EncodedBuf { + kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")), + }) + } + _ => { + debug!("attempted to encode trailers for non-chunked response"); + None + } + } + } + pub(super) fn encode_and_end(&self, msg: B, dst: &mut WriteBuf>) -> bool where B: Buf, @@ -144,7 +218,7 @@ impl Encoder { debug_assert!(len > 0, "encode() called with empty buf"); match self.kind { - Kind::Chunked => { + Kind::Chunked(_) => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) @@ -181,6 +255,40 @@ impl Encoder { } } +fn valid_trailer_field(name: &HeaderName) -> bool { + match name { + &AUTHORIZATION => false, + &CACHE_CONTROL => false, + &CONTENT_ENCODING => false, + &CONTENT_LENGTH => false, + &CONTENT_RANGE => false, + &CONTENT_TYPE => false, + &HOST => false, + &MAX_FORWARDS => false, + &SET_COOKIE => false, + &TRAILER => false, + &TRANSFER_ENCODING => false, + &TE => false, + _ => true, + } +} + +fn allowed_trailer_field_map(allowed_trailer_fields: &Vec) -> HashMap { + let mut trailer_map = HashMap::new(); + + for header_value in allowed_trailer_fields { + if let Ok(header_str) = header_value.to_str() { + let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect(); + + for item in items { + trailer_map.entry(item.to_string()).or_insert(()); + } + } + } + + trailer_map +} + impl Buf for EncodedBuf where B: Buf, @@ -192,6 +300,7 @@ where BufKind::Limited(ref b) => b.remaining(), BufKind::Chunked(ref b) => b.remaining(), BufKind::ChunkedEnd(ref b) => b.remaining(), + BufKind::Trailers(ref b) => b.remaining(), } } @@ -202,6 +311,7 @@ where BufKind::Limited(ref b) => b.chunk(), BufKind::Chunked(ref b) => b.chunk(), BufKind::ChunkedEnd(ref b) => b.chunk(), + BufKind::Trailers(ref b) => b.chunk(), } } @@ -212,6 +322,7 @@ where BufKind::Limited(ref mut b) => b.advance(cnt), BufKind::Chunked(ref mut b) => b.advance(cnt), BufKind::ChunkedEnd(ref mut b) => b.advance(cnt), + BufKind::Trailers(ref mut b) => b.advance(cnt), } } @@ -222,6 +333,7 @@ where BufKind::Limited(ref b) => b.chunks_vectored(dst), BufKind::Chunked(ref b) => b.chunks_vectored(dst), BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), + BufKind::Trailers(ref b) => b.chunks_vectored(dst), } } } @@ -327,7 +439,16 @@ impl std::error::Error for NotEof {} #[cfg(test)] mod tests { + use std::iter::FromIterator; + use bytes::BufMut; + use http::{ + header::{ + AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING, + }, + HeaderMap, HeaderName, HeaderValue, + }; use super::super::io::Cursor; use super::Encoder; @@ -402,4 +523,145 @@ mod tests { assert!(!encoder.is_eof()); encoder.end::<()>().unwrap(); } + + #[test] + fn chunked_with_valid_trailers() { + let encoder = Encoder::chunked(); + let trailers = vec![HeaderValue::from_static("chunky-trailer")]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let headers = HeaderMap::from_iter( + vec![ + ( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + ), + ( + HeaderName::from_static("should-not-be-included"), + HeaderValue::from_static("oops"), + ), + ] + .into_iter(), + ); + + let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n"); + } + + #[test] + fn chunked_with_multiple_trailer_headers() { + let encoder = Encoder::chunked(); + let trailers = vec![ + HeaderValue::from_static("chunky-trailer"), + HeaderValue::from_static("chunky-trailer-2"), + ]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let headers = HeaderMap::from_iter( + vec![ + ( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + ), + ( + HeaderName::from_static("chunky-trailer-2"), + HeaderValue::from_static("more header data"), + ), + ] + .into_iter(), + ); + + let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!( + dst, + b"0\r\nchunky-trailer: header data\r\nchunky-trailer-2: more header data\r\n\r\n" + ); + } + + #[test] + fn chunked_with_no_trailer_header() { + let encoder = Encoder::chunked(); + + let headers = HeaderMap::from_iter( + vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + )] + .into_iter(), + ); + + assert!(encoder + .encode_trailers::<&[u8]>(headers.clone(), false) + .is_none()); + + let trailers = vec![]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none()); + } + + #[test] + fn chunked_with_invalid_trailers() { + let encoder = Encoder::chunked(); + + let trailers = format!( + "{},{},{},{},{},{},{},{},{},{},{},{}", + AUTHORIZATION, + CACHE_CONTROL, + CONTENT_ENCODING, + CONTENT_LENGTH, + CONTENT_RANGE, + CONTENT_TYPE, + HOST, + MAX_FORWARDS, + SET_COOKIE, + TRAILER, + TRANSFER_ENCODING, + TE, + ); + let trailers = vec![HeaderValue::from_str(&trailers).unwrap()]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("header data")); + headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data")); + headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data")); + headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data")); + headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data")); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data")); + headers.insert(HOST, HeaderValue::from_static("header data")); + headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data")); + headers.insert(SET_COOKIE, HeaderValue::from_static("header data")); + headers.insert(TRAILER, HeaderValue::from_static("header data")); + headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data")); + headers.insert(TE, HeaderValue::from_static("header data")); + + assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none()); + } + + #[test] + fn chunked_with_title_case_headers() { + let encoder = Encoder::chunked(); + let trailers = vec![HeaderValue::from_static("chunky-trailer")]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let headers = HeaderMap::from_iter( + vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + )] + .into_iter(), + ); + let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n"); + } } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index c30a4948f9..6828db75a7 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -629,6 +629,7 @@ impl Server { }; let mut encoder = Encoder::length(0); + let mut allowed_trailer_fields: Option> = None; let mut wrote_date = false; let mut cur_name = None; let mut is_name_written = false; @@ -815,6 +816,38 @@ impl Server { header::DATE => { wrote_date = true; } + header::TRAILER => { + // check that we actually can send a chunked body... + if msg.head.version == Version::HTTP_10 + || !Server::can_chunked(msg.req_method, msg.head.subject) + { + continue; + } + + if !is_name_written { + is_name_written = true; + header_name_writer.write_header_name_with_colon( + dst, + "trailer: ", + header::TRAILER, + ); + extend(dst, value.as_bytes()); + } else { + extend(dst, b", "); + extend(dst, value.as_bytes()); + } + + match allowed_trailer_fields { + Some(ref mut allowed_trailer_fields) => { + allowed_trailer_fields.push(value); + } + None => { + allowed_trailer_fields = Some(vec![value]); + } + } + + continue 'headers; + } _ => (), } //TODO: this should perhaps instead combine them into @@ -899,6 +932,12 @@ impl Server { extend(dst, b"\r\n"); } + if encoder.is_chunked() { + if let Some(allowed_trailer_fields) = allowed_trailer_fields { + encoder = encoder.into_chunked_with_trailing_fields(allowed_trailer_fields); + } + } + Ok(encoder.set_last(is_last)) } } @@ -1306,6 +1345,19 @@ impl Client { } }; + let encoder = encoder.map(|enc| { + if enc.is_chunked() { + let allowed_trailer_fields: Vec = + headers.get_all(header::TRAILER).iter().cloned().collect(); + + if !allowed_trailer_fields.is_empty() { + return enc.into_chunked_with_trailing_fields(allowed_trailer_fields); + } + } + + enc + }); + // This is because we need a second mutable borrow to remove // content-length header. if let Some(encoder) = encoder { @@ -1468,8 +1520,7 @@ fn title_case(dst: &mut Vec, name: &[u8]) { } } -#[cfg(feature = "client")] -fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec) { +pub(crate) fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec) { for (name, value) in headers { title_case(dst, name.as_str().as_bytes()); extend(dst, b": "); @@ -1478,8 +1529,7 @@ fn write_headers_title_case(headers: &HeaderMap, dst: &mut Vec) { } } -#[cfg(feature = "client")] -fn write_headers(headers: &HeaderMap, dst: &mut Vec) { +pub(crate) fn write_headers(headers: &HeaderMap, dst: &mut Vec) { for (name, value) in headers { extend(dst, name.as_str().as_bytes()); extend(dst, b": "); diff --git a/tests/client.rs b/tests/client.rs index b306016eea..8ac4a5e9b2 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -5,6 +5,7 @@ use std::convert::Infallible; use std::fmt; use std::future::Future; use std::io::{Read, Write}; +use std::iter::FromIterator; use std::net::{SocketAddr, TcpListener}; use std::pin::Pin; use std::thread; @@ -13,7 +14,7 @@ use std::time::Duration; use http::uri::PathAndQuery; use http_body_util::{BodyExt, StreamBody}; use hyper::body::Frame; -use hyper::header::HeaderValue; +use hyper::header::{HeaderMap, HeaderName, HeaderValue}; use hyper::{Method, Request, StatusCode, Uri, Version}; use bytes::Bytes; @@ -408,6 +409,15 @@ macro_rules! __client_req_prop { Frame::data, ))); }}; + + ($req_builder:ident, $body:ident, $addr:ident, body_stream_with_trailers: $body_e:expr) => {{ + use support::trailers::StreamBodyWithTrailers; + let (body, trailers) = $body_e; + $body = BodyExt::boxed(StreamBodyWithTrailers::with_trailers( + futures_util::TryStreamExt::map_ok(body, Frame::data), + trailers, + )); + }}; } macro_rules! __client_req_header { @@ -631,6 +641,44 @@ test! { body: &b"hello"[..], } +test! { + name: client_post_req_body_chunked_with_trailer, + + server: + expected: "\ + POST / HTTP/1.1\r\n\ + trailer: chunky-trailer\r\n\ + host: {addr}\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0\r\n\ + chunky-trailer: header data\r\n\ + \r\n\ + ", + reply: REPLY_OK, + + client: + request: { + method: POST, + url: "http://{addr}/", + headers: { + "trailer" => "chunky-trailer", + }, + body_stream_with_trailers: ( + (futures_util::stream::once(async { Ok::<_, Infallible>(Bytes::from("hello"))})), + HeaderMap::from_iter(vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data") + )].into_iter())), + }, + response: + status: OK, + headers: {}, + body: None, +} + test! { name: client_get_req_body_sized, diff --git a/tests/server.rs b/tests/server.rs index 16a5a9afbe..edf569a860 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -19,7 +19,7 @@ use futures_channel::oneshot; use futures_util::future::{self, Either, FutureExt}; use h2::client::SendRequest; use h2::{RecvStream, SendStream}; -use http::header::{HeaderName, HeaderValue}; +use http::header::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody}; use hyper::rt::Timer; use hyper::rt::{Read as AsyncRead, Write as AsyncWrite}; @@ -2595,6 +2595,94 @@ async fn http2_keep_alive_count_server_pings() { .expect("timed out waiting for pings"); } +#[test] +fn http1_trailer_fields() { + let body = futures_util::stream::once(async move { Ok("hello".into()) }); + let mut headers = HeaderMap::new(); + headers.insert("chunky-trailer", "header data".parse().unwrap()); + // Invalid trailer field that should not be sent + headers.insert("Host", "www.example.com".parse().unwrap()); + // Not specified in Trailer header, so should not be sent + headers.insert("foo", "bar".parse().unwrap()); + + let server = serve(); + server + .reply() + .header("transfer-encoding", "chunked") + .header("trailer", "chunky-trailer") + .body_stream_with_trailers(body, headers); + let mut req = connect(server.addr()); + req.write_all( + b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + TE: trailers\r\n\ + \r\n\ + ", + ) + .expect("writing"); + + let chunky_trailer_chunk = b"\r\nchunky-trailer: header data\r\n\r\n"; + let res = read_until(&mut req, |buf| buf.ends_with(chunky_trailer_chunk)).expect("reading"); + let sres = s(&res); + + let expected_head = + "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n"; + assert_eq!(&sres[..expected_head.len()], expected_head); + + // skip the date header + let date_fragment = "GMT\r\n\r\n"; + let pos = sres.find(date_fragment).expect("find GMT"); + let body = &sres[pos + date_fragment.len()..]; + + let expected_body = "5\r\nhello\r\n0\r\nchunky-trailer: header data\r\n\r\n"; + assert_eq!(body, expected_body); +} + +#[test] +fn http1_trailer_fields_not_allowed() { + let body = futures_util::stream::once(async move { Ok("hello".into()) }); + let mut headers = HeaderMap::new(); + headers.insert("chunky-trailer", "header data".parse().unwrap()); + + let server = serve(); + server + .reply() + .header("transfer-encoding", "chunked") + .header("trailer", "chunky-trailer") + .body_stream_with_trailers(body, headers); + let mut req = connect(server.addr()); + + // TE: trailers is not specified in request headers + req.write_all( + b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ", + ) + .expect("writing"); + + let last_chunk = b"\r\n0\r\n\r\n"; + let res = read_until(&mut req, |buf| buf.ends_with(last_chunk)).expect("reading"); + let sres = s(&res); + + let expected_head = + "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ntrailer: chunky-trailer\r\n"; + assert_eq!(&sres[..expected_head.len()], expected_head); + + // skip the date header + let date_fragment = "GMT\r\n\r\n"; + let pos = sres.find(date_fragment).expect("find GMT"); + let body = &sres[pos + date_fragment.len()..]; + + // no trailer fields should be sent because TE: trailers was not in request headers + let expected_body = "5\r\nhello\r\n0\r\n\r\n"; + assert_eq!(body, expected_body); +} + // ------------------------------------------------- // the Server that is used to run all the tests with // ------------------------------------------------- @@ -2700,6 +2788,19 @@ impl<'a> ReplyBuilder<'a> { self.tx.lock().unwrap().send(Reply::Body(body)).unwrap(); } + fn body_stream_with_trailers(self, stream: S, trailers: HeaderMap) + where + S: futures_util::Stream> + Send + Sync + 'static, + { + use futures_util::TryStreamExt; + use hyper::body::Frame; + use support::trailers::StreamBodyWithTrailers; + let mut stream_body = StreamBodyWithTrailers::new(stream.map_ok(Frame::data)); + stream_body.set_trailers(trailers); + let body = BodyExt::boxed(stream_body); + self.tx.lock().unwrap().send(Reply::Body(body)).unwrap(); + } + #[allow(dead_code)] fn error>(self, err: E) { self.tx diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796459412..1de834532d 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -24,6 +24,8 @@ mod tokiort; #[allow(unused)] pub use tokiort::{TokioExecutor, TokioIo, TokioTimer}; +pub mod trailers; + #[allow(unused_macros)] macro_rules! t { ( diff --git a/tests/support/trailers.rs b/tests/support/trailers.rs new file mode 100644 index 0000000000..a23664e31c --- /dev/null +++ b/tests/support/trailers.rs @@ -0,0 +1,76 @@ +use bytes::Buf; +use futures_util::stream::Stream; +use http::header::HeaderMap; +use http_body::{Body, Frame}; +use pin_project_lite::pin_project; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +pin_project! { + /// A body created from a [`Stream`]. + #[derive(Clone, Debug)] + pub struct StreamBodyWithTrailers { + #[pin] + stream: S, + trailers: Option, + } +} + +impl StreamBodyWithTrailers { + /// Create a new `StreamBodyWithTrailers`. + pub fn new(stream: S) -> Self { + Self { + stream, + trailers: None, + } + } + + pub fn with_trailers(stream: S, trailers: HeaderMap) -> Self { + Self { + stream, + trailers: Some(trailers), + } + } + + pub fn set_trailers(&mut self, trailers: HeaderMap) { + self.trailers = Some(trailers); + } +} + +impl Body for StreamBodyWithTrailers +where + S: Stream, E>>, + D: Buf, +{ + type Data = D; + type Error = E; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let project = self.project(); + match project.stream.poll_next(cx) { + Poll::Ready(Some(result)) => Poll::Ready(Some(result)), + Poll::Ready(None) => match project.trailers.take() { + Some(trailers) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))), + None => Poll::Ready(None), + }, + Poll::Pending => Poll::Pending, + } + } +} + +impl Stream for StreamBodyWithTrailers { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +}