Skip to content

Commit

Permalink
refactor host extraction to common logic + fix authority rama-fp
Browse files Browse the repository at this point in the history
  • Loading branch information
glendc committed Apr 19, 2024
1 parent 61c474e commit 30c869d
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 84 deletions.
1 change: 0 additions & 1 deletion rama-fp/src/service/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ impl From<RequestInfo> for Table {
("Fetch Mode".to_owned(), info.fetch_mode.to_string()),
("Resource Type".to_owned(), info.resource_type.to_string()),
("Initiator".to_owned(), info.initiator.to_string()),
("Uri".to_owned(), info.uri),
(
"Peer Address".to_owned(),
info.peer_addr.unwrap_or_default(),
Expand Down
90 changes: 90 additions & 0 deletions src/http/headers/extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//! Http header Extract utilities.
use crate::http::{header::FORWARDED, HeaderMap};

const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";

/// Extract the host from the headers ([`HeaderMap`]).
pub fn extract_host_from_headers(headers: &HeaderMap) -> Option<String> {
if let Some(host) = parse_forwarded(headers) {
return Some(host.to_owned());
}

if let Some(host) = headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Some(host.to_owned());
}

if let Some(host) = headers
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Some(host.to_owned());
}

None
}

fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;

// get the first set of values
let first_value = forwarded_values.split(',').next()?;

// find the value of the `host` field
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
}

#[cfg(test)]
mod tests {
use crate::http::HeaderName;

use super::*;

#[test]
fn forwarded_parsing() {
// the basic case
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// is case insensitive
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// ipv6
let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "[2001:db8:cafe::17]:4711");

// multiple values in one header
let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// multiple header values
let headers = header_map(&[
(FORWARDED, "host=192.0.2.60"),
(FORWARDED, "host=127.0.0.1"),
]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
}

fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}
2 changes: 2 additions & 0 deletions src/http/headers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,5 @@ pub mod authorization {
pub use headers::authorization::Credentials;
pub use headers::authorization::{Authorization, Basic, Bearer};
}

pub mod extract;
10 changes: 9 additions & 1 deletion src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,15 @@ pub mod dep {
}
}

pub use self::dep::http::header;
pub mod header {
//! HTTP header types
pub use crate::http::dep::http::header::*;

/// Key str constant for the `X-Forwarded-Host` header.
pub const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
}

pub use self::dep::http::header::HeaderMap;
pub use self::dep::http::header::HeaderName;
pub use self::dep::http::header::HeaderValue;
Expand Down
10 changes: 7 additions & 3 deletions src/http/request_context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::{dep::http::request::Parts, Request, Version};
use super::{
dep::http::request::Parts, headers::extract::extract_host_from_headers, Request, Version,
};
use crate::uri::Scheme;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -65,7 +67,8 @@ impl From<&Parts> for RequestContext {
let uri = &parts.uri;

let scheme = uri.scheme().into();
let host = uri.host().map(str::to_owned);
let host =
extract_host_from_headers(&parts.headers).or_else(|| uri.host().map(str::to_owned));
let port = uri.port().map(u16::from);
let http_version = parts.version;

Expand All @@ -83,7 +86,8 @@ impl<Body> From<&Request<Body>> for RequestContext {
let uri = req.uri();

let scheme = uri.scheme().into();
let host = uri.host().map(str::to_owned);
let host =
extract_host_from_headers(req.headers()).or_else(|| uri.host().map(str::to_owned));
let port = uri.port().map(u16::from);
let http_version = req.version();

Expand Down
83 changes: 4 additions & 79 deletions src/http/service/web/endpoint/extract/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@ use std::ops::Deref;

use super::FromRequestParts;
use crate::http::{
dep::http::request::Parts,
header::{HeaderMap, FORWARDED},
StatusCode,
dep::http::request::Parts, headers::extract::extract_host_from_headers, StatusCode,
};
use crate::service::Context;

const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";

/// Extractor that resolves the hostname of the request.
///
/// Hostname is resolved through the following, in order:
Expand All @@ -30,24 +26,8 @@ where
type Rejection = StatusCode;

async fn from_request_parts(_ctx: &Context<S>, parts: &Parts) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(&parts.headers) {
return Ok(Host(host.to_owned()));
}

if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}

if let Some(host) = parts
.headers
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
if let Some(host) = extract_host_from_headers(&parts.headers) {
return Ok(Host(host));
}

if let Some(host) = parts.uri.host() {
Expand All @@ -58,23 +38,6 @@ where
}
}

#[allow(warnings)]
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;

// get the first set of values
let first_value = forwarded_values.split(',').nth(0)?;

// find the value of the `host` field
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
}

impl Deref for Host {
type Target = str;

Expand All @@ -88,6 +51,7 @@ mod tests {
use super::*;

use crate::http::dep::http_body_util::BodyExt as _;
use crate::http::header::X_FORWARDED_HOST_HEADER_KEY;
use crate::http::service::web::WebService;
use crate::http::{Body, HeaderName, Request};
use crate::service::Service;
Expand Down Expand Up @@ -147,43 +111,4 @@ mod tests {
async fn uri_host() {
test_host_from_request("example.com", vec![]).await;
}

#[test]
fn forwarded_parsing() {
// the basic case
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// is case insensitive
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// ipv6
let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "[2001:db8:cafe::17]:4711");

// multiple values in one header
let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");

// multiple header values
let headers = header_map(&[
(FORWARDED, "host=192.0.2.60"),
(FORWARDED, "host=127.0.0.1"),
]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
}

fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}

0 comments on commit 30c869d

Please sign in to comment.