diff --git a/src/body/incoming.rs b/src/body/incoming.rs index e0beba6108..b7d1591cc4 100644 --- a/src/body/incoming.rs +++ b/src/body/incoming.rs @@ -403,6 +403,19 @@ impl Sender { .map_err(|err| err.into_inner().expect("just sent Ok")) } + #[cfg(feature = "http1")] + pub(crate) fn try_send_trailers( + &mut self, + trailers: HeaderMap, + ) -> Result<(), Option> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(None), + }; + + tx.send(trailers).map_err(|err| Some(err)) + } + #[cfg(test)] pub(crate) fn abort(mut self) { self.send_error(crate::Error::new_body_write_aborted()); diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 5f08982436..88b177203e 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -11,6 +11,7 @@ use bytes::{Buf, Bytes}; use futures_util::ready; use http::header::{HeaderValue, CONNECTION, TE}; use http::{HeaderMap, Method, Version}; +use http_body::Frame; use httparse::ParserConfig; use super::io::Buffered; @@ -268,10 +269,20 @@ where self.try_keep_alive(cx); } } else if msg.expect_continue && msg.head.version.gt(&Version::HTTP_10) { - self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + let h1_max_header_size = None; // TODO: remove this when we land h1_max_header_size support + self.state.reading = Reading::Continue(Decoder::new( + msg.decode, + self.state.h1_max_headers, + h1_max_header_size, + )); wants = wants.add(Wants::EXPECT); } else { - self.state.reading = Reading::Body(Decoder::new(msg.decode)); + let h1_max_header_size = None; // TODO: remove this when we land h1_max_header_size support + self.state.reading = Reading::Body(Decoder::new( + msg.decode, + self.state.h1_max_headers, + h1_max_header_size, + )); } self.state.allow_trailer_fields = msg @@ -312,33 +323,41 @@ where pub(crate) fn poll_read_body( &mut self, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>>> { debug_assert!(self.can_read_body()); let (reading, ret) = match self.state.reading { Reading::Body(ref mut decoder) => { match ready!(decoder.decode(cx, &mut self.io)) { - Ok(slice) => { - let (reading, chunk) = if decoder.is_eof() { - debug!("incoming body completed"); - ( - Reading::KeepAlive, - if !slice.is_empty() { - Some(Ok(slice)) - } else { - None - }, - ) - } else if slice.is_empty() { - error!("incoming body unexpectedly ended"); - // This should be unreachable, since all 3 decoders - // either set eof=true or return an Err when reading - // an empty slice... - (Reading::Closed, None) + Ok(frame) => { + if frame.is_data() { + let slice = frame.data_ref().unwrap_or_else(|| unreachable!()); + let (reading, maybe_frame) = if decoder.is_eof() { + debug!("incoming body completed"); + ( + Reading::KeepAlive, + if !slice.is_empty() { + Some(Ok(frame)) + } else { + None + }, + ) + } else if slice.is_empty() { + error!("incoming body unexpectedly ended"); + // This should be unreachable, since all 3 decoders + // either set eof=true or return an Err when reading + // an empty slice... + (Reading::Closed, None) + } else { + return Poll::Ready(Some(Ok(frame))); + }; + (reading, Poll::Ready(maybe_frame)) + } else if frame.is_trailers() { + (Reading::Closed, Poll::Ready(Some(Ok(frame)))) } else { - return Poll::Ready(Some(Ok(slice))); - }; - (reading, Poll::Ready(chunk)) + trace!("discarding unknown frame"); + (Reading::Closed, Poll::Ready(None)) + } } Err(e) => { debug!("incoming body decode error: {}", e); diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 05f3cc6e60..7e34074d76 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -4,10 +4,13 @@ use std::io; use std::task::{Context, Poll}; use std::usize; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use futures_util::ready; +use http::{HeaderMap, HeaderName, HeaderValue}; +use http_body::Frame; use super::io::MemRead; +use super::role::DEFAULT_MAX_HEADERS; use super::DecodedLength; use self::Kind::{Chunked, Eof, Length}; @@ -17,6 +20,11 @@ use self::Kind::{Chunked, Eof, Length}; /// This limit is currentlty applied for the entire body, not per chunk. const CHUNKED_EXTENSIONS_LIMIT: u64 = 1024 * 16; +/// Maximum number of bytes allowed for all trailer fields. +/// +/// TODO: remove this when we land h1_max_header_size support +const TRAILER_LIMIT: usize = 1024 * 16; + /// Decoders to handle different Transfer-Encodings. /// /// If a message body does not include a Transfer-Encoding, it *should* @@ -26,7 +34,7 @@ pub(crate) struct Decoder { kind: Kind, } -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, PartialEq)] enum Kind { /// A Reader used when a Content-Length header is passed with a positive integer. Length(u64), @@ -35,6 +43,10 @@ enum Kind { state: ChunkedState, chunk_len: u64, extensions_cnt: u64, + trailers_buf: Option, + trailers_cnt: usize, + h1_max_headers: Option, + h1_max_header_size: Option, }, /// A Reader used for responses that don't indicate a length or chunked. /// @@ -81,12 +93,19 @@ impl Decoder { } } - pub(crate) fn chunked() -> Decoder { + pub(crate) fn chunked( + h1_max_headers: Option, + h1_max_header_size: Option, + ) -> Decoder { Decoder { kind: Kind::Chunked { state: ChunkedState::new(), chunk_len: 0, extensions_cnt: 0, + trailers_buf: None, + trailers_cnt: 0, + h1_max_headers, + h1_max_header_size, }, } } @@ -97,9 +116,13 @@ impl Decoder { } } - pub(super) fn new(len: DecodedLength) -> Self { + pub(super) fn new( + len: DecodedLength, + h1_max_headers: Option, + h1_max_header_size: Option, + ) -> Self { match len { - DecodedLength::CHUNKED => Decoder::chunked(), + DecodedLength::CHUNKED => Decoder::chunked(h1_max_headers, h1_max_header_size), DecodedLength::CLOSE_DELIMITED => Decoder::eof(), length => Decoder::length(length.danger_len()), } @@ -123,12 +146,12 @@ impl Decoder { &mut self, cx: &mut Context<'_>, body: &mut R, - ) -> Poll> { + ) -> Poll, io::Error>> { trace!("decode; state={:?}", self.kind); match self.kind { Length(ref mut remaining) => { if *remaining == 0 { - Poll::Ready(Ok(Bytes::new())) + Poll::Ready(Ok(Frame::data(Bytes::new()))) } else { let to_read = *remaining as usize; let buf = ready!(body.read_mem(cx, to_read))?; @@ -143,37 +166,77 @@ impl Decoder { } else { *remaining -= num; } - Poll::Ready(Ok(buf)) + Poll::Ready(Ok(Frame::data(buf))) } } Chunked { ref mut state, ref mut chunk_len, ref mut extensions_cnt, + ref mut trailers_buf, + ref mut trailers_cnt, + ref h1_max_headers, + ref h1_max_header_size, } => { + let h1_max_headers = h1_max_headers.unwrap_or(DEFAULT_MAX_HEADERS); + let h1_max_header_size = h1_max_header_size.unwrap_or(TRAILER_LIMIT); loop { let mut buf = None; // advances the chunked state - *state = ready!(state.step(cx, body, chunk_len, extensions_cnt, &mut buf))?; + *state = ready!(state.step( + cx, + body, + chunk_len, + extensions_cnt, + &mut buf, + trailers_buf, + trailers_cnt, + h1_max_headers, + h1_max_header_size + ))?; if *state == ChunkedState::End { trace!("end of chunked"); - return Poll::Ready(Ok(Bytes::new())); + + if trailers_buf.is_some() { + trace!("found possible trailers"); + + // decoder enforces that trailers count will not exceed h1_max_headers + if *trailers_cnt >= h1_max_headers { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "chunk trailers count overflow", + ))); + } + match decode_trailers( + &mut trailers_buf.take().expect("Trailer is None"), + *trailers_cnt, + ) { + Ok(headers) => { + return Poll::Ready(Ok(Frame::trailers(headers))); + } + Err(e) => { + return Poll::Ready(Err(e)); + } + } + } + + return Poll::Ready(Ok(Frame::data(Bytes::new()))); } if let Some(buf) = buf { - return Poll::Ready(Ok(buf)); + return Poll::Ready(Ok(Frame::data(buf))); } } } Eof(ref mut is_eof) => { if *is_eof { - Poll::Ready(Ok(Bytes::new())) + Poll::Ready(Ok(Frame::data(Bytes::new()))) } else { // 8192 chosen because its about 2 packets, there probably // won't be that much available, so don't have MemReaders // allocate buffers to big body.read_mem(cx, 8192).map_ok(|slice| { *is_eof = slice.is_empty(); - slice + Frame::data(slice) }) } } @@ -181,7 +244,7 @@ impl Decoder { } #[cfg(test)] - async fn decode_fut(&mut self, body: &mut R) -> Result { + async fn decode_fut(&mut self, body: &mut R) -> Result, io::Error> { futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await } } @@ -216,6 +279,19 @@ macro_rules! or_overflow { ) } +macro_rules! put_u8 { + ($trailers_buf:expr, $byte:expr, $limit:expr) => { + $trailers_buf.put_u8($byte); + + if $trailers_buf.len() >= $limit { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "chunk trailers bytes over limit", + ))); + } + }; +} + impl ChunkedState { fn new() -> ChunkedState { ChunkedState::Start @@ -227,6 +303,10 @@ impl ChunkedState { size: &mut u64, extensions_cnt: &mut u64, buf: &mut Option, + trailers_buf: &mut Option, + trailers_cnt: &mut usize, + h1_max_headers: usize, + h1_max_header_size: usize, ) -> Poll> { use self::ChunkedState::*; match *self { @@ -238,10 +318,17 @@ impl ChunkedState { Body => ChunkedState::read_body(cx, body, size, buf), BodyCr => ChunkedState::read_body_cr(cx, body), BodyLf => ChunkedState::read_body_lf(cx, body), - Trailer => ChunkedState::read_trailer(cx, body), - TrailerLf => ChunkedState::read_trailer_lf(cx, body), - EndCr => ChunkedState::read_end_cr(cx, body), - EndLf => ChunkedState::read_end_lf(cx, body), + Trailer => ChunkedState::read_trailer(cx, body, trailers_buf, h1_max_header_size), + TrailerLf => ChunkedState::read_trailer_lf( + cx, + body, + trailers_buf, + trailers_cnt, + h1_max_headers, + h1_max_header_size, + ), + EndCr => ChunkedState::read_end_cr(cx, body, trailers_buf, h1_max_header_size), + EndLf => ChunkedState::read_end_lf(cx, body, trailers_buf, h1_max_header_size), End => Poll::Ready(Ok(ChunkedState::End)), } } @@ -442,19 +529,51 @@ impl ChunkedState { fn read_trailer( cx: &mut Context<'_>, rdr: &mut R, + trailers_buf: &mut Option, + h1_max_header_size: usize, ) -> Poll> { trace!("read_trailer"); - match byte!(rdr, cx) { + let byte = byte!(rdr, cx); + + put_u8!( + trailers_buf.as_mut().expect("trailers_buf is None"), + byte, + h1_max_header_size + ); + + match byte { b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)), _ => Poll::Ready(Ok(ChunkedState::Trailer)), } } + fn read_trailer_lf( cx: &mut Context<'_>, rdr: &mut R, + trailers_buf: &mut Option, + trailers_cnt: &mut usize, + h1_max_headers: usize, + h1_max_header_size: usize, ) -> Poll> { - match byte!(rdr, cx) { - b'\n' => Poll::Ready(Ok(ChunkedState::EndCr)), + let byte = byte!(rdr, cx); + match byte { + b'\n' => { + if *trailers_cnt >= h1_max_headers { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "chunk trailers count overflow", + ))); + } + *trailers_cnt += 1; + + put_u8!( + trailers_buf.as_mut().expect("trailers_buf is None"), + byte, + h1_max_header_size + ); + + Poll::Ready(Ok(ChunkedState::EndCr)) + } _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid trailer end LF", @@ -465,18 +584,48 @@ impl ChunkedState { fn read_end_cr( cx: &mut Context<'_>, rdr: &mut R, + trailers_buf: &mut Option, + h1_max_header_size: usize, ) -> Poll> { - match byte!(rdr, cx) { - b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), - _ => Poll::Ready(Ok(ChunkedState::Trailer)), + let byte = byte!(rdr, cx); + match byte { + b'\r' => { + if let Some(trailers_buf) = trailers_buf { + put_u8!(trailers_buf, byte, h1_max_header_size); + } + Poll::Ready(Ok(ChunkedState::EndLf)) + } + byte => { + match trailers_buf { + None => { + // 64 will fit a single Expires header without reallocating + let mut buf = BytesMut::with_capacity(64); + buf.put_u8(byte); + *trailers_buf = Some(buf); + } + Some(ref mut trailers_buf) => { + put_u8!(trailers_buf, byte, h1_max_header_size); + } + } + + Poll::Ready(Ok(ChunkedState::Trailer)) + } } } fn read_end_lf( cx: &mut Context<'_>, rdr: &mut R, + trailers_buf: &mut Option, + h1_max_header_size: usize, ) -> Poll> { - match byte!(rdr, cx) { - b'\n' => Poll::Ready(Ok(ChunkedState::End)), + let byte = byte!(rdr, cx); + match byte { + b'\n' => { + if let Some(trailers_buf) = trailers_buf { + put_u8!(trailers_buf, byte, h1_max_header_size); + } + Poll::Ready(Ok(ChunkedState::End)) + } _ => Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidInput, "Invalid chunk end LF", @@ -485,6 +634,48 @@ impl ChunkedState { } } +// TODO: disallow Transfer-Encoding, Content-Length, Trailer, etc in trailers ?? +fn decode_trailers(buf: &mut BytesMut, count: usize) -> Result { + let mut trailers = HeaderMap::new(); + let mut headers = vec![httparse::EMPTY_HEADER; count]; + let res = httparse::parse_headers(&buf, &mut headers); + match res { + Ok(httparse::Status::Complete((_, headers))) => { + for header in headers.iter() { + use std::convert::TryFrom; + let name = match HeaderName::try_from(header.name) { + Ok(name) => name, + Err(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid header name: {:?}", &header), + )); + } + }; + + let value = match HeaderValue::from_bytes(header.value) { + Ok(value) => value, + Err(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid header value: {:?}", &header), + )); + } + }; + + trailers.insert(name, value); + } + + Ok(trailers) + } + Ok(httparse::Status::Partial) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Partial header", + )), + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + } +} + #[derive(Debug)] struct IncompleteBody; @@ -554,9 +745,20 @@ mod tests { let rdr = &mut s.as_bytes(); let mut size = 0; let mut ext_cnt = 0; + let mut trailers_cnt = 0; loop { let result = futures_util::future::poll_fn(|cx| { - state.step(cx, rdr, &mut size, &mut ext_cnt, &mut None) + state.step( + cx, + rdr, + &mut size, + &mut ext_cnt, + &mut None, + &mut None, + &mut trailers_cnt, + DEFAULT_MAX_HEADERS, + TRAILER_LIMIT, + ) }) .await; let desc = format!("read_size failed for {:?}", s); @@ -573,9 +775,20 @@ mod tests { let rdr = &mut s.as_bytes(); let mut size = 0; let mut ext_cnt = 0; + let mut trailers_cnt = 0; loop { let result = futures_util::future::poll_fn(|cx| { - state.step(cx, rdr, &mut size, &mut ext_cnt, &mut None) + state.step( + cx, + rdr, + &mut size, + &mut ext_cnt, + &mut None, + &mut None, + &mut trailers_cnt, + DEFAULT_MAX_HEADERS, + TRAILER_LIMIT, + ) }) .await; state = match result { @@ -639,7 +852,16 @@ mod tests { async fn test_read_sized_early_eof() { let mut bytes = &b"foo bar"[..]; let mut decoder = Decoder::length(10); - assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + assert_eq!( + decoder + .decode_fut(&mut bytes) + .await + .unwrap() + .data_ref() + .unwrap() + .len(), + 7 + ); let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); } @@ -651,8 +873,17 @@ mod tests { 9\r\n\ foo bar\ "[..]; - let mut decoder = Decoder::chunked(); - assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let mut decoder = Decoder::chunked(None, None); + assert_eq!( + decoder + .decode_fut(&mut bytes) + .await + .unwrap() + .data_ref() + .unwrap() + .len(), + 7 + ); let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); } @@ -661,10 +892,12 @@ mod tests { #[tokio::test] async fn test_read_chunked_single_read() { let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; - let buf = Decoder::chunked() + let buf = Decoder::chunked(None, None) .decode_fut(&mut mock_buf) .await - .expect("decode"); + .expect("decode") + .into_data() + .expect("unknown frame type"); assert_eq!(16, buf.len()); let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); assert_eq!("1234567890abcdef", &result); @@ -684,8 +917,13 @@ mod tests { scratch.extend(b"0\r\n\r\n"); let mut mock_buf = Bytes::from(scratch); - let mut decoder = Decoder::chunked(); - let buf1 = decoder.decode_fut(&mut mock_buf).await.expect("decode1"); + let mut decoder = Decoder::chunked(None, None); + let buf1 = decoder + .decode_fut(&mut mock_buf) + .await + .expect("decode1") + .into_data() + .expect("unknown frame type"); assert_eq!(&buf1[..], b"A"); let err = decoder @@ -700,7 +938,7 @@ mod tests { #[tokio::test] async fn test_read_chunked_trailer_with_missing_lf() { let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..]; - let mut decoder = Decoder::chunked(); + let mut decoder = Decoder::chunked(None, None); decoder.decode_fut(&mut mock_buf).await.expect("decode"); let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::InvalidInput); @@ -710,20 +948,35 @@ mod tests { #[tokio::test] async fn test_read_chunked_after_eof() { let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..]; - let mut decoder = Decoder::chunked(); + let mut decoder = Decoder::chunked(None, None); // normal read - let buf = decoder.decode_fut(&mut mock_buf).await.unwrap(); + let buf = decoder + .decode_fut(&mut mock_buf) + .await + .unwrap() + .into_data() + .expect("unknown frame type"); assert_eq!(16, buf.len()); let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); assert_eq!("1234567890abcdef", &result); // eof read - let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + let buf = decoder + .decode_fut(&mut mock_buf) + .await + .expect("decode") + .into_data() + .expect("unknown frame type"); assert_eq!(0, buf.len()); // ensure read after eof also returns eof - let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); + let buf = decoder + .decode_fut(&mut mock_buf) + .await + .expect("decode") + .into_data() + .expect("unknown frame type"); assert_eq!(0, buf.len()); } @@ -751,7 +1004,9 @@ mod tests { let buf = decoder .decode_fut(&mut ins) .await - .expect("unexpected decode error"); + .expect("unexpected decode error") + .into_data() + .expect("unexpected frame type"); if buf.is_empty() { break; // eof } @@ -783,7 +1038,7 @@ mod tests { async fn test_read_chunked_async() { let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n"; let expected = "foobar"; - all_async_cases(content, expected, Decoder::chunked()).await; + all_async_cases(content, expected, Decoder::chunked(None, None)).await; } #[cfg(not(miri))] @@ -808,10 +1063,15 @@ mod tests { b.bytes = LEN as u64; b.iter(|| { - let mut decoder = Decoder::chunked(); + let mut decoder = Decoder::chunked(None, None); rt.block_on(async { let mut raw = content.clone(); - let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + let chunk = decoder + .decode_fut(&mut raw) + .await + .unwrap() + .into_data() + .unwrap(); assert_eq!(chunk.len(), LEN); }); }); @@ -830,7 +1090,12 @@ mod tests { let mut decoder = Decoder::length(LEN as u64); rt.block_on(async { let mut raw = content.clone(); - let chunk = decoder.decode_fut(&mut raw).await.unwrap(); + let chunk = decoder + .decode_fut(&mut raw) + .await + .unwrap() + .into_data() + .unwrap(); assert_eq!(chunk.len(), LEN); }); }); @@ -843,4 +1108,110 @@ mod tests { .build() .expect("rt build") } + + #[test] + fn test_decode_trailers() { + let mut buf = BytesMut::new(); + buf.extend_from_slice( + b"Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\nX-Stream-Error: failed to decode\r\n\r\n", + ); + let headers = decode_trailers(&mut buf, 2).expect("decode_trailers"); + assert_eq!(headers.len(), 2); + assert_eq!( + headers.get("Expires").unwrap(), + "Wed, 21 Oct 2015 07:28:00 GMT" + ); + assert_eq!(headers.get("X-Stream-Error").unwrap(), "failed to decode"); + } + + #[tokio::test] + async fn test_trailer_max_headers_enforced() { + let h1_max_headers = 10; + let mut scratch = vec![]; + scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n"); + for i in 0..h1_max_headers { + scratch.extend(format!("trailer{}: {}\r\n", i, i).as_bytes()); + } + scratch.extend(b"\r\n"); + let mut mock_buf = Bytes::from(scratch); + + let mut decoder = Decoder::chunked(Some(h1_max_headers), None); + + // ready chunked body + let buf = decoder + .decode_fut(&mut mock_buf) + .await + .unwrap() + .into_data() + .expect("unknown frame type"); + assert_eq!(16, buf.len()); + + // eof read + let err = decoder + .decode_fut(&mut mock_buf) + .await + .expect_err("trailer fields over limit"); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn test_trailer_max_header_size_huge_trailer() { + let max_header_size = 1024; + let mut scratch = vec![]; + scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n"); + scratch.extend(format!("huge_trailer: {}\r\n", "x".repeat(max_header_size)).as_bytes()); + scratch.extend(b"\r\n"); + let mut mock_buf = Bytes::from(scratch); + + let mut decoder = Decoder::chunked(None, Some(max_header_size)); + + // ready chunked body + let buf = decoder + .decode_fut(&mut mock_buf) + .await + .unwrap() + .into_data() + .expect("unknown frame type"); + assert_eq!(16, buf.len()); + + // eof read + let err = decoder + .decode_fut(&mut mock_buf) + .await + .expect_err("trailers over limit"); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn test_trailer_max_header_size_many_small_trailers() { + let max_headers = 10; + let header_size = 64; + let mut scratch = vec![]; + scratch.extend(b"10\r\n1234567890abcdef\r\n0\r\n"); + + for i in 0..max_headers { + scratch.extend(format!("trailer{}: {}\r\n", i, "x".repeat(header_size)).as_bytes()); + } + + scratch.extend(b"\r\n"); + let mut mock_buf = Bytes::from(scratch); + + let mut decoder = Decoder::chunked(None, Some(max_headers * header_size)); + + // ready chunked body + let buf = decoder + .decode_fut(&mut mock_buf) + .await + .unwrap() + .into_data() + .expect("unknown frame type"); + assert_eq!(16, buf.len()); + + // eof read + let err = decoder + .decode_fut(&mut mock_buf) + .await + .expect_err("trailers over limit"); + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + } } diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 3a4faf0487..be0a88c1f9 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -213,17 +213,39 @@ where } } match self.conn.poll_read_body(cx) { - Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) { - Ok(()) => { - self.body_tx = Some(body); - } - Err(_canceled) => { - if self.conn.can_read_body() { - trace!("body receiver dropped before eof, closing"); - self.conn.close_read(); + Poll::Ready(Some(Ok(frame))) => { + if frame.is_data() { + let chunk = frame.into_data().unwrap_or_else(|_| unreachable!()); + match body.try_send_data(chunk) { + Ok(()) => { + self.body_tx = Some(body); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } + } + } else if frame.is_trailers() { + let trailers = + frame.into_trailers().unwrap_or_else(|_| unreachable!()); + match body.try_send_trailers(trailers) { + Ok(()) => { + self.body_tx = Some(body); + } + Err(_canceled) => { + if self.conn.can_read_body() { + trace!("body receiver dropped before eof, closing"); + self.conn.close_read(); + } + } } + } else { + // we should have dropped all unknown frames in poll_read_body + error!("unexpected frame"); } - }, + } Poll::Ready(None) => { // just drop, the body will close automatically } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index f2c1b73084..a7f21e53aa 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -30,7 +30,7 @@ use crate::proto::h1::{ use crate::proto::RequestHead; use crate::proto::{BodyLength, MessageHead, RequestLine}; -const DEFAULT_MAX_HEADERS: usize = 100; +pub(crate) const DEFAULT_MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific #[cfg(feature = "server")] const MAX_URI_LEN: usize = (u16::MAX - 1) as usize; diff --git a/tests/client.rs b/tests/client.rs index 43e1f08acb..6c9d3d7587 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -34,6 +34,17 @@ where b.collect().await.map(|c| c.to_bytes()) } +async fn concat_with_trailers(b: B) -> Result<(Bytes, Option), B::Error> +where + B: hyper::body::Body, +{ + let collect = b.collect().await?; + let trailers = collect.trailers().cloned(); + let bytes = collect.to_bytes(); + + Ok((bytes, trailers)) +} + async fn tcp_connect(addr: &SocketAddr) -> std::io::Result> { TcpStream::connect(*addr).await.map(TokioIo::new) } @@ -122,6 +133,9 @@ macro_rules! test { status: $client_status:ident, headers: { $($response_header_name:expr => $response_header_val:expr,)* }, body: $response_body:expr, + $(trailers: {$( + $response_trailer_name:expr => $response_trailer_val:expr, + )*},)? ) => ( #[test] fn $name() { @@ -158,12 +172,23 @@ macro_rules! test { ); )* - let body = rt.block_on(concat(res)) + let (body, _trailers) = rt.block_on(concat_with_trailers(res)) .expect("body concat wait"); let expected_res_body = Option::<&[u8]>::from($response_body) .unwrap_or_default(); assert_eq!(body.as_ref(), expected_res_body); + + $($( + assert_eq!( + _trailers.as_ref().expect("trailers is None") + .get($response_trailer_name) + .expect(concat!("trailer header '", stringify!($response_trailer_name), "'")), + $response_trailer_val, + "trailer '{}'", + stringify!($response_trailer_name), + ); + )*)? } ); ( @@ -679,6 +704,94 @@ test! { body: None, } +test! { + name: client_res_body_chunked_with_trailer, + + server: + expected: "GET / HTTP/1.1\r\nte: trailers\r\nhost: {addr}\r\n\r\n", + reply: "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + trailer: chunky-trailer\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0\r\n\ + chunky-trailer: header data\r\n\ + \r\n\ + ", + + client: + request: { + method: GET, + url: "http://{addr}/", + headers: { + "te" => "trailers", + }, + }, + response: + status: OK, + headers: { + "Transfer-Encoding" => "chunked", + }, + body: &b"hello"[..], + trailers: { + "chunky-trailer" => "header data", + }, +} + +test! { + name: client_res_body_chunked_with_pathological_trailers, + + server: + expected: "GET / HTTP/1.1\r\nte: trailers\r\nhost: {addr}\r\n\r\n", + reply: "\ + HTTP/1.1 200 OK\r\n\ + transfer-encoding: chunked\r\n\ + trailer: chunky-trailer1, chunky-trailer2, chunky-trailer3, chunky-trailer4, chunky-trailer5\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0\r\n\ + chunky-trailer1: header data1\r\n\ + chunky-trailer2: header data2\r\n\ + chunky-trailer3: header data3\r\n\ + chunky-trailer4: header data4\r\n\ + chunky-trailer5: header data5\r\n\ + sneaky-trailer: not in trailer header\r\n\ + transfer-encoding: chunked\r\n\ + content-length: 5\r\n\ + trailer: foo\r\n\ + \r\n\ + ", + + client: + request: { + method: GET, + url: "http://{addr}/", + headers: { + "te" => "trailers", + }, + }, + response: + status: OK, + headers: { + "Transfer-Encoding" => "chunked", + }, + body: &b"hello"[..], + trailers: { + "chunky-trailer1" => "header data1", + "chunky-trailer2" => "header data2", + "chunky-trailer3" => "header data3", + "chunky-trailer4" => "header data4", + "chunky-trailer5" => "header data5", + "sneaky-trailer" => "not in trailer header", + "transfer-encoding" => "chunked", + "content-length" => "5", + "trailer" => "foo", + }, +} + test! { name: client_get_req_body_sized, diff --git a/tests/server.rs b/tests/server.rs index 133cde963f..f3efbf3c6e 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -2693,7 +2693,7 @@ async fn http2_keep_alive_count_server_pings() { } #[test] -fn http1_trailer_fields() { +fn http1_trailer_send_fields() { let body = futures_util::stream::once(async move { Ok("hello".into()) }); let mut headers = HeaderMap::new(); headers.insert("chunky-trailer", "header data".parse().unwrap()); @@ -2780,6 +2780,35 @@ fn http1_trailer_fields_not_allowed() { assert_eq!(body, expected_body); } +#[test] +fn http1_trailer_recv_fields() { + let server = serve(); + let mut req = connect(server.addr()); + req.write_all( + b"\ + POST / HTTP/1.1\r\n\ + trailer: chunky-trailer\r\n\ + host: example.domain\r\n\ + transfer-encoding: chunked\r\n\ + \r\n\ + 5\r\n\ + hello\r\n\ + 0\r\n\ + chunky-trailer: header data\r\n\ + \r\n\ + ", + ) + .expect("writing"); + + assert_eq!(server.body(), b"hello"); + + let trailers = server.trailers(); + assert_eq!( + trailers.get("chunky-trailer"), + Some(&"header data".parse().unwrap()) + ); +} + // ------------------------------------------------- // the Server that is used to run all the tests with // ------------------------------------------------- @@ -2787,6 +2816,7 @@ fn http1_trailer_fields_not_allowed() { struct Serve { addr: SocketAddr, msg_rx: mpsc::Receiver, + trailers_rx: mpsc::Receiver, reply_tx: Mutex>, shutdown_signal: Option>, thread: Option>, @@ -2820,6 +2850,10 @@ impl Serve { Ok(buf) } + fn trailers(&self) -> HeaderMap { + self.trailers_rx.recv().expect("trailers") + } + fn reply(&self) -> ReplyBuilder<'_> { ReplyBuilder { tx: &self.reply_tx } } @@ -2933,6 +2967,7 @@ impl Drop for Serve { #[derive(Clone)] struct TestService { tx: mpsc::Sender, + trailers_tx: mpsc::Sender, reply: spmc::Receiver, } @@ -2963,6 +2998,7 @@ impl Service> for TestService { fn call(&self, mut req: Request) -> Self::Future { let tx = self.tx.clone(); + let trailers_tx = self.trailers_tx.clone(); let replies = self.reply.clone(); Box::pin(async move { @@ -2972,6 +3008,9 @@ impl Service> for TestService { if frame.is_data() { tx.send(Msg::Chunk(frame.into_data().unwrap().to_vec())) .unwrap(); + } else if frame.is_trailers() { + let trailers = frame.into_trailers().unwrap(); + trailers_tx.send(trailers).unwrap(); } } Err(err) => { @@ -3100,6 +3139,7 @@ impl ServeOptions { let (addr_tx, addr_rx) = mpsc::channel(); let (msg_tx, msg_rx) = mpsc::channel(); + let (trailers_tx, trailers_rx) = mpsc::channel(); let (reply_tx, reply_rx) = spmc::channel(); let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); @@ -3123,6 +3163,7 @@ impl ServeOptions { loop { let msg_tx = msg_tx.clone(); + let trailers_tx = trailers_tx.clone(); let reply_rx = reply_rx.clone(); tokio::select! { @@ -3135,6 +3176,7 @@ impl ServeOptions { let reply_rx = reply_rx.clone(); let service = TestService { tx: msg_tx, + trailers_tx, reply: reply_rx, }; @@ -3162,6 +3204,7 @@ impl ServeOptions { Serve { msg_rx, + trailers_rx, reply_tx: Mutex::new(reply_tx), addr, shutdown_signal: Some(shutdown_tx),