From 401dfaf0bd192c3772f4a66622a1e47c224c05b5 Mon Sep 17 00:00:00 2001 From: "Herman J. Radtke III" Date: Fri, 10 Nov 2023 22:29:56 -0500 Subject: [PATCH] fix(http1): code review fixes - use more idiomatic expressions - add TE as invalid header - add tests for encode_trailers - fix bug in encode_trailers when buffer is empty --- src/proto/h1/conn.rs | 16 +++-- src/proto/h1/encode.rs | 159 +++++++++++++++++++++++++++++++++++------ src/proto/h1/role.rs | 16 +---- 3 files changed, 148 insertions(+), 43 deletions(-) diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index c2719ac555..8007ebe5cb 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,7 +8,7 @@ use std::time::Duration; use crate::rt::{Read, Write}; use bytes::{Buf, Bytes}; -use http::header::{HeaderValue, CONNECTION}; +use http::header::{HeaderValue, CONNECTION, TE}; use http::{HeaderMap, Method, Version}; use httparse::ParserConfig; @@ -265,12 +265,14 @@ where self.state.reading = Reading::Body(Decoder::new(msg.decode)); } - if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) { - if te_value.eq_ignore_ascii_case("trailers") { - self.state.allow_trailer_fields = true; - } else { - self.state.allow_trailer_fields = false; - } + if msg + .head + .headers + .get(TE) + .map(|te_header| te_header == "trailers") + .unwrap_or(false) + { + self.state.allow_trailer_fields = true; } else { self.state.allow_trailer_fields = false; } diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index 3a42a795a2..cbea0c0b09 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -7,7 +7,7 @@ use bytes::{Buf, Bytes}; use http::{ header::{ AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, - CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING, }, HeaderMap, HeaderName, HeaderValue, }; @@ -169,13 +169,8 @@ impl Encoder { title_case_headers: bool, ) -> Option> { match &self.kind { - Kind::Chunked(allowed_trailer_fields) => { - let allowed_trailer_fields_map = match allowed_trailer_fields { - Some(ref allowed_trailer_fields) => { - allowed_trailer_field_map(&allowed_trailer_fields) - } - None => return None, - }; + Kind::Chunked(Some(ref allowed_trailer_fields)) => { + let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields); let mut cur_name = None; let mut allowed_trailers = HeaderMap::new(); @@ -186,8 +181,8 @@ impl Encoder { } let name = cur_name.as_ref().expect("current header name"); - if allowed_trailer_fields_map.contains_key(name.as_str()) - && !invalid_trailer_field(name) + if allowed_trailer_field_map.contains_key(name.as_str()) + && valid_trailer_field(name) { allowed_trailers.insert(name, value); } @@ -200,6 +195,10 @@ impl Encoder { write_headers(&allowed_trailers, &mut buf); } + if buf.is_empty() { + return None; + } + Some(EncodedBuf { kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")), }) @@ -256,20 +255,21 @@ impl Encoder { } } -fn invalid_trailer_field(name: &HeaderName) -> bool { +fn valid_trailer_field(name: &HeaderName) -> bool { match name { - &AUTHORIZATION => true, - &CACHE_CONTROL => true, - &CONTENT_ENCODING => true, - &CONTENT_LENGTH => true, - &CONTENT_RANGE => true, - &CONTENT_TYPE => true, - &HOST => true, - &MAX_FORWARDS => true, - &SET_COOKIE => true, - &TRAILER => true, - &TRANSFER_ENCODING => true, - _ => false, + &AUTHORIZATION => false, + &CACHE_CONTROL => false, + &CONTENT_ENCODING => false, + &CONTENT_LENGTH => false, + &CONTENT_RANGE => false, + &CONTENT_TYPE => false, + &HOST => false, + &MAX_FORWARDS => false, + &SET_COOKIE => false, + &TRAILER => false, + &TRANSFER_ENCODING => false, + &TE => false, + _ => true, } } @@ -439,7 +439,16 @@ impl std::error::Error for NotEof {} #[cfg(test)] mod tests { + use std::iter::FromIterator; + use bytes::BufMut; + use http::{ + header::{ + AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE, + CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING, + }, + HeaderMap, HeaderName, HeaderValue, + }; use super::super::io::Cursor; use super::Encoder; @@ -514,4 +523,108 @@ mod tests { assert!(!encoder.is_eof()); encoder.end::<()>().unwrap(); } + + #[test] + fn chunked_with_valid_trailers() { + let encoder = Encoder::chunked(); + let trailers = vec![HeaderValue::from_static("chunky-trailer")]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + ); + headers.insert( + HeaderName::from_static("should-not-be-included"), + HeaderValue::from_static("oops"), + ); + + let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n"); + } + + #[test] + fn chunked_with_no_trailer_header() { + let encoder = Encoder::chunked(); + + let headers = HeaderMap::from_iter( + vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + )] + .into_iter(), + ); + + assert!(encoder + .encode_trailers::<&[u8]>(headers.clone(), false) + .is_none()); + + let trailers = vec![]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none()); + } + + #[test] + fn chunked_with_invalid_trailers() { + let encoder = Encoder::chunked(); + + let trailers = format!( + "{},{},{},{},{},{},{},{},{},{},{},{}", + AUTHORIZATION, + CACHE_CONTROL, + CONTENT_ENCODING, + CONTENT_LENGTH, + CONTENT_RANGE, + CONTENT_TYPE, + HOST, + MAX_FORWARDS, + SET_COOKIE, + TRAILER, + TRANSFER_ENCODING, + TE, + ); + let trailers = vec![HeaderValue::from_str(&trailers).unwrap()]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("header data")); + headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data")); + headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data")); + headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data")); + headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data")); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data")); + headers.insert(HOST, HeaderValue::from_static("header data")); + headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data")); + headers.insert(SET_COOKIE, HeaderValue::from_static("header data")); + headers.insert(TRAILER, HeaderValue::from_static("header data")); + headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data")); + headers.insert(TE, HeaderValue::from_static("header data")); + + assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none()); + } + + #[test] + fn chunked_with_title_case_headers() { + let encoder = Encoder::chunked(); + let trailers = vec![HeaderValue::from_static("chunky-trailer")]; + let encoder = encoder.into_chunked_with_trailing_fields(trailers); + + let headers = HeaderMap::from_iter( + vec![( + HeaderName::from_static("chunky-trailer"), + HeaderValue::from_static("header data"), + )] + .into_iter(), + ); + let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap(); + + let mut dst = Vec::new(); + dst.put(buf1); + assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n"); + } } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index e9a38d569f..a5a08d8cab 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1347,20 +1347,10 @@ impl Client { let encoder = encoder.map(|enc| { if enc.is_chunked() { - let mut allowed_trailer_fields: Option> = None; - let trailers = headers.get_all(header::TRAILER); - for trailer in trailers.iter() { - match allowed_trailer_fields { - Some(ref mut allowed_trailer_fields) => { - allowed_trailer_fields.push(trailer.clone()); - } - None => { - allowed_trailer_fields = Some(vec![trailer.clone()]); - } - } - } + let allowed_trailer_fields: Vec = + headers.get_all(header::TRAILER).iter().cloned().collect(); - if let Some(allowed_trailer_fields) = allowed_trailer_fields { + if allowed_trailer_fields.len() > 0 { return enc.into_chunked_with_trailing_fields(allowed_trailer_fields); } }