Skip to content

Commit

Permalink
fix(http1): code review fixes
Browse files Browse the repository at this point in the history
- use more idiomatic expressions
- add TE as invalid header
- add tests for encode_trailers
- fix bug in encode_trailers when buffer is empty
  • Loading branch information
hjr3 committed Nov 11, 2023
1 parent 34db8de commit 401dfaf
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 43 deletions.
16 changes: 9 additions & 7 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::time::Duration;

use crate::rt::{Read, Write};
use bytes::{Buf, Bytes};
use http::header::{HeaderValue, CONNECTION};
use http::header::{HeaderValue, CONNECTION, TE};
use http::{HeaderMap, Method, Version};
use httparse::ParserConfig;

Expand Down Expand Up @@ -265,12 +265,14 @@ where
self.state.reading = Reading::Body(Decoder::new(msg.decode));
}

if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) {
if te_value.eq_ignore_ascii_case("trailers") {
self.state.allow_trailer_fields = true;
} else {
self.state.allow_trailer_fields = false;
}
if msg
.head
.headers
.get(TE)
.map(|te_header| te_header == "trailers")
.unwrap_or(false)
{
self.state.allow_trailer_fields = true;
} else {
self.state.allow_trailer_fields = false;
}
Expand Down
159 changes: 136 additions & 23 deletions src/proto/h1/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bytes::{Buf, Bytes};
use http::{
header::{
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
},
HeaderMap, HeaderName, HeaderValue,
};
Expand Down Expand Up @@ -169,13 +169,8 @@ impl Encoder {
title_case_headers: bool,
) -> Option<EncodedBuf<B>> {
match &self.kind {
Kind::Chunked(allowed_trailer_fields) => {
let allowed_trailer_fields_map = match allowed_trailer_fields {
Some(ref allowed_trailer_fields) => {
allowed_trailer_field_map(&allowed_trailer_fields)
}
None => return None,
};
Kind::Chunked(Some(ref allowed_trailer_fields)) => {
let allowed_trailer_field_map = allowed_trailer_field_map(&allowed_trailer_fields);

let mut cur_name = None;
let mut allowed_trailers = HeaderMap::new();
Expand All @@ -186,8 +181,8 @@ impl Encoder {
}
let name = cur_name.as_ref().expect("current header name");

if allowed_trailer_fields_map.contains_key(name.as_str())
&& !invalid_trailer_field(name)
if allowed_trailer_field_map.contains_key(name.as_str())
&& valid_trailer_field(name)
{
allowed_trailers.insert(name, value);
}
Expand All @@ -200,6 +195,10 @@ impl Encoder {
write_headers(&allowed_trailers, &mut buf);
}

if buf.is_empty() {
return None;
}

Some(EncodedBuf {
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
})
Expand Down Expand Up @@ -256,20 +255,21 @@ impl Encoder {
}
}

fn invalid_trailer_field(name: &HeaderName) -> bool {
fn valid_trailer_field(name: &HeaderName) -> bool {
match name {
&AUTHORIZATION => true,
&CACHE_CONTROL => true,
&CONTENT_ENCODING => true,
&CONTENT_LENGTH => true,
&CONTENT_RANGE => true,
&CONTENT_TYPE => true,
&HOST => true,
&MAX_FORWARDS => true,
&SET_COOKIE => true,
&TRAILER => true,
&TRANSFER_ENCODING => true,
_ => false,
&AUTHORIZATION => false,
&CACHE_CONTROL => false,
&CONTENT_ENCODING => false,
&CONTENT_LENGTH => false,
&CONTENT_RANGE => false,
&CONTENT_TYPE => false,
&HOST => false,
&MAX_FORWARDS => false,
&SET_COOKIE => false,
&TRAILER => false,
&TRANSFER_ENCODING => false,
&TE => false,
_ => true,
}
}

Expand Down Expand Up @@ -439,7 +439,16 @@ impl std::error::Error for NotEof {}

#[cfg(test)]
mod tests {
use std::iter::FromIterator;

use bytes::BufMut;
use http::{
header::{
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TE, TRAILER, TRANSFER_ENCODING,
},
HeaderMap, HeaderName, HeaderValue,
};

use super::super::io::Cursor;
use super::Encoder;
Expand Down Expand Up @@ -514,4 +523,108 @@ mod tests {
assert!(!encoder.is_eof());
encoder.end::<()>().unwrap();
}

#[test]
fn chunked_with_valid_trailers() {
let encoder = Encoder::chunked();
let trailers = vec![HeaderValue::from_static("chunky-trailer")];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);

let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
);
headers.insert(
HeaderName::from_static("should-not-be-included"),
HeaderValue::from_static("oops"),
);

let buf1 = encoder.encode_trailers::<&[u8]>(headers, false).unwrap();

let mut dst = Vec::new();
dst.put(buf1);
assert_eq!(dst, b"0\r\nchunky-trailer: header data\r\n\r\n");
}

#[test]
fn chunked_with_no_trailer_header() {
let encoder = Encoder::chunked();

let headers = HeaderMap::from_iter(
vec![(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
)]
.into_iter(),
);

assert!(encoder
.encode_trailers::<&[u8]>(headers.clone(), false)
.is_none());

let trailers = vec![];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);

assert!(encoder.encode_trailers::<&[u8]>(headers, false).is_none());
}

#[test]
fn chunked_with_invalid_trailers() {
let encoder = Encoder::chunked();

let trailers = format!(
"{},{},{},{},{},{},{},{},{},{},{},{}",
AUTHORIZATION,
CACHE_CONTROL,
CONTENT_ENCODING,
CONTENT_LENGTH,
CONTENT_RANGE,
CONTENT_TYPE,
HOST,
MAX_FORWARDS,
SET_COOKIE,
TRAILER,
TRANSFER_ENCODING,
TE,
);
let trailers = vec![HeaderValue::from_str(&trailers).unwrap()];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);

let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("header data"));
headers.insert(CACHE_CONTROL, HeaderValue::from_static("header data"));
headers.insert(CONTENT_ENCODING, HeaderValue::from_static("header data"));
headers.insert(CONTENT_LENGTH, HeaderValue::from_static("header data"));
headers.insert(CONTENT_RANGE, HeaderValue::from_static("header data"));
headers.insert(CONTENT_TYPE, HeaderValue::from_static("header data"));
headers.insert(HOST, HeaderValue::from_static("header data"));
headers.insert(MAX_FORWARDS, HeaderValue::from_static("header data"));
headers.insert(SET_COOKIE, HeaderValue::from_static("header data"));
headers.insert(TRAILER, HeaderValue::from_static("header data"));
headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("header data"));
headers.insert(TE, HeaderValue::from_static("header data"));

assert!(encoder.encode_trailers::<&[u8]>(headers, true).is_none());
}

#[test]
fn chunked_with_title_case_headers() {
let encoder = Encoder::chunked();
let trailers = vec![HeaderValue::from_static("chunky-trailer")];
let encoder = encoder.into_chunked_with_trailing_fields(trailers);

let headers = HeaderMap::from_iter(
vec![(
HeaderName::from_static("chunky-trailer"),
HeaderValue::from_static("header data"),
)]
.into_iter(),
);
let buf1 = encoder.encode_trailers::<&[u8]>(headers, true).unwrap();

let mut dst = Vec::new();
dst.put(buf1);
assert_eq!(dst, b"0\r\nChunky-Trailer: header data\r\n\r\n");
}
}
16 changes: 3 additions & 13 deletions src/proto/h1/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1347,20 +1347,10 @@ impl Client {

let encoder = encoder.map(|enc| {
if enc.is_chunked() {
let mut allowed_trailer_fields: Option<Vec<HeaderValue>> = None;
let trailers = headers.get_all(header::TRAILER);
for trailer in trailers.iter() {
match allowed_trailer_fields {
Some(ref mut allowed_trailer_fields) => {
allowed_trailer_fields.push(trailer.clone());
}
None => {
allowed_trailer_fields = Some(vec![trailer.clone()]);
}
}
}
let allowed_trailer_fields: Vec<HeaderValue> =
headers.get_all(header::TRAILER).iter().cloned().collect();

if let Some(allowed_trailer_fields) = allowed_trailer_fields {
if allowed_trailer_fields.len() > 0 {
return enc.into_chunked_with_trailing_fields(allowed_trailer_fields);
}
}
Expand Down

0 comments on commit 401dfaf

Please sign in to comment.