diff --git a/Cargo.lock b/Cargo.lock index c5d64c3e3..053c6c9d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -637,6 +637,40 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "575f75dfd25738df5b91b8e43e14d44bda14637a58fae779fd2b064f8bf3e010" +[[package]] +name = "deadpool" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ed5957ff93768adf7a65ab167a17835c3d2c3c50d084fe305174c112f468e2f" +dependencies = [ + "deadpool-runtime", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-postgres" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d697d376cbfa018c23eb4caab1fd1883dd9c906a8c034e8d9a3cb06a7e0bef9" +dependencies = [ + "async-trait", + "deadpool", + "getrandom 0.2.15", + "tokio", + "tokio-postgres", + "tracing", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] + [[package]] name = "deranged" version = "0.3.11" @@ -685,6 +719,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -808,6 +843,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fastrand" version = "2.3.0" @@ -1162,6 +1203,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "hex" version = "0.4.3" @@ -1211,6 +1258,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.11" @@ -1698,6 +1754,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "md5" version = "0.7.0" @@ -1851,6 +1917,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.36.7" @@ -2027,6 +2103,24 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.9" @@ -2071,6 +2165,37 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "postgres-protocol" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" +dependencies = [ + "base64 0.22.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.9.0", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol", + "serde", + "serde_json", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -2204,7 +2329,9 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "deadpool-postgres", "hex", + "itertools 0.14.0", "mimalloc", "rama", "serde", @@ -2212,6 +2339,7 @@ dependencies = [ "terminal-prompt", "tikv-jemallocator", "tokio", + "tokio-postgres", "tracing", "tracing-subscriber", ] @@ -2289,10 +2417,6 @@ dependencies = [ "csv", "flate2", "futures-lite", - "headers", - "http", - "http-body", - "http-body-util", "http-range-header", "httpdate", "iri-string", @@ -2437,6 +2561,7 @@ dependencies = [ "rama-utils", "rustls", "serde", + "serde_json", "sha2", "socket2", "tokio", @@ -2512,11 +2637,17 @@ dependencies = [ name = "rama-ua" version = "0.2.0-alpha.7" dependencies = [ + "bytes", + "itertools 0.14.0", "rama-core", + "rama-http-types", + "rama-net", "rama-utils", + "rand 0.9.0", "serde", "serde_json", "tokio", + "tracing", ] [[package]] @@ -2982,6 +3113,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.9" @@ -3028,6 +3165,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" @@ -3284,6 +3432,32 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-postgres" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.9.0", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.26.1" @@ -3499,6 +3673,12 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.17" @@ -3514,6 +3694,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -3633,6 +3819,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -3735,6 +3927,17 @@ dependencies = [ "rustix", ] +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index dcff4dc66..a9cb1d437 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ http = "1" http-body = "1" http-body-util = "0.1" http-range-header = "0.4.0" +deadpool-postgres = "0.14.1" httpdate = "1.0" boring = "4.9.1" tokio-boring = "4.9.1" @@ -142,6 +143,7 @@ slab = "0.4.2" indexmap = "2" rand = "0.9.0" walkdir = "2.3.2" +tokio-postgres = "0.7.13" env_logger = "0.11.6" httparse = "1.8" smallvec = { version = "1.12", features = ["const_generics", "const_new"] } diff --git a/examples/http_user_agent_classifier.rs b/examples/http_user_agent_classifier.rs index 094d2ec0b..84f04895e 100644 --- a/examples/http_user_agent_classifier.rs +++ b/examples/http_user_agent_classifier.rs @@ -47,8 +47,8 @@ async fn handle(ctx: Context<()>, _req: Request) -> Result "kind": ua.info().map(|info| info.kind.to_string()), "version": ua.info().and_then(|info| info.version), "platform": ua.platform().map(|p| p.to_string()), - "http_agent": ua.http_agent().to_string(), - "tls_agent": ua.tls_agent().to_string(), + "http_agent": ua.http_agent().as_ref().map(ToString::to_string), + "tls_agent": ua.tls_agent().as_ref().map(ToString::to_string), })) .into_response()) } diff --git a/rama-cli/Cargo.toml b/rama-cli/Cargo.toml index 3fa85d8b3..8b15486a1 100644 --- a/rama-cli/Cargo.toml +++ b/rama-cli/Cargo.toml @@ -22,7 +22,9 @@ mimalloc = ["dep:mimalloc"] base64 = { workspace = true } bytes = { workspace = true } clap = { workspace = true } +deadpool-postgres = { workspace = true } hex = { workspace = true } +itertools = { workspace = true } jemallocator = { workspace = true, optional = true } mimalloc = { workspace = true, optional = true } rama = { version = "0.2.0-alpha.7", path = "..", features = ["full"] } @@ -30,6 +32,7 @@ serde = { workspace = true } serde_json = { workspace = true } terminal-prompt = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } tracing = { workspace = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/rama-cli/assets/script.js b/rama-cli/assets/script.js index eb02762ba..79c4c169b 100644 --- a/rama-cli/assets/script.js +++ b/rama-cli/assets/script.js @@ -22,26 +22,10 @@ async function fetchWithBackoff(url, options) { throw new Error('Max retries exceeded'); } -// Function to make a GET request -async function makeGetRequest(url) { - const headers = { - 'x-CusToM-HEADER': `rama-fp${Date.now()}`, - 'x-CusToM-HEADER-eXtRa': `rama-fpeXtRa-${Date.now()}`, - }; - - const options = { - method: 'GET', - headers - }; - - return fetchWithBackoff(url, options); -} - // Function to make a POST request async function makePostRequest(url, number) { const headers = { - 'x-CusToM-HEADER': `rama-fp${Date.now()}`, - 'x-CusToM-HEADER-eXtRa': `rama-fpeXtRa-${Date.now()}`, + 'x-RAMA-custom-header-marker': `rama-fp${Date.now()}`, }; const body = JSON.stringify({ number }); @@ -60,8 +44,7 @@ function makeRequestWithXHR(url, method, number) { return new Promise((resolve, reject) => { const xhr = new XMLHttpRequest(); xhr.open(method, url); - xhr.setRequestHeader('x-CusToM-HEADER', `rama-fp${Date.now()}`); - xhr.setRequestHeader('x-CusToM-HEADER-eXtRa', `rama-fpeXtRa-${Date.now()}`); + xhr.setRequestHeader('x-RAMA-custom-header-marker', `rama-fp${Date.now()}`); xhr.onload = function () { if (xhr.status >= 200 && xhr.status < 300) { @@ -82,20 +65,20 @@ function makeRequestWithXHR(url, method, number) { // Main function to execute the requests async function main() { try { - // Fetch GET request - const response1 = await makeGetRequest('/api/fetch/number'); - const { number } = await response1.json(); + // Generate random numbers for the requests + const number = Math.floor(Math.random() * 1000) + 1; + const number2 = Math.floor(Math.random() * 1000) + 1; + + console.log('Generated random numbers:', number, number2); // Fetch POST request const response2 = await makePostRequest(`/api/fetch/number/${number}`, number); + console.log('Fetch POST request response:', response2); const result = await response2.json(); - // XMLHttpRequest GET request - const response3 = await makeRequestWithXHR('/api/xml/number', 'GET'); - const { number: number2 } = JSON.parse(response3); - // XMLHttpRequest POST request const response4 = await makeRequestWithXHR(`/api/xml/number/${number2}`, 'POST', number2); + console.log('XMLHttpRequest POST request response:', response4); const result2 = JSON.parse(response4); console.log('Requests completed successfully'); diff --git a/rama-cli/src/cmd/fp/data.rs b/rama-cli/src/cmd/fp/data.rs index 424db9584..d528c9153 100644 --- a/rama-cli/src/cmd/fp/data.rs +++ b/rama-cli/src/cmd/fp/data.rs @@ -1,9 +1,9 @@ -use super::State; +use super::{State, StorageAuthorized}; use rama::{ Context, - error::{BoxError, ErrorContext}, + error::{BoxError, ErrorContext, OpaqueError}, http::{ - HeaderMap, Request, + self, HeaderMap, HeaderName, Request, dep::http::{Extensions, request::Parts}, headers::Forwarded, proto::{h1::Http1HeaderMap, h2::PseudoHeaderOrder}, @@ -17,7 +17,7 @@ use rama::{ SecureTransport, client::{ClientHello, ClientHelloExtension}, }, - ua::UserAgent, + ua::{Http1Settings, Http2Settings, UserAgent}, }; use serde::Serialize; use std::{str::FromStr, sync::Arc}; @@ -210,8 +210,118 @@ pub(super) struct HttpInfo { pub(super) pseudo_headers: Option>, } -pub(super) fn get_http_info(headers: HeaderMap, ext: &mut Extensions) -> HttpInfo { - let headers: Vec<_> = Http1HeaderMap::new(headers, Some(ext)) +pub(super) async fn get_and_store_http_info( + ctx: &Context>, + headers: HeaderMap, + ext: &mut Extensions, + http_version: http::Version, + ua: String, + initiator: Initiator, +) -> Result { + let original_headers = Http1HeaderMap::new(headers, Some(ext)); + let pseudo_headers = ext.get::(); + + if ctx.contains::() { + if let Some(storage) = ctx.state().storage.as_ref() { + match http_version { + http::Version::HTTP_09 | http::Version::HTTP_10 | http::Version::HTTP_11 => { + match initiator { + Initiator::Navigator => { + storage + .store_h1_headers_navigate(ua, original_headers.clone()) + .await + .context("store h1 headers navigate")?; + } + Initiator::Fetch => { + storage + .store_h1_headers_fetch(ua, original_headers.clone()) + .await + .context("store h1 headers fetch")?; + } + Initiator::XMLHttpRequest => { + if let Some(header_name) = original_headers.get_original_name( + &HeaderName::from_static("x-rama-custom-header-marker"), + ) { + // Check if the header name is title-cased or not + let header_str = header_name.as_str(); + let title_case_headers = header_str.split('-').all(|part| { + part.chars().next().is_none_or(|c| c.is_ascii_uppercase()) + && part.chars().skip(1).all(|c| c.is_ascii_lowercase()) + }); + + tracing::debug!( + "Custom header marker found: {}, title-cased: {}", + header_str, + title_case_headers + ); + + storage + .store_h1_settings( + ua.clone(), + Http1Settings { title_case_headers }, + ) + .await + .context("store h1 settings")?; + } + + storage + .store_h1_headers_xhr(ua, original_headers.clone()) + .await + .context("store h1 headers xhr")?; + } + Initiator::Form => { + storage + .store_h1_headers_form(ua, original_headers.clone()) + .await + .context("store h1 headers form")?; + } + } + } + http::Version::HTTP_2 => { + if let Some(pseudo_headers) = pseudo_headers { + storage + .store_h2_settings( + ua.clone(), + Http2Settings { + http_pseudo_headers: Some(pseudo_headers.iter().collect()), + }, + ) + .await + .context("store h2 pseudo headers")?; + } + match initiator { + Initiator::Navigator => { + storage + .store_h2_headers_navigate(ua, original_headers.clone()) + .await + .context("store h2 headers navigate")?; + } + Initiator::Fetch => { + storage + .store_h2_headers_fetch(ua, original_headers.clone()) + .await + .context("store h2 headers fetch")?; + } + Initiator::XMLHttpRequest => { + storage + .store_h2_headers_xhr(ua, original_headers.clone()) + .await + .context("store h2 headers xhr")?; + } + Initiator::Form => { + storage + .store_h2_headers_form(ua, original_headers.clone()) + .await + .context("store h2 headers form")?; + } + } + } + _ => (), + } + } + } + + let headers: Vec<_> = original_headers .into_iter() .map(|(name, value)| { ( @@ -223,14 +333,13 @@ pub(super) fn get_http_info(headers: HeaderMap, ext: &mut Extensions) -> HttpInf }) .collect(); - let pseudo_headers: Option> = ext - .get::() - .map(|o| o.iter().map(|p| p.to_string()).collect()); + let pseudo_headers: Option> = + pseudo_headers.map(|o| o.iter().map(|p| p.to_string()).collect()); - HttpInfo { + Ok(HttpInfo { headers, pseudo_headers, - } + }) } #[derive(Debug, Clone, Serialize)] @@ -267,20 +376,31 @@ pub(super) enum TlsDisplayInfoExtensionData { Multi(Vec), } -pub(super) fn get_tls_display_info(ctx: &Context>) -> Option { - let hello: &ClientHello = ctx +pub(super) async fn get_tls_display_info_and_store( + ctx: &Context>, + ua: String, +) -> Result, OpaqueError> { + let hello: &ClientHello = match ctx .get::() - .and_then(|st| st.client_hello())?; + .and_then(|st| st.client_hello()) + { + Some(hello) => hello, + None => return Ok(None), + }; - let ja4 = Ja4::compute(ctx.extensions()) - .inspect_err(|err| tracing::error!(?err, "ja4 compute failure")) - .ok()?; + if ctx.contains::() { + if let Some(storage) = ctx.state().storage.as_ref() { + storage + .store_tls_client_hello(ua, hello.clone()) + .await + .context("store tls client hello")?; + } + } - let ja3 = Ja3::compute(ctx.extensions()) - .inspect_err(|err| tracing::error!(?err, "ja3 compute failure")) - .ok()?; + let ja4 = Ja4::compute(ctx.extensions()).context("ja4 compute")?; + let ja3 = Ja3::compute(ctx.extensions()).context("ja3 compute")?; - Some(TlsDisplayInfo { + Ok(Some(TlsDisplayInfo { ja4: Ja4DisplayInfo { full: format!("{ja4:?}"), hash: format!("{ja4}"), @@ -355,5 +475,5 @@ pub(super) fn get_tls_display_info(ctx: &Context>) -> Option>(), - }) + })) } diff --git a/rama-cli/src/cmd/fp/endpoints.rs b/rama-cli/src/cmd/fp/endpoints.rs index 8076bf2c6..6c94282a8 100644 --- a/rama-cli/src/cmd/fp/endpoints.rs +++ b/rama-cli/src/cmd/fp/endpoints.rs @@ -2,7 +2,8 @@ use super::{ State, data::{ DataSource, FetchMode, Initiator, RequestInfo, ResourceType, TlsDisplayInfo, UserAgentInfo, - get_http_info, get_ja4h_info, get_request_info, get_tls_display_info, get_user_agent_info, + get_and_store_http_info, get_ja4h_info, get_request_info, get_tls_display_info_and_store, + get_user_agent_info, }, }; use crate::cmd::fp::data::TlsDisplayInfoExtensionData; @@ -95,7 +96,18 @@ pub(super) async fn get_report( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); + let user_agent = user_agent_info.user_agent.clone(); + + let http_info = get_and_store_http_info( + &ctx, + parts.headers, + &mut parts.extensions, + parts.version, + user_agent.clone(), + Initiator::Navigator, + ) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; let head = r#""#.to_owned(); @@ -126,7 +138,10 @@ pub(super) async fn get_report( }); } - let tls_info = get_tls_display_info(&ctx); + let tls_info = get_tls_display_info_and_store(&ctx, user_agent) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + if let Some(tls_info) = tls_info { let mut tls_tables = tls_info.into(); tables.append(&mut tls_tables); @@ -175,45 +190,6 @@ pub(super) struct APINumberParams { number: usize, } -pub(super) async fn get_api_fetch_number( - mut ctx: Context>, - req: Request, -) -> Result, Response> { - let ja4h = get_ja4h_info(&req); - - let (mut parts, _) = req.into_parts(); - - let user_agent_info = get_user_agent_info(&ctx).await; - - let request_info = get_request_info( - FetchMode::SameOrigin, - ResourceType::Xhr, - Initiator::Fetch, - &mut ctx, - &parts, - ) - .await - .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - - let http_info = get_http_info(parts.headers, &mut parts.extensions); - - let tls_info = get_tls_display_info(&ctx); - - Ok(Json(json!({ - "number": ctx.state().counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel), - "fp": { - "user_agent_info": user_agent_info, - "request_info": request_info, - "tls_info": tls_info, - "http_info": json!({ - "headers": http_info.headers, - "pseudo_headers": http_info.pseudo_headers, - "ja4h": ja4h, - }), - } - }))) -} - pub(super) async fn post_api_fetch_number( Path(params): Path, mut ctx: Context>, @@ -225,6 +201,8 @@ pub(super) async fn post_api_fetch_number( let user_agent_info = get_user_agent_info(&ctx).await; + let user_agent = user_agent_info.user_agent.clone(); + let request_info = get_request_info( FetchMode::SameOrigin, ResourceType::Xhr, @@ -235,51 +213,23 @@ pub(super) async fn post_api_fetch_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); - - let tls_info = get_tls_display_info(&ctx); - - Ok(Json(json!({ - "number": params.number, - "fp": { - "user_agent_info": user_agent_info, - "request_info": request_info, - "tls_info": tls_info, - "http_info": json!({ - "headers": http_info.headers, - "pseudo_headers": http_info.pseudo_headers, - "ja4h": ja4h, - }), - } - }))) -} - -pub(super) async fn get_api_xml_http_request_number( - mut ctx: Context>, - req: Request, -) -> Result, Response> { - let ja4h = get_ja4h_info(&req); - - let (mut parts, _) = req.into_parts(); - - let user_agent_info = get_user_agent_info(&ctx).await; - - let request_info = get_request_info( - FetchMode::SameOrigin, - ResourceType::Xhr, + let http_info = get_and_store_http_info( + &ctx, + parts.headers, + &mut parts.extensions, + parts.version, + user_agent.clone(), Initiator::Fetch, - &mut ctx, - &parts, ) .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); - - let tls_info = get_tls_display_info(&ctx); + let tls_info = get_tls_display_info_and_store(&ctx, user_agent) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; Ok(Json(json!({ - "number": ctx.state().counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel), + "number": params.number, "fp": { "user_agent_info": user_agent_info, "request_info": request_info, @@ -304,6 +254,8 @@ pub(super) async fn post_api_xml_http_request_number( let user_agent_info = get_user_agent_info(&ctx).await; + let user_agent = user_agent_info.user_agent.clone(); + let request_info = get_request_info( FetchMode::SameOrigin, ResourceType::Xhr, @@ -314,9 +266,20 @@ pub(super) async fn post_api_xml_http_request_number( .await .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let http_info = get_http_info(parts.headers, &mut parts.extensions); + let http_info = get_and_store_http_info( + &ctx, + parts.headers, + &mut parts.extensions, + parts.version, + user_agent.clone(), + Initiator::XMLHttpRequest, + ) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; - let tls_info = get_tls_display_info(&ctx); + let tls_info = get_tls_display_info_and_store(&ctx, user_agent) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; Ok(Json(json!({ "number": params.number, @@ -344,6 +307,8 @@ pub(super) async fn form(mut ctx: Context>, req: Request) -> Result>, req: Request) -> Result>, req: Request) -> Result Result<(), BoxError> { }; let address = format!("{}:{}", cfg.interface, cfg.port); - let ch_headers = [ - "Width", - "Downlink", - "Sec-CH-UA", - "Sec-CH-UA-Mobile", - "Sec-CH-UA-Full-Version", - "ETC", - "Save-Data", - "Sec-CH-UA-Platform", - "Sec-CH-Prefers-Reduced-Motion", - "Sec-CH-UA-Arch", - "Sec-CH-UA-Bitness", - "Sec-CH-UA-Model", - "Sec-CH-UA-Platform-Version", - "Sec-CH-UA-Prefers-Color-Scheme", - "Device-Memory", - "RTT", - "Sec-GPC", - ] - .join(", ") - .parse::() - .expect("parse header value"); + let ch_headers = all_client_hint_header_name_strings() + .join(", ") + .parse::() + .expect("parse header value"); graceful.spawn_task_fn(move |guard| async move { let inner_http_service = HijackLayer::new( @@ -233,9 +225,7 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { ) .layer(match_service!{ HttpMatcher::get("/report") => endpoints::get_report, - HttpMatcher::get("/api/fetch/number") => endpoints::get_api_fetch_number, HttpMatcher::post("/api/fetch/number/:number") => endpoints::post_api_fetch_number, - HttpMatcher::get("/api/xml/number") => endpoints::get_api_xml_http_request_number, HttpMatcher::post("/api/xml/number/:number") => endpoints::post_api_xml_http_request_number, HttpMatcher::method_get().or_method_post().and_path("/form") => endpoints::form, _ => Redirect::temporary("/consent"), @@ -250,6 +240,7 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { HeaderName::from_static("x-sponsored-by"), HeaderValue::from_static("fly.io"), ), + StorageAuthLayer, SetResponseHeaderLayer::if_not_present( HeaderName::from_static("accept-ch"), ch_headers.clone(), @@ -295,7 +286,10 @@ pub async fn run(cfg: CliCommandFingerprint) -> Result<(), BoxError> { }) ); - let tcp_listener = TcpListener::build_with_state(Arc::new(State::new(acme_data))) + let pg_url = std::env::var("DATABASE_URL").ok(); + let storage_auth = std::env::var("RAMA_FP_STORAGE_COOKIE").ok(); + + let tcp_listener = TcpListener::build_with_state(Arc::new(State::new(acme_data, pg_url, storage_auth.as_deref()).await.expect("create state"))) .bind(&address) .await .expect("bind TCP Listener"); @@ -365,3 +359,63 @@ impl FromStr for HttpVersion { }) } } + +#[derive(Debug, Clone, Default)] +struct StorageAuthLayer; + +impl Layer for StorageAuthLayer { + type Service = StorageAuthService; + + fn layer(&self, inner: S) -> Self::Service { + StorageAuthService { inner } + } +} + +struct StorageAuthService { + inner: S, +} + +impl std::fmt::Debug for StorageAuthService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StorageAuthService") + .field("inner", &self.inner) + .finish() + } +} + +impl Service, Request> for StorageAuthService +where + Body: Send + 'static, + S: Service, Request>, +{ + type Response = S::Response; + type Error = S::Error; + + async fn serve( + &self, + mut ctx: Context>, + mut req: Request, + ) -> Result { + if let Some(cookie) = req.headers().typed_get::() { + let cookie = cookie + .iter() + .map(|(k, v)| { + if k.eq_ignore_ascii_case("rama-storage-auth") { + if Some(v) == ctx.state().storage_auth.as_deref() { + ctx.insert(StorageAuthorized); + } + "rama-storage-auth=xxx".to_owned() + } else { + format!("{k}={v}") + } + }) + .join("; "); + if !cookie.is_empty() { + req.headers_mut() + .insert(COOKIE, HeaderValue::from_str(&cookie).unwrap()); + } + } + + self.inner.serve(ctx, req).await + } +} diff --git a/rama-cli/src/cmd/fp/state.rs b/rama-cli/src/cmd/fp/state.rs index 62f08e82a..fdbe5ba25 100644 --- a/rama-cli/src/cmd/fp/state.rs +++ b/rama-cli/src/cmd/fp/state.rs @@ -1,23 +1,36 @@ -use std::{collections::HashMap, sync::atomic::AtomicUsize}; +use std::collections::HashMap; -use super::data::DataSource; +use rama::error::{ErrorContext, OpaqueError}; + +use super::{data::DataSource, storage::Storage}; #[derive(Debug)] #[non_exhaustive] pub(super) struct State { pub(super) data_source: DataSource, - pub(super) counter: AtomicUsize, pub(super) acme: ACMEData, + pub(super) storage: Option, + pub(super) storage_auth: Option, } impl State { /// Create a new instance of [`State`]. - pub(super) fn new(acme: ACMEData) -> Self { - State { + pub(super) async fn new( + acme: ACMEData, + pg_url: Option, + storage_auth: Option<&str>, + ) -> Result { + let storage = match pg_url { + Some(pg_url) => Some(Storage::new(pg_url).await.context("create storage")?), + None => None, + }; + + Ok(State { data_source: DataSource::default(), - counter: AtomicUsize::new(0), acme, - } + storage, + storage_auth: storage_auth.map(|s| s.to_owned()), + }) } } diff --git a/rama-cli/src/cmd/fp/storage/mod.rs b/rama-cli/src/cmd/fp/storage/mod.rs new file mode 100644 index 000000000..c0c46be57 --- /dev/null +++ b/rama-cli/src/cmd/fp/storage/mod.rs @@ -0,0 +1,267 @@ +use rama::{ + error::{ErrorContext, OpaqueError}, + http::proto::h1::Http1HeaderMap, + net::tls::client::ClientHello, + ua::{Http1Settings, Http2Settings}, +}; + +mod postgres; +use postgres::Pool; +use tokio_postgres::types; + +#[derive(Debug, Clone)] +pub(super) struct Storage { + pool: Pool, +} + +impl Storage { + pub(super) async fn new(pg_url: String) -> Result { + tracing::debug!("create new storage with PG URL: {}", pg_url); + let pool = postgres::new_pool(pg_url).await?; + Ok(Self { pool }) + } +} + +impl Storage { + pub(super) async fn store_h1_settings( + &self, + ua: String, + settings: Http1Settings, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h1 settings for UA '{ua}': {settings:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h1_settings) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_settings = $2", + &[&ua, &types::Json(settings)], + ).await.context("store h1 settings in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 settings for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_navigate( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h1 navigateheaders for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_navigate) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_navigate = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 navigate headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 navigate headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_fetch( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h1 fetch headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_fetch) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_fetch = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 fetch headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 fetch headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_xhr( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h1 xhr headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_xhr) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_xhr = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 xhr headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 xhr headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h1_headers_form( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h1 form headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h1_headers_form) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h1_headers_form = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h1 form headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h1 form headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_settings( + &self, + ua: String, + settings: Http2Settings, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h2 settings for UA '{ua}': {settings:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h2_settings) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_settings = $2", + &[&ua, &types::Json(settings)], + ).await.context("store h2 settings in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 settings for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_navigate( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h2 navigate headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_navigate) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_navigate = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 navigate headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 navigate headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_fetch( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h2 fetch headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_fetch) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_fetch = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 fetch headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 fetch headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_xhr( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h2 xhr headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_xhr) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_xhr = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 xhr headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 xhr headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_h2_headers_form( + &self, + ua: String, + headers: Http1HeaderMap, + ) -> Result<(), OpaqueError> { + tracing::debug!("store h2 form headers for UA '{ua}': {headers:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, h2_headers_form) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET h2_headers_form = $2", + &[&ua, &types::Json(headers)], + ).await.context("store h2 form headers in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store h2 form headers for UA '{ua}': {n}" + ); + } + + Ok(()) + } + + pub(super) async fn store_tls_client_hello( + &self, + ua: String, + tls_client_hello: ClientHello, + ) -> Result<(), OpaqueError> { + tracing::debug!("store tls client hello for UA '{ua}': {tls_client_hello:?}"); + + let client = self.pool.get().await.context("get postgres client")?; + let n = client.execute( + "INSERT INTO \"ua-profiles\" (uastr, tls_client_hello) VALUES ($1, $2) ON CONFLICT (uastr) DO UPDATE SET tls_client_hello = $2", + &[&ua, &types::Json(tls_client_hello)], + ).await.context("store tls client hello in postgres")?; + + if n != 1 { + tracing::error!( + "unexpected number of rows affected to store tls client hello for UA '{ua}': {n}" + ); + } + + Ok(()) + } +} diff --git a/rama-cli/src/cmd/fp/storage/postgres.rs b/rama-cli/src/cmd/fp/storage/postgres.rs new file mode 100644 index 000000000..90fe12c1f --- /dev/null +++ b/rama-cli/src/cmd/fp/storage/postgres.rs @@ -0,0 +1,128 @@ +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +use deadpool_postgres::Config; +use rama::{ + error::{ErrorContext, OpaqueError}, + net::{address::Host, stream::Stream}, + tls::{ + boring::client::{TlsStream, tls_connect}, + std::dep::boring::{hash::MessageDigest, nid::Nid, ssl::SslRef}, + }, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_postgres::tls::{self, ChannelBinding, MakeTlsConnect, TlsConnect}; + +pub(super) use deadpool_postgres::Pool; + +pub(super) async fn new_pool(url: String) -> Result { + Config { + url: Some(url), + dbname: Some("fp".to_owned()), + ..Default::default() + } + .create_pool(None, MakeBoringTlsConnector) + .context("create postgres deadpool") +} + +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +struct MakeBoringTlsConnector; + +#[derive(Debug, Clone)] +struct BoringTlsConnector { + host: Host, +} + +impl MakeTlsConnect for MakeBoringTlsConnector +where + S: Stream + Unpin, +{ + type Stream = BoringTlsStream; + type TlsConnect = BoringTlsConnector; + type Error = OpaqueError; + + fn make_tls_connect(&mut self, domain: &str) -> Result { + let host: Host = domain.parse().context("parse host from domain")?; + Ok(BoringTlsConnector { host }) + } +} + +impl TlsConnect for BoringTlsConnector +where + S: Stream + Unpin, +{ + type Stream = BoringTlsStream; + type Error = OpaqueError; + #[allow(clippy::type_complexity)] + type Future = Pin, Self::Error>> + Send>>; + + fn connect(self, stream: S) -> Self::Future { + Box::pin(async move { + let tls_stream = tls_connect(self.host, stream, None).await?; + Ok(BoringTlsStream(tls_stream)) + }) + } +} +/// The stream returned by `TlsConnector`. +struct BoringTlsStream(TlsStream); + +impl AsyncRead for BoringTlsStream +where + S: Stream + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for BoringTlsStream +where + S: Stream + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl tls::TlsStream for BoringTlsStream +where + S: Stream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match tls_server_end_point(self.0.ssl_ref()) { + Some(buf) => ChannelBinding::tls_server_end_point(buf), + None => ChannelBinding::none(), + } + } +} + +fn tls_server_end_point(ssl: &SslRef) -> Option> { + let cert = ssl.peer_certificate()?; + let algo_nid = cert.signature_algorithm().object().nid(); + let signature_algorithms = algo_nid.signature_algorithms()?; + let md = match signature_algorithms.digest { + Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(), + nid => MessageDigest::from_nid(nid)?, + }; + cert.digest(md).ok().map(|b| b.to_vec()) +} diff --git a/rama-fp/browserstack/main.py b/rama-fp/browserstack/main.py index 1a0dacd58..ebd39f456 100644 --- a/rama-fp/browserstack/main.py +++ b/rama-fp/browserstack/main.py @@ -70,6 +70,8 @@ def env(key): BROWSERSTACK_ACCESS_KEY = env("BROWSERSTACK_ACCESS_KEY") URL = os.environ.get("URL") or "https://hub.browserstack.com/wd/hub" +RAMA_FP_STORAGE_COOKIE = env("RAMA_FP_STORAGE_COOKIE") + def get_browser_option(browser): switcher = { @@ -154,6 +156,11 @@ def run_session(cap): options=options, ) + driver.set_cookie({ + 'name': 'rama-storage-auth', + 'value': RAMA_FP_STORAGE_COOKIE, + }) + driver.get(entrypoint) print("ua", driver.execute_script("return navigator.userAgent;")) print("loc", driver.execute_script("return document.location.href;")) diff --git a/rama-fp/infra/deployments/pg/fly.toml b/rama-fp/infra/deployments/pg/fly.toml new file mode 100644 index 000000000..a10c31cc6 --- /dev/null +++ b/rama-fp/infra/deployments/pg/fly.toml @@ -0,0 +1,69 @@ +# fly.toml app configuration file generated for rama-fp-pg on 2025-02-25T21:24:55+01:00 +# +# See https://fly.io/docs/reference/configuration/ for information about how to use this file. +# + +app = 'rama-fp-pg' +primary_region = 'lhr' + +[env] + PRIMARY_REGION = 'lhr' + +[[mounts]] + source = 'pg_data' + destination = '/data' + +[[services]] + protocol = 'tcp' + internal_port = 5432 + auto_start_machines = false + + [[services.ports]] + port = 5432 + handlers = ['pg_tls'] + + [services.concurrency] + type = 'connections' + hard_limit = 1000 + soft_limit = 1000 + +[[services]] + protocol = 'tcp' + internal_port = 5433 + auto_start_machines = false + + [[services.ports]] + port = 5433 + handlers = ['pg_tls'] + + [services.concurrency] + type = 'connections' + hard_limit = 1000 + soft_limit = 1000 + +[checks] + [checks.pg] + port = 5500 + type = 'http' + interval = '15s' + timeout = '10s' + path = '/flycheck/pg' + + [checks.role] + port = 5500 + type = 'http' + interval = '15s' + timeout = '10s' + path = '/flycheck/role' + + [checks.vm] + port = 5500 + type = 'http' + interval = '15s' + timeout = '10s' + path = '/flycheck/vm' + +[[metrics]] + port = 9187 + path = '/metrics' + https = false diff --git a/rama-http-backend/src/client/conn.rs b/rama-http-backend/src/client/conn.rs index 55cd81310..eab219b16 100644 --- a/rama-http-backend/src/client/conn.rs +++ b/rama-http-backend/src/client/conn.rs @@ -3,7 +3,7 @@ use rama_core::{ Context, Layer, Service, error::{BoxError, OpaqueError}, }; -use rama_http_types::{Request, Version, dep::http_body}; +use rama_http_types::{Request, Version, conn::Http1ClientContextParams, dep::http_body}; use rama_net::{ client::{ConnectorService, EstablishedClientConnection}, stream::Stream, @@ -123,7 +123,11 @@ where } Version::HTTP_11 | Version::HTTP_10 | Version::HTTP_09 => { trace!(uri = %req.uri(), "create ~h1 client executor"); - let (sender, conn) = rama_http_core::client::conn::http1::handshake(io).await?; + let mut builder = rama_http_core::client::conn::http1::Builder::new(); + if let Some(params) = ctx.get::() { + builder.title_case_headers(params.title_header_case); + } + let (sender, conn) = builder.handshake(io).await?; ctx.spawn(async move { if let Err(err) = conn.await { diff --git a/rama-http-backend/src/client/mod.rs b/rama-http-backend/src/client/mod.rs index 40bcc37a5..417515564 100644 --- a/rama-http-backend/src/client/mod.rs +++ b/rama-http-backend/src/client/mod.rs @@ -1,6 +1,9 @@ //! Rama HTTP client module, //! which provides the [`HttpClient`] type to serve HTTP requests. +#[cfg(any(feature = "rustls", feature = "boring"))] +use std::sync::Arc; + use proxy::layer::HttpProxyConnector; use rama_core::{ Context, Service, @@ -14,7 +17,7 @@ use rama_tcp::client::service::TcpConnector; use rama_tls::std::client::{TlsConnector, TlsConnectorData}; #[cfg(any(feature = "rustls", feature = "boring"))] -use rama_net::tls::client::ClientConfig; +use rama_net::tls::client::{ClientConfig, ProxyClientConfig}; #[cfg(any(feature = "rustls", feature = "boring"))] use rama_core::error::ErrorContext; @@ -118,15 +121,29 @@ where #[cfg(any(feature = "rustls", feature = "boring"))] let connector = { - let proxy_tls_connector_data = match &self.proxy_tls_config { - Some(proxy_tls_config) => { + let proxy_tls_connector_data = match ( + ctx.get::(), + &self.proxy_tls_config, + ) { + (Some(proxy_tls_config), _) => { + trace!("create proxy tls connector using rama tls client config from ontext"); + proxy_tls_config + .0 + .as_ref() + .clone() + .try_into() + .context( + "HttpClient: create proxy tls connector data from tls config found in context", + )? + } + (None, Some(proxy_tls_config)) => { trace!("create proxy tls connector using pre-defined rama tls client config"); proxy_tls_config .clone() .try_into() .context("HttpClient: create proxy tls connector data from tls config")? } - None => { + (None, None) => { trace!("create proxy tls connector using the 'new_http_auto' constructor"); TlsConnectorData::new().context( "HttpClient: create proxy tls connector data with no application presets", @@ -138,15 +155,20 @@ where TlsConnector::tunnel(tcp_connector, None) .with_connector_data(proxy_tls_connector_data), ); - let tls_connector_data = match &self.tls_config { - Some(tls_config) => { + let tls_connector_data = match (ctx.get::>(), &self.tls_config) { + (Some(tls_config), _) => { + trace!("create tls connector using rama tls client config from ontext"); + tls_config.as_ref().clone().try_into().context( + "HttpClient: create tls connector data from tls config found in context", + )? + } + (None, Some(tls_config)) => { trace!("create tls connector using pre-defined rama tls client config"); - tls_config - .clone() - .try_into() - .context("HttpClient: create tls connector data from tls config")? + tls_config.clone().try_into().context( + "HttpClient: create tls connector data from pre-defined tls config", + )? } - None => { + (None, None) => { trace!("create tls connector using the 'new_http_auto' constructor"); TlsConnectorData::new_http_auto() .context("HttpClient: create tls connector data for http (auto)")? diff --git a/rama-http-types/src/compression.rs b/rama-http-types/src/compression.rs new file mode 100644 index 000000000..007869f15 --- /dev/null +++ b/rama-http-types/src/compression.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] +#[non_exhaustive] +/// Marker type that can be used to request an opt-in +/// decompression layer to decompress a body in case it is compressed. +pub struct DecompressIfPossible; diff --git a/rama-http-types/src/conn.rs b/rama-http-types/src/conn.rs new file mode 100644 index 000000000..0f662cc52 --- /dev/null +++ b/rama-http-types/src/conn.rs @@ -0,0 +1,15 @@ +//! HTTP connection utilities. + +#[derive(Debug, Clone, Default)] +/// Optional parameters that can be set in the [`Context`] of a (h1) request +/// to customise the connection of the h1 connection. +/// +/// Can be used by Http connector services, especially in the context of proxies, +/// where there might not be one static config that is to be applied to all client connections. +pub struct Http1ClientContextParams { + /// Set whether HTTP/1 connections will write header names as title case at + /// the socket level. + /// + /// Default is false. + pub title_header_case: bool, +} diff --git a/rama-http-types/src/headers/client_hints.rs b/rama-http-types/src/headers/client_hints.rs new file mode 100644 index 000000000..b85a56d43 --- /dev/null +++ b/rama-http-types/src/headers/client_hints.rs @@ -0,0 +1,292 @@ +macro_rules! client_hint { + ( + #[doc = $ch_doc:literal] + pub enum ClientHint { + $( + #[doc = $doc:literal] + $name:ident($($str:literal),*), + )+ + } + ) => { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub enum ClientHint { + $( + #[doc = $doc] + $name, + )+ + } + + impl ClientHint { + #[doc = "Checks if the client hint is low entropy, meaning that it will be send by default."] + pub fn is_low_entropy(&self) -> bool { + matches!(self, Self::SaveData | Self::Ua | Self::Mobile | Self::Platform) + } + + #[inline] + #[doc = "Attempts to convert a `HeaderName` to a `ClientHint`."] + pub fn match_header_name(name: &crate::HeaderName) -> Option { + name.try_into().ok() + } + + #[doc = "Returns the preferred string representation of the client hint."] + pub fn as_str(&self) -> &'static str { + match self { + $( + Self::$name => { + const VARIANTS: &'static [&'static str] = &[$($str,)+]; + VARIANTS[0] + }, + )+ + } + } + } + + rama_utils::macros::error::static_str_error! { + /// Client Hint Parsing Error + pub struct ClientHintParsingError; + } + + impl TryFrom<&str> for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: &str) -> Result { + rama_utils::macros::match_ignore_ascii_case_str! { + match (name) { + $( + $($str)|+ => Ok(Self::$name), + )+ + _ => Err(ClientHintParsingError), + } + } + } + } + + impl TryFrom for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: String) -> Result { + Self::try_from(name.as_str()) + } + } + + impl TryFrom<$crate::HeaderName> for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: $crate::HeaderName) -> Result { + Self::try_from(name.as_str()) + } + } + + impl TryFrom<&$crate::HeaderName> for ClientHint { + type Error = ClientHintParsingError; + + fn try_from(name: &$crate::HeaderName) -> Result { + Self::try_from(name.as_str()) + } + } + + impl std::str::FromStr for ClientHint { + type Err = ClientHintParsingError; + + #[inline] + fn from_str(s: &str) -> Result { + Self::try_from(s) + } + } + + impl std::fmt::Display for ClientHint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } + } + + impl serde::Serialize for ClientHint { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } + } + + impl<'de> serde::Deserialize<'de> for ClientHint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + let s = >::deserialize(deserializer)?; + Self::try_from(s.as_ref()).map_err(D::Error::custom) + } + } + + #[doc = "Returns an iterator over all client hint header name strings."] + pub fn all_client_hint_header_name_strings() -> impl Iterator { + [ + $( + $($str,)+ + )+ + ].into_iter() + } + + #[doc = "Returns an iterator over all client hint header names."] + pub fn all_client_hint_header_names() -> impl Iterator { + all_client_hint_header_name_strings().map($crate::HeaderName::from_static) + } + }; +} + +// NOTE: we are open to contributions to this module, +// e.g. in case you wish typed headers for each or some of these client hint headers, +// we gladly mentor and guide you in the process. + +client_hint! { + #[doc = "Client Hints are a set of HTTP Headers and a JavaScript API that allow web browsers to send detailed information about the client device and browser to web servers. They are designed to be a successor to User-Agent, and provide a standardized way for web servers to optimize content for the client without relying on unreliable user-agent string-based detection or browser fingerprinting techniques."] + pub enum ClientHint { + /// Sec-CH-UA represents a user agent's branding and version. + Ua("sec-ch-ua"), + /// Sec-CH-UA-Full-Version represents the user agent's full version. + FullVersion("sec-ch-ua-full-version"), + /// Sec-CH-UA-Full-Version-List represents the full version for each brand in its brands list. + FullVersionList("sec-ch-ua-full-version-list"), + /// Sec-CH-UA-Platform represents the platform on which a given user agent is executing. + Platform("sec-ch-ua-platform"), + /// Sec-CH-UA-Platform-Version represents the platform version on which a given user agent is executing. + PlatformVersion("sec-ch-ua-platform-version"), + /// Sec-CH-UA-Arch represents the architecture of the platform on which a given user agent is executing. + Arch("sec-ch-ua-arch"), + /// Sec-CH-UA-Bitness represents the bitness of the architecture of the platform on which a given user agent is executing. + Bitness("sec-ch-ua-bitness"), + /// Sec-CH-UA-WoW64 is used to detect whether or not a user agent binary is running in 32-bit mode on 64-bit Windows. + Wow64("sec-ch-ua-wow64"), + /// Sec-CH-UA-Model represents the device on which a given user agent is executing. + Model("sec-ch-ua-model"), + /// Sec-CH-UA-Mobile is used to detect whether or not a user agent prefers a «mobile» user experience. + Mobile("sec-ch-ua-mobile"), + /// Sec-CH-UA-Form-Factors represents the form-factors of a device, historically represented as a token in the User-Agent string. + FormFactor("sec-ch-ua-form-factors"), + /// Sec-CH-Lang (or Lang) represents the user's language preference. + Lang("sec-ch-lang", "lang"), + /// Sec-CH-Save-Data (or Save-Data) represents the user agent's preference for reduced data usage. + SaveData("sec-ch-save-data", "save-data"), + /// Sec-CH-Width gives a server the layout width of the image. + Width("sec-ch-width"), + /// Sec-CH-Viewport-Width (or Viewport-Width) is the width of the user's viewport in CSS pixels. + ViewportWidth("sec-ch-viewport-width", "viewport-width"), + /// Sec-CH-Viewport-Height represents the user-agent's current viewport height. + ViewportHeight("sec-ch-viewport-height"), + /// Sec-CH-DPR (or DPR) reports the ratio of physical pixels to CSS pixels of the user's screen. + Dpr("sec-ch-dpr", "dpr"), + /// Sec-CH-Device-Memory (or Device-Memory) reveals the approximate amount of memory the current device has in GiB. Because this information could be used to fingerprint users, the value of Device-Memory is intentionally coarse. Valid values are 0.25, 0.5, 1, 2, 4, and 8. + DeviceMemory("sec-ch-device-memory", "device-memory"), + /// Sec-CH-RTT (or RTT) provides the approximate Round Trip Time, in milliseconds, on the application layer. The RTT hint, unlike transport layer RTT, includes server processing time. The value of RTT is rounded to the nearest 25 milliseconds to prevent fingerprinting. + Rtt("sec-ch-rtt", "rtt"), + /// Sec-CH-Downlink (or Downlink) expressed in megabits per second (Mbps), reveals the approximate downstream speed of the user's connection. The value is rounded to the nearest multiple of 25 kilobits per second. Because again, fingerprinting. + Downlink("sec-ch-downlink", "downlink"), + /// Sec-CH-ECT (or ECT) stands for Effective Connection Type. Its value is one of an enumerated list of connection types, each of which describes a connection within specified ranges of both RTT and Downlink values. Valid values for ECT are 4g, 3g, 2g, and slow-2g. + Ect("sec-ch-ect", "ect"), + /// Sec-CH-Prefers-Color-Scheme represents the user's preferred color scheme. + PrefersColorScheme("sec-ch-prefers-color-scheme"), + /// Sec-CH-Prefers-Reduced-Motion is used to detect if the user has requested the system minimize the amount of animation or motion it uses. + PrefersReducedMotion("sec-ch-prefers-reduced-motion"), + /// Sec-CH-Prefers-Reduced-Transparency is used to detect if the user has requested the system minimize the amount of transparent or translucent layer effects it uses. + PrefersReducedTransparency("sec-ch-prefers-reduced-transparency"), + /// Sec-CH-Prefers-Contrast is used to detect if the user has requested that the web content is presented with a higher (or lower) contrast. + PrefersContrast("sec-ch-prefers-contrast"), + /// Sec-CH-Forced-Colors is used to detect if the user agent has enabled a forced colors mode where it enforces a user-chosen limited color palette on the page. + ForcedColors("sec-ch-forced-colors"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_hint_ua_from_str() { + let hint = ClientHint::try_from("Sec-CH-UA").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_ua_from_str_lowercase() { + let hint = ClientHint::try_from("sec-ch-ua").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_ua_from_str_uppercase() { + let hint = ClientHint::try_from("SEC-CH-UA").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_ua_from_str_mixedcase() { + let hint = ClientHint::try_from("Sec-CH-UA").unwrap(); + assert_eq!(hint, ClientHint::Ua); + } + + #[test] + fn test_client_hint_low_entropy() { + let hints = [ + "Sec-CH-UA", + "Sec-CH-UA-Mobile", + "Sec-CH-UA-Platform", + "Save-Data", + "Sec-CH-Save-Data", + ]; + + for hint in hints { + let hint = ClientHint::try_from(hint).expect(hint); + assert!(hint.is_low_entropy()); + } + } + + #[test] + fn test_client_hint_high_entropy() { + let hints = [ + "Sec-CH-UA-Full-Version", + "Sec-CH-UA-Full-Version-List", + "Sec-CH-UA-Platform-Version", + "Sec-CH-UA-Arch", + "Sec-CH-UA-Bitness", + "Sec-CH-UA-WoW64", + "Sec-CH-UA-Model", + "Sec-CH-UA-Form-Factors", + "Sec-CH-Width", + "Sec-CH-Viewport-Width", + "Sec-CH-Viewport-Height", + "Sec-CH-DPR", + "Sec-CH-Device-Memory", + "Sec-CH-RTT", + "Sec-CH-Downlink", + "Sec-CH-ECT", + "Sec-CH-Prefers-Color-Scheme", + "Sec-CH-Prefers-Reduced-Motion", + "Sec-CH-Prefers-Reduced-Transparency", + "Sec-CH-Prefers-Contrast", + "Sec-CH-Forced-Colors", + ]; + + for hint in hints { + let hint = ClientHint::try_from(hint).expect(hint); + assert!(!hint.is_low_entropy()); + } + } + + #[test] + fn test_all_client_hint_header_name_strings_contains_some_hints() { + let strings = all_client_hint_header_name_strings().collect::>(); + assert!(strings.contains(&"sec-ch-ua"), "{:?}", strings); + } + + #[test] + fn test_all_client_hint_header_names() { + let names = all_client_hint_header_names().collect::>(); + let strings = all_client_hint_header_name_strings().collect::>(); + assert_eq!(names.len(), strings.len()); + for (name, string) in names.iter().zip(strings.iter()) { + assert_eq!(name.as_str(), *string); + } + } +} diff --git a/rama-http/src/headers/common/accept.rs b/rama-http-types/src/headers/common/accept.rs similarity index 89% rename from rama-http/src/headers/common/accept.rs rename to rama-http-types/src/headers/common/accept.rs index caedd4943..cdd4300c2 100644 --- a/rama-http/src/headers/common/accept.rs +++ b/rama-http-types/src/headers/common/accept.rs @@ -1,5 +1,6 @@ use crate::dep::mime::{self, Mime}; -use crate::headers::{self, Header, QualityValue}; +use crate::headers::specifier::QualityValue; +use crate::headers::{self, Header}; use crate::{HeaderName, HeaderValue}; use std::iter::FromIterator; @@ -35,10 +36,10 @@ fn qitem(mime: Mime) -> QualityValue { /// # Examples /// ``` /// use std::iter::FromIterator; -/// use rama_http::headers::{Accept, QualityValue, HeaderMapExt}; -/// use rama_http::dep::mime; +/// use rama_http_types::headers::{Accept, specifier::QualityValue, HeaderMapExt}; +/// use rama_http_types::dep::mime; /// -/// let mut headers = rama_http::HeaderMap::new(); +/// let mut headers = rama_http_types::HeaderMap::new(); /// /// headers.typed_insert( /// Accept::from_iter(vec![ @@ -49,10 +50,10 @@ fn qitem(mime: Mime) -> QualityValue { /// /// ``` /// use std::iter::FromIterator; -/// use rama_http::headers::{Accept, QualityValue, HeaderMapExt}; -/// use rama_http::dep::mime; +/// use rama_http_types::headers::{Accept, specifier::QualityValue, HeaderMapExt}; +/// use rama_http_types::dep::mime; /// -/// let mut headers = rama_http::HeaderMap::new(); +/// let mut headers = rama_http_types::HeaderMap::new(); /// headers.typed_insert( /// Accept::from_iter(vec![ /// QualityValue::new(mime::APPLICATION_JSON, Default::default()), @@ -61,10 +62,10 @@ fn qitem(mime: Mime) -> QualityValue { /// ``` /// ``` /// use std::iter::FromIterator; -/// use rama_http::headers::{Accept, QualityValue, HeaderMapExt}; -/// use rama_http::dep::mime; +/// use rama_http_types::headers::{Accept, specifier::QualityValue, HeaderMapExt}; +/// use rama_http_types::dep::mime; /// -/// let mut headers = rama_http::HeaderMap::new(); +/// let mut headers = rama_http_types::HeaderMap::new(); /// /// headers.typed_insert( /// Accept::from_iter(vec![ @@ -161,7 +162,7 @@ mod tests { mime::{TEXT_HTML, TEXT_PLAIN, TEXT_PLAIN_UTF_8}, }; - use crate::headers::Quality; + use crate::headers::specifier::Quality; macro_rules! test_header { ($name: ident, $input: expr, $expected: expr) => { diff --git a/rama-http/src/headers/common/mod.rs b/rama-http-types/src/headers/common/mod.rs similarity index 100% rename from rama-http/src/headers/common/mod.rs rename to rama-http-types/src/headers/common/mod.rs diff --git a/rama-http-types/src/headers/encoding/accept_encoding.rs b/rama-http-types/src/headers/encoding/accept_encoding.rs new file mode 100644 index 000000000..8d1ac16b5 --- /dev/null +++ b/rama-http-types/src/headers/encoding/accept_encoding.rs @@ -0,0 +1,99 @@ +use super::SupportedEncodings; +use crate::HeaderValue; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AcceptEncoding { + gzip: bool, + deflate: bool, + br: bool, + zstd: bool, +} + +impl AcceptEncoding { + pub fn maybe_to_header_value(self) -> Option { + let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { + (true, true, true, false) => "gzip,deflate,br", + (true, true, false, false) => "gzip,deflate", + (true, false, true, false) => "gzip,br", + (true, false, false, false) => "gzip", + (false, true, true, false) => "deflate,br", + (false, true, false, false) => "deflate", + (false, false, true, false) => "br", + (true, true, true, true) => "zstd,gzip,deflate,br", + (true, true, false, true) => "zstd,gzip,deflate", + (true, false, true, true) => "zstd,gzip,br", + (true, false, false, true) => "zstd,gzip", + (false, true, true, true) => "zstd,deflate,br", + (false, true, false, true) => "zstd,deflate", + (false, false, true, true) => "zstd,br", + (false, false, false, true) => "zstd", + (false, false, false, false) => return None, + }; + Some(HeaderValue::from_static(accept)) + } + + pub fn set_gzip(&mut self, enable: bool) { + self.gzip = enable; + } + + pub fn with_gzip(mut self, enable: bool) -> Self { + self.gzip = enable; + self + } + + pub fn set_deflate(&mut self, enable: bool) { + self.deflate = enable; + } + + pub fn with_deflate(mut self, enable: bool) -> Self { + self.deflate = enable; + self + } + + pub fn set_br(&mut self, enable: bool) { + self.br = enable; + } + + pub fn with_br(mut self, enable: bool) -> Self { + self.br = enable; + self + } + + pub fn set_zstd(&mut self, enable: bool) { + self.zstd = enable; + } + + pub fn with_zstd(mut self, enable: bool) -> Self { + self.zstd = enable; + self + } +} + +impl SupportedEncodings for AcceptEncoding { + fn gzip(&self) -> bool { + self.gzip + } + + fn deflate(&self) -> bool { + self.deflate + } + + fn br(&self) -> bool { + self.br + } + + fn zstd(&self) -> bool { + self.zstd + } +} + +impl Default for AcceptEncoding { + fn default() -> Self { + AcceptEncoding { + gzip: true, + deflate: true, + br: true, + zstd: true, + } + } +} diff --git a/rama-http/src/layer/util/content_encoding.rs b/rama-http-types/src/headers/encoding/mod.rs similarity index 64% rename from rama-http/src/layer/util/content_encoding.rs rename to rama-http-types/src/headers/encoding/mod.rs index 7442191a1..57bf9b967 100644 --- a/rama-http/src/layer/util/content_encoding.rs +++ b/rama-http-types/src/headers/encoding/mod.rs @@ -1,18 +1,41 @@ -//! Types and functions for handling content encoding. +//! Utility types and functions that can be used in context of encoding headers. use rama_utils::macros::match_ignore_ascii_case_str; +use std::fmt; -pub(crate) trait SupportedEncodings: Copy { +mod accept_encoding; +pub use accept_encoding::AcceptEncoding; + +use super::specifier::{Quality, QualityValue}; + +pub trait SupportedEncodings: Copy { fn gzip(&self) -> bool; fn deflate(&self) -> bool; fn br(&self) -> bool; fn zstd(&self) -> bool; } -// This enum's variants are ordered from least to most preferred. +impl SupportedEncodings for bool { + fn gzip(&self) -> bool { + *self + } + + fn deflate(&self) -> bool { + *self + } + + fn br(&self) -> bool { + *self + } + + fn zstd(&self) -> bool { + *self + } +} + #[derive(Copy, Clone, Debug, Ord, PartialOrd, PartialEq, Eq, Hash)] -pub(crate) enum Encoding { - #[allow(dead_code)] +/// This enum's variants are ordered from least to most preferred. +pub enum Encoding { Identity, Deflate, Gzip, @@ -20,9 +43,21 @@ pub(crate) enum Encoding { Zstd, } +impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl From for http::HeaderValue { + #[inline] + fn from(encoding: Encoding) -> Self { + http::HeaderValue::from_static(encoding.as_str()) + } +} + impl Encoding { - #[allow(dead_code)] - fn to_str(self) -> &'static str { + fn as_str(self) -> &'static str { match self { Encoding::Identity => "identity", Encoding::Gzip => "gzip", @@ -32,8 +67,7 @@ impl Encoding { } } - #[allow(dead_code)] - pub(crate) fn to_file_extension(self) -> Option<&'static std::ffi::OsStr> { + pub fn to_file_extension(self) -> Option<&'static std::ffi::OsStr> { match self { Encoding::Gzip => Some(std::ffi::OsStr::new(".gz")), Encoding::Deflate => Some(std::ffi::OsStr::new(".zz")), @@ -43,122 +77,72 @@ impl Encoding { } } - #[allow(dead_code)] - pub(crate) fn into_header_value(self) -> http::HeaderValue { - http::HeaderValue::from_static(self.to_str()) - } - - fn parse(s: &str, _supported_encoding: impl SupportedEncodings) -> Option { + fn parse(s: &str, supported_encoding: impl SupportedEncodings) -> Option { match_ignore_ascii_case_str! { match (s) { - "gzip" | "x-gzip" if _supported_encoding.gzip() => Some(Encoding::Gzip), - "deflate" if _supported_encoding.deflate() => Some(Encoding::Deflate), - "br" if _supported_encoding.br() => Some(Encoding::Brotli), - "zstd" if _supported_encoding.zstd() => Some(Encoding::Zstd), + "gzip" | "x-gzip" if supported_encoding.gzip() => Some(Encoding::Gzip), + "deflate" if supported_encoding.deflate() => Some(Encoding::Deflate), + "br" if supported_encoding.br() => Some(Encoding::Brotli), + "zstd" if supported_encoding.zstd() => Some(Encoding::Zstd), "identity" => Some(Encoding::Identity), _ => None, } } } - #[cfg(feature = "compression")] - // based on https://github.com/http-rs/accept-encoding - #[allow(dead_code)] - pub(crate) fn from_headers( + pub fn maybe_from_content_encoding_header( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, + ) -> Option { + headers + .get(http::header::CONTENT_ENCODING) + .and_then(|hval| hval.to_str().ok()) + .and_then(|s| Encoding::parse(s, supported_encoding)) + } + + #[inline] + pub fn from_content_encoding_header( headers: &http::HeaderMap, supported_encoding: impl SupportedEncodings, ) -> Self { - Encoding::preferred_encoding(encodings(headers, supported_encoding)) + Encoding::maybe_from_content_encoding_header(headers, supported_encoding) .unwrap_or(Encoding::Identity) } - #[allow(dead_code)] - pub(crate) fn preferred_encoding( - accepted_encodings: impl Iterator, + pub fn maybe_from_accept_encoding_headers( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, ) -> Option { - accepted_encodings - .filter(|(_, qvalue)| qvalue.0 > 0) - .max_by_key(|&(encoding, qvalue)| (qvalue, encoding)) - .map(|(encoding, _)| encoding) + Encoding::maybe_preferred_encoding(parse_accept_encoding_headers( + headers, + supported_encoding, + )) } -} - -// Allowed q-values are numbers between 0 and 1 with at most 3 digits in the fractional part. They -// are presented here as an unsigned integer between 0 and 1000. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub(crate) struct QValue(u16); -impl QValue { #[inline] - fn one() -> Self { - Self(1000) - } - - // Parse a q-value as specified in RFC 7231 section 5.3.1. - fn parse(s: &str) -> Option { - let mut c = s.chars(); - // Parse "q=" (case-insensitively). - match c.next() { - Some('q' | 'Q') => (), - _ => return None, - }; - match c.next() { - Some('=') => (), - _ => return None, - }; - - // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1" - // are allowed. - let mut value = match c.next() { - Some('0') => 0, - Some('1') => 1000, - _ => return None, - }; - - // Parse optional decimal point. - match c.next() { - Some('.') => (), - None => return Some(Self(value)), - _ => return None, - }; - - // Parse optional fractional digits. The value of each digit is multiplied by `factor`. - // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for - // the first digit, `10` for the next, and `1` for the digit after that. - let mut factor = 100; - loop { - match c.next() { - Some(n @ '0'..='9') => { - // If `factor` is less than `1`, three digits have already been parsed. A - // q-value having more than 3 fractional digits is invalid. - if factor < 1 { - return None; - } - // Add the digit's value multiplied by `factor` to `value`. - value += factor * (n as u16 - '0' as u16); - } - None => { - // No more characters to parse. Check that the value representing the q-value is - // in the valid range. - return if value <= 1000 { - Some(Self(value)) - } else { - None - }; - } - _ => return None, - }; - factor /= 10; - } + pub fn from_accept_encoding_headers( + headers: &http::HeaderMap, + supported_encoding: impl SupportedEncodings, + ) -> Self { + Encoding::maybe_from_accept_encoding_headers(headers, supported_encoding) + .unwrap_or(Encoding::Identity) + } + + pub fn maybe_preferred_encoding( + accepted_encodings: impl Iterator>, + ) -> Option { + accepted_encodings + .filter(|qval| qval.quality.as_u16() > 0) + .max_by_key(|qval| (qval.quality, qval.value)) + .map(|qval| qval.value) } } -// based on https://github.com/http-rs/accept-encoding -#[allow(dead_code)] -pub(crate) fn encodings<'a>( +/// based on https://github.com/http-rs/accept-encoding +pub fn parse_accept_encoding_headers<'a>( headers: &'a http::HeaderMap, supported_encoding: impl SupportedEncodings + 'a, -) -> impl Iterator + 'a { +) -> impl Iterator> + 'a { headers .get_all(http::header::ACCEPT_ENCODING) .iter() @@ -173,16 +157,16 @@ pub(crate) fn encodings<'a>( }; let qval = if let Some(qval) = v.next() { - QValue::parse(qval.trim())? + qval.trim().parse::().ok()? } else { - QValue::one() + Quality::one() }; - Some((encoding, qval)) + Some(QualityValue::new(encoding, qval)) }) } -#[cfg(all(test, feature = "compression"))] +#[cfg(test)] mod tests { use super::*; @@ -209,7 +193,8 @@ mod tests { #[test] fn no_accept_encoding_header() { - let encoding = Encoding::from_headers(&http::HeaderMap::new(), SupportedEncodingsAll); + let encoding = + Encoding::from_accept_encoding_headers(&http::HeaderMap::new(), SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } @@ -220,7 +205,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -231,7 +216,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -242,7 +227,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,x-gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -253,7 +238,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("deflate,x-gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -264,7 +249,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip,deflate,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -275,7 +260,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -286,7 +271,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate,br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -301,7 +286,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -316,7 +301,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -335,7 +320,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("br"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -346,7 +331,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); @@ -354,7 +339,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -362,7 +347,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,br;q=0.999"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -373,7 +358,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,deflate;q=0.6,br;q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); let mut headers = http::HeaderMap::new(); @@ -381,7 +366,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.8,deflate;q=0.6,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -389,7 +374,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.6,deflate;q=0.8,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Deflate, encoding); let mut headers = http::HeaderMap::new(); @@ -397,7 +382,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.995,deflate;q=0.997,br;q=0.999"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -408,7 +393,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("invalid,gzip"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); } @@ -419,7 +404,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -427,7 +412,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0."), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -435,7 +420,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0,br;q=0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -446,7 +431,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gZiP"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Gzip, encoding); let mut headers = http::HeaderMap::new(); @@ -454,7 +439,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5,br;Q=0.8"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -465,7 +450,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static(" gzip\t; q=0.5 ,\tbr ;\tq=0.8\t"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Brotli, encoding); } @@ -476,7 +461,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q =0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -484,7 +469,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q= 0.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } @@ -495,7 +480,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=-0.1"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -503,7 +488,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=00.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -511,7 +496,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=0.5000"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -519,7 +504,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=.5"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -527,7 +512,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.01"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); let mut headers = http::HeaderMap::new(); @@ -535,7 +520,7 @@ mod tests { http::header::ACCEPT_ENCODING, http::HeaderValue::from_static("gzip;q=1.001"), ); - let encoding = Encoding::from_headers(&headers, SupportedEncodingsAll); + let encoding = Encoding::from_accept_encoding_headers(&headers, SupportedEncodingsAll); assert_eq!(Encoding::Identity, encoding); } } diff --git a/rama-http-types/src/headers/mod.rs b/rama-http-types/src/headers/mod.rs index 13ea70b89..05bb13f3e 100644 --- a/rama-http-types/src/headers/mod.rs +++ b/rama-http-types/src/headers/mod.rs @@ -90,3 +90,16 @@ pub mod authorization { mod ext; #[doc(inline)] pub use ext::HeaderExt; + +pub mod encoding; +pub mod specifier; + +pub mod util; + +mod common; +pub use common::Accept; + +mod client_hints; +pub use client_hints::{ + ClientHint, all_client_hint_header_name_strings, all_client_hint_header_names, +}; diff --git a/rama-http-types/src/headers/specifier/mod.rs b/rama-http-types/src/headers/specifier/mod.rs new file mode 100644 index 000000000..aca3af972 --- /dev/null +++ b/rama-http-types/src/headers/specifier/mod.rs @@ -0,0 +1,6 @@ +//! Specifiers that can be used as part of header values. +//! +//! An example is the [`QValue`] used in function of several headers such as 'accept-encoding'. + +mod quality_value; +pub use quality_value::{Quality, QualityValue}; diff --git a/rama-http/src/headers/util/quality_value.rs b/rama-http-types/src/headers/specifier/quality_value.rs similarity index 74% rename from rama-http/src/headers/util/quality_value.rs rename to rama-http-types/src/headers/specifier/quality_value.rs index c0328b13c..bbb22ebb2 100644 --- a/rama-http/src/headers/util/quality_value.rs +++ b/rama-http-types/src/headers/specifier/quality_value.rs @@ -24,6 +24,80 @@ use self::internal::IntoQuality; #[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)] pub struct Quality(u16); +impl Quality { + #[inline] + pub fn one() -> Self { + Self(1000) + } + + #[inline] + pub fn as_u16(&self) -> u16 { + self.0 + } +} + +impl str::FromStr for Quality { + type Err = crate::headers::Error; + + // Parse a q-value as specified in RFC 7231 section 5.3.1. + fn from_str(s: &str) -> Result { + let mut c = s.chars(); + // Parse "q=" (case-insensitively). + match c.next() { + Some('q' | 'Q') => (), + _ => return Err(crate::headers::Error::invalid()), + }; + match c.next() { + Some('=') => (), + _ => return Err(crate::headers::Error::invalid()), + }; + + // Parse leading digit. Since valid q-values are between 0.000 and 1.000, only "0" and "1" + // are allowed. + let mut value = match c.next() { + Some('0') => 0, + Some('1') => 1000, + _ => return Err(crate::headers::Error::invalid()), + }; + + // Parse optional decimal point. + match c.next() { + Some('.') => (), + None => return Ok(Self(value)), + _ => return Err(crate::headers::Error::invalid()), + }; + + // Parse optional fractional digits. The value of each digit is multiplied by `factor`. + // Since the q-value is represented as an integer between 0 and 1000, `factor` is `100` for + // the first digit, `10` for the next, and `1` for the digit after that. + let mut factor = 100; + loop { + match c.next() { + Some(n @ '0'..='9') => { + // If `factor` is less than `1`, three digits have already been parsed. A + // q-value having more than 3 fractional digits is invalid. + if factor < 1 { + return Err(crate::headers::Error::invalid()); + } + // Add the digit's value multiplied by `factor` to `value`. + value += factor * (n as u16 - '0' as u16); + } + None => { + // No more characters to parse. Check that the value representing the q-value is + // in the valid range. + return if value <= 1000 { + Ok(Self(value)) + } else { + Err(crate::headers::Error::invalid()) + }; + } + _ => return Err(crate::headers::Error::invalid()), + }; + factor /= 10; + } + } +} + impl Default for Quality { fn default() -> Quality { Quality(1000) @@ -40,6 +114,8 @@ pub struct QualityValue { pub quality: Quality, } +impl Copy for QualityValue {} + impl QualityValue { /// Creates a new `QualityValue` from an item and a quality. pub const fn new(value: T, quality: Quality) -> QualityValue { @@ -92,7 +168,7 @@ impl str::FromStr for QualityValue { fn from_str(s: &str) -> Result, crate::headers::Error> { // Set defaults used if parsing fails. let mut raw_item = s; - let mut quality = 1f32; + let mut quality = Quality::one(); let mut parts = s.rsplitn(2, ';').map(|x| x.trim()); if let (Some(first), Some(second), None) = (parts.next(), parts.next(), parts.next()) { @@ -100,26 +176,13 @@ impl str::FromStr for QualityValue { return Err(crate::headers::Error::invalid()); } if first.starts_with("q=") || first.starts_with("Q=") { - let q_part = &first[2..]; - if q_part.len() > 5 { - return Err(crate::headers::Error::invalid()); - } - match q_part.parse::() { - Ok(q_value) => { - if (0f32..=1f32).contains(&q_value) { - quality = q_value; - raw_item = second; - } else { - return Err(crate::headers::Error::invalid()); - } - } - Err(_) => return Err(crate::headers::Error::invalid()), - } + quality = Quality::from_str(first)?; + raw_item = second; } } match raw_item.parse::() { // we already checked above that the quality is within range - Ok(item) => Ok(QualityValue::new(item, from_f32(quality))), + Ok(item) => Ok(QualityValue::new(item, quality)), Err(_) => Err(crate::headers::Error::invalid()), } } diff --git a/rama-http/src/headers/util/csv.rs b/rama-http-types/src/headers/util/csv.rs similarity index 89% rename from rama-http/src/headers/util/csv.rs rename to rama-http-types/src/headers/util/csv.rs index c9f65c59b..350a072d5 100644 --- a/rama-http/src/headers/util/csv.rs +++ b/rama-http-types/src/headers/util/csv.rs @@ -8,7 +8,7 @@ use crate::HeaderValue; use crate::headers::Error; /// Reads a comma-delimited raw header into a Vec. -pub(crate) fn from_comma_delimited<'i, I, T, E>(values: &mut I) -> Result +pub fn from_comma_delimited<'i, I, T, E>(values: &mut I) -> Result where I: Iterator, T: ::std::str::FromStr, @@ -30,7 +30,7 @@ where } /// Format an array into a comma-delimited string. -pub(crate) fn fmt_comma_delimited( +pub fn fmt_comma_delimited( f: &mut fmt::Formatter, mut iter: impl Iterator, ) -> fmt::Result { diff --git a/rama-http-types/src/headers/util/mod.rs b/rama-http-types/src/headers/util/mod.rs new file mode 100644 index 000000000..d4996f0bb --- /dev/null +++ b/rama-http-types/src/headers/util/mod.rs @@ -0,0 +1 @@ +pub mod csv; diff --git a/rama-http-types/src/lib.rs b/rama-http-types/src/lib.rs index cf60b2e7d..af66afc55 100644 --- a/rama-http-types/src/lib.rs +++ b/rama-http-types/src/lib.rs @@ -35,8 +35,12 @@ pub use response::{IntoResponse, IntoResponseParts, Response}; pub mod proto; +pub mod compression; + pub mod headers; +pub mod conn; + pub mod dep { //! Dependencies for rama http modules. //! diff --git a/rama-http-types/src/proto/h1/headers/map.rs b/rama-http-types/src/proto/h1/headers/map.rs index 84e051c6c..e316e6e12 100644 --- a/rama-http-types/src/proto/h1/headers/map.rs +++ b/rama-http-types/src/proto/h1/headers/map.rs @@ -1,4 +1,10 @@ -use std::collections::{self, HashMap}; +use std::{ + borrow::Cow, + collections::{self, HashMap}, +}; + +use http::header::AsHeaderName; +use serde::{Deserialize, Serialize, de::Error as _, ser::Error as _}; use super::{ Http1HeaderName, @@ -50,6 +56,22 @@ impl Http1HeaderMap { } } + #[inline] + pub fn get(&self, key: impl AsHeaderName) -> Option<&HeaderValue> { + self.headers.get(key) + } + + pub fn get_original_name(&self, key: &HeaderName) -> Option<&Http1HeaderName> { + self.original_headers + .iter() + .find(|header| header.header_name() == key) + } + + #[inline] + pub fn contains_key(&self, key: impl AsHeaderName) -> bool { + self.headers.contains_key(key) + } + pub fn into_headers(self) -> HeaderMap { self.headers } @@ -132,6 +154,41 @@ impl IntoIterator for Http1HeaderMap { } } +impl Serialize for Http1HeaderMap { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let headers: Result, _> = self + .clone() + .into_iter() + .map(|(name, value)| { + let value = value.to_str().map_err(S::Error::custom)?; + Ok::<_, S::Error>((name, value.to_owned())) + }) + .collect(); + headers?.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Http1HeaderMap { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let headers = )>>::deserialize(deserializer)?; + headers + .into_iter() + .map(|(name, value)| { + Ok::<_, D::Error>(( + name, + HeaderValue::from_str(&value).map_err(D::Error::custom)?, + )) + }) + .collect() + } +} + #[derive(Debug)] pub struct Http1HeaderMapIntoIter { state: Http1HeaderMapIntoIterState, @@ -186,7 +243,9 @@ impl Iterator for Http1HeaderMapIntoIter { } #[derive(Debug)] -struct HeaderMapValueRemover { +/// Utility that can be used to be able to remove +/// headers from an [`HeaderMap`] in random order, one by one. +pub struct HeaderMapValueRemover { header_map: HeaderMap, removed_values: Option>>, } @@ -201,7 +260,7 @@ impl From for HeaderMapValueRemover { } impl HeaderMapValueRemover { - fn remove(&mut self, header: &HeaderName) -> Option { + pub fn remove(&mut self, header: &HeaderName) -> Option { match self.header_map.entry(header) { header::Entry::Occupied(occupied_entry) => { let (k, mut values) = occupied_entry.remove_entry_mult(); @@ -244,7 +303,8 @@ impl IntoIterator for HeaderMapValueRemover { } #[derive(Debug)] -struct HeaderMapValueRemoverIntoIter { +/// Porduced by the [`IntoIterator`] implementation for [`HeaderMapValueRemover`]. +pub struct HeaderMapValueRemoverIntoIter { cached_header_name: Option, cached_headers: Option>>, removed_headers: diff --git a/rama-http-types/src/proto/h1/headers/mod.rs b/rama-http-types/src/proto/h1/headers/mod.rs index 9a7569098..3d6154658 100644 --- a/rama-http-types/src/proto/h1/headers/mod.rs +++ b/rama-http-types/src/proto/h1/headers/mod.rs @@ -12,4 +12,6 @@ pub use name::{Http1HeaderName, IntoHttp1HeaderName, TryIntoHttp1HeaderName}; pub mod original; mod map; -pub use map::{Http1HeaderMap, Http1HeaderMapIntoIter}; +pub use map::{ + HeaderMapValueRemover, HeaderMapValueRemoverIntoIter, Http1HeaderMap, Http1HeaderMapIntoIter, +}; diff --git a/rama-http-types/src/proto/h1/headers/original.rs b/rama-http-types/src/proto/h1/headers/original.rs index 91cb7bb18..741d160e3 100644 --- a/rama-http-types/src/proto/h1/headers/original.rs +++ b/rama-http-types/src/proto/h1/headers/original.rs @@ -35,6 +35,10 @@ impl OriginalHttp1Headers { pub fn is_empty(&self) -> bool { self.ordered_headers.is_empty() } + + pub fn iter(&self) -> impl Iterator { + self.ordered_headers.iter() + } } impl OriginalHttp1Headers { diff --git a/rama-http-types/src/proto/h2/pseudo_header.rs b/rama-http-types/src/proto/h2/pseudo_header.rs index 619f02ef7..516afea8d 100644 --- a/rama-http-types/src/proto/h2/pseudo_header.rs +++ b/rama-http-types/src/proto/h2/pseudo_header.rs @@ -137,6 +137,26 @@ impl IntoIterator for PseudoHeaderOrder { } } +impl FromIterator for PseudoHeaderOrder { + fn from_iter>(iter: T) -> Self { + let mut this = Self::new(); + for header in iter { + this.push(header); + } + this + } +} + +impl<'a> FromIterator<&'a PseudoHeader> for PseudoHeaderOrder { + fn from_iter>(iter: T) -> Self { + let mut this = Self::new(); + for header in iter { + this.push(*header); + } + this + } +} + #[derive(Debug)] /// Iterator over a copy of [`PseudoHeaderOrder`]. pub struct PseudoHeaderOrderIter { diff --git a/rama-http/Cargo.toml b/rama-http/Cargo.toml index ac596aa22..18f7ebcc2 100644 --- a/rama-http/Cargo.toml +++ b/rama-http/Cargo.toml @@ -33,10 +33,6 @@ bytes = { workspace = true } const_format = { workspace = true } csv = { workspace = true } futures-lite = { workspace = true } -headers = { workspace = true } -http = { workspace = true } -http-body = { workspace = true } -http-body-util = { workspace = true } http-range-header = { workspace = true } httpdate = { workspace = true } iri-string = { workspace = true } diff --git a/rama-http/src/headers/forwarded/exotic_forward_ip.rs b/rama-http/src/headers/forwarded/exotic_forward_ip.rs index e715801fd..f48148461 100644 --- a/rama-http/src/headers/forwarded/exotic_forward_ip.rs +++ b/rama-http/src/headers/forwarded/exotic_forward_ip.rs @@ -91,7 +91,7 @@ macro_rules! exotic_forward_ip_headers { fn decode<'i, I: Iterator>( values: &mut I, - ) -> Result { + ) -> Result { Ok($name( values .next() diff --git a/rama-http/src/headers/mod.rs b/rama-http/src/headers/mod.rs index 06dc4a44d..543145c96 100644 --- a/rama-http/src/headers/mod.rs +++ b/rama-http/src/headers/mod.rs @@ -20,12 +20,13 @@ //! //! ```rust //! use rama_http::{headers::Header, HeaderName, HeaderValue}; +//! use rama_http_types::{header, headers}; //! //! struct Dnt(bool); //! //! impl Header for Dnt { //! fn name() -> &'static HeaderName { -//! &http::header::DNT +//! &header::DNT //! } //! //! fn decode<'i, I>(values: &mut I) -> Result @@ -67,7 +68,7 @@ pub use rama_http_types::headers::{Header, HeaderMapExt}; #[doc(inline)] pub use rama_http_types::headers::{ - AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders, + Accept, AcceptRanges, AccessControlAllowCredentials, AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin, AccessControlExposeHeaders, AccessControlMaxAge, AccessControlRequestHeaders, AccessControlRequestMethod, Age, Allow, Authorization, CacheControl, Connection, ContentDisposition, ContentEncoding, ContentLength, @@ -78,9 +79,23 @@ pub use rama_http_types::headers::{ StrictTransportSecurity, Te, TransferEncoding, Upgrade, UserAgent, Vary, }; -mod common; #[doc(inline)] -pub use common::Accept; +pub use rama_http_types::headers::specifier::{Quality, QualityValue}; + +#[doc(inline)] +pub use rama_http_types::headers::util; + +#[doc(inline)] +pub use rama_http_types::headers::ClientHint; + +pub mod client_hints { + //! Http (UA) Client Hints + + #[doc(inline)] + pub use rama_http_types::headers::{ + all_client_hint_header_name_strings, all_client_hint_header_names, + }; +} mod forwarded; #[doc(inline)] @@ -93,13 +108,10 @@ pub mod authorization { //! Authorization header and types. #[doc(inline)] - pub use ::headers::authorization::Credentials; + pub use rama_http_types::headers::authorization::Credentials; #[doc(inline)] - pub use ::headers::authorization::{Authorization, Basic, Bearer}; + pub use rama_http_types::headers::authorization::{Authorization, Basic, Bearer}; } #[doc(inline)] pub use ::rama_http_types::headers::HeaderExt; - -pub(crate) mod util; -pub use util::quality_value::{Quality, QualityValue}; diff --git a/rama-http/src/headers/util/mod.rs b/rama-http/src/headers/util/mod.rs deleted file mode 100644 index af107ff1b..000000000 --- a/rama-http/src/headers/util/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) mod csv; -/// Internal utility functions for headers. -pub(crate) mod quality_value; diff --git a/rama-http/src/io/request.rs b/rama-http/src/io/request.rs index d890f0e5a..81a4342ac 100644 --- a/rama-http/src/io/request.rs +++ b/rama-http/src/io/request.rs @@ -85,7 +85,7 @@ where for (name, value) in header_map { match parts.version { - http::Version::HTTP_2 | http::Version::HTTP_3 => { + rama_http_types::Version::HTTP_2 | rama_http_types::Version::HTTP_3 => { // write lower-case for H2/H3 w.write_all( format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?) diff --git a/rama-http/src/io/response.rs b/rama-http/src/io/response.rs index 260924fdf..2061befb5 100644 --- a/rama-http/src/io/response.rs +++ b/rama-http/src/io/response.rs @@ -73,7 +73,7 @@ where for (name, value) in header_map { match parts.version { - http::Version::HTTP_2 | http::Version::HTTP_3 => { + rama_http_types::Version::HTTP_2 | rama_http_types::Version::HTTP_3 => { // write lower-case for H2/H3 w.write_all( format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?) diff --git a/rama-http/src/layer/auth/add_authorization.rs b/rama-http/src/layer/auth/add_authorization.rs index d65e80667..3f549c3aa 100644 --- a/rama-http/src/layer/auth/add_authorization.rs +++ b/rama-http/src/layer/auth/add_authorization.rs @@ -287,9 +287,13 @@ where mut req: Request, ) -> Result { if let Some(value) = &self.value { - if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) { + if !self.if_not_present + || !req + .headers() + .contains_key(rama_http_types::header::AUTHORIZATION) + { req.headers_mut() - .insert(http::header::AUTHORIZATION, value.clone()); + .insert(rama_http_types::header::AUTHORIZATION, value.clone()); } } self.inner.serve(ctx, req).await @@ -344,7 +348,10 @@ mod tests { async fn making_header_sensitive() { let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn( |request: Request| async move { - let auth = request.headers().get(http::header::AUTHORIZATION).unwrap(); + let auth = request + .headers() + .get(rama_http_types::header::AUTHORIZATION) + .unwrap(); assert!(auth.is_sensitive()); Ok::<_, Infallible>(Response::new(Body::empty())) diff --git a/rama-http/src/layer/auth/async_require_authorization.rs b/rama-http/src/layer/auth/async_require_authorization.rs index 1b72abba1..52c6a8c35 100644 --- a/rama-http/src/layer/auth/async_require_authorization.rs +++ b/rama-http/src/layer/auth/async_require_authorization.rs @@ -271,7 +271,7 @@ mod tests { let authorized = request .headers() .get(header::AUTHORIZATION) - .and_then(|it: &http::HeaderValue| it.to_str().ok()) + .and_then(|it: &rama_http_types::HeaderValue| it.to_str().ok()) .and_then(|it| it.strip_prefix("Bearer ")) .map(|it| it == "69420") .unwrap_or(false); diff --git a/rama-http/src/layer/auth/require_authorization.rs b/rama-http/src/layer/auth/require_authorization.rs index d7eb1dfdf..5110d2f17 100644 --- a/rama-http/src/layer/auth/require_authorization.rs +++ b/rama-http/src/layer/auth/require_authorization.rs @@ -64,7 +64,6 @@ use rama_core::Context; use std::{fmt, marker::PhantomData, sync::Arc}; use rama_net::user::UserId; -use sealed::AuthorizerSeal; const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; @@ -283,7 +282,7 @@ mod sealed { use super::*; /// Private trait that contains the actual authorization logic - pub(super) trait AuthorizerSeal: Send + Sync + 'static { + pub trait AuthorizerSeal: Send + Sync + 'static { /// Check if the given header value is valid for this authorizer. fn is_valid(&self, header_value: &HeaderValue) -> bool; @@ -293,7 +292,7 @@ mod sealed { impl AuthorizerSeal for Basic { fn is_valid(&self, header_value: &HeaderValue) -> bool { - header_value == &self.header_value + header_value == self.header_value } fn www_authenticate_header() -> Option { @@ -303,7 +302,7 @@ mod sealed { impl AuthorizerSeal for Bearer { fn is_valid(&self, header_value: &HeaderValue) -> bool { - header_value == &self.header_value + header_value == self.header_value } fn www_authenticate_header() -> Option { diff --git a/rama-http/src/layer/body_limit.rs b/rama-http/src/layer/body_limit.rs index 574805c50..1a7af821b 100644 --- a/rama-http/src/layer/body_limit.rs +++ b/rama-http/src/layer/body_limit.rs @@ -85,7 +85,10 @@ impl Service> for BodyLimitService where S: Service>, State: Clone + Send + Sync + 'static, - ReqBody: http_body::Body> + Send + Sync + 'static, + ReqBody: rama_http_types::dep::http_body::Body> + + Send + + Sync + + 'static, { type Response = S::Response; type Error = S::Error; diff --git a/rama-http/src/layer/catch_panic.rs b/rama-http/src/layer/catch_panic.rs index 7cc97b171..8caa8d72b 100644 --- a/rama-http/src/layer/catch_panic.rs +++ b/rama-http/src/layer/catch_panic.rs @@ -280,7 +280,7 @@ impl ResponseForPanic for DefaultResponseForPanic { #[allow(clippy::declare_interior_mutable_const)] const TEXT_PLAIN: HeaderValue = HeaderValue::from_static("text/plain; charset=utf-8"); res.headers_mut() - .insert(http::header::CONTENT_TYPE, TEXT_PLAIN); + .insert(rama_http_types::header::CONTENT_TYPE, TEXT_PLAIN); res } diff --git a/rama-http/src/layer/classify/status_in_range_is_error.rs b/rama-http/src/layer/classify/status_in_range_is_error.rs index b23569da5..db3e4db20 100644 --- a/rama-http/src/layer/classify/status_in_range_is_error.rs +++ b/rama-http/src/layer/classify/status_in_range_is_error.rs @@ -53,7 +53,7 @@ impl ClassifyResponse for StatusInRangeAsFailures { fn classify_response( self, - res: &http::Response, + res: &rama_http_types::Response, ) -> ClassifiedResponse { if self.range.contains(&res.status().as_u16()) { let class = StatusInRangeFailureClass::StatusCode(res.status()); diff --git a/rama-http/src/layer/compression/body.rs b/rama-http/src/layer/compression/body.rs index 3e6161842..b90d3f1f5 100644 --- a/rama-http/src/layer/compression/body.rs +++ b/rama-http/src/layer/compression/body.rs @@ -143,11 +143,11 @@ where } } - fn size_hint(&self) -> http_body::SizeHint { + fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint { if let BodyInner::Identity { inner } = &self.inner { inner.size_hint() } else { - http_body::SizeHint::new() + rama_http_types::dep::http_body::SizeHint::new() } } diff --git a/rama-http/src/layer/compression/layer.rs b/rama-http/src/layer/compression/layer.rs index 59ec6fee4..2e98327cf 100644 --- a/rama-http/src/layer/compression/layer.rs +++ b/rama-http/src/layer/compression/layer.rs @@ -1,7 +1,8 @@ use super::predicate::DefaultPredicate; use super::{Compression, Predicate}; -use crate::layer::util::compression::{AcceptEncoding, CompressionLevel}; +use crate::layer::util::compression::CompressionLevel; use rama_core::Layer; +use rama_http_types::headers::encoding::AcceptEncoding; /// Compress response bodies of the underlying service. /// diff --git a/rama-http/src/layer/compression/mod.rs b/rama-http/src/layer/compression/mod.rs index 2d8c66f97..c44efb2d1 100644 --- a/rama-http/src/layer/compression/mod.rs +++ b/rama-http/src/layer/compression/mod.rs @@ -117,9 +117,9 @@ mod tests { struct Always; impl Predicate for Always { - fn should_compress(&self, _: &http::Response) -> bool + fn should_compress(&self, _: &rama_http_types::Response) -> bool where - B: http_body::Body, + B: rama_http_types::dep::http_body::Body, { true } @@ -295,9 +295,9 @@ mod tests { #[allow(clippy::dbg_macro)] impl Predicate for EveryOtherResponse { - fn should_compress(&self, _: &http::Response) -> bool + fn should_compress(&self, _: &rama_http_types::Response) -> bool where - B: http_body::Body, + B: rama_http_types::dep::http_body::Body, { let mut guard = self.0.write().unwrap(); let should_compress = *guard % 2 != 0; diff --git a/rama-http/src/layer/compression/predicate.rs b/rama-http/src/layer/compression/predicate.rs index 967fc5087..505c225c8 100644 --- a/rama-http/src/layer/compression/predicate.rs +++ b/rama-http/src/layer/compression/predicate.rs @@ -13,7 +13,7 @@ use std::{fmt, sync::Arc}; /// Predicate used to determine if a response should be compressed or not. pub trait Predicate: Clone { /// Should this response be compressed or not? - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body; @@ -36,7 +36,7 @@ impl Predicate for F where F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone, { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -52,7 +52,7 @@ impl Predicate for Option where T: Predicate, { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -76,7 +76,7 @@ where Lhs: Predicate, Rhs: Predicate, { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -136,7 +136,7 @@ impl Default for DefaultPredicate { } impl Predicate for DefaultPredicate { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -168,7 +168,7 @@ impl Default for SizeAbove { } impl Predicate for SizeAbove { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -225,7 +225,7 @@ impl NotForContentType { } impl Predicate for NotForContentType { - fn should_compress(&self, response: &http::Response) -> bool + fn should_compress(&self, response: &rama_http_types::Response) -> bool where B: Body, { @@ -263,7 +263,7 @@ impl fmt::Debug for Str { } } -fn content_type(response: &http::Response) -> &str { +fn content_type(response: &rama_http_types::Response) -> &str { response .headers() .get(header::CONTENT_TYPE) diff --git a/rama-http/src/layer/compression/service.rs b/rama-http/src/layer/compression/service.rs index 8d48fa36a..c3bfe0162 100644 --- a/rama-http/src/layer/compression/service.rs +++ b/rama-http/src/layer/compression/service.rs @@ -4,9 +4,10 @@ use super::body::BodyInner; use super::predicate::{DefaultPredicate, Predicate}; use crate::dep::http_body::Body; use crate::layer::util::compression::WrapBody; -use crate::layer::util::{compression::AcceptEncoding, content_encoding::Encoding}; use crate::{Request, Response, header}; use rama_core::{Context, Service}; +use rama_http_types::HeaderValue; +use rama_http_types::headers::encoding::{AcceptEncoding, Encoding}; use rama_utils::macros::define_inner_service_accessors; /// Compress response bodies of the underlying service. @@ -192,7 +193,7 @@ where ctx: Context, req: Request, ) -> Result { - let encoding = Encoding::from_headers(req.headers(), self.accept); + let encoding = Encoding::from_accept_encoding_headers(req.headers(), self.accept); let res = self.inner.serve(ctx, req).await?; @@ -259,7 +260,7 @@ where parts .headers - .insert(header::CONTENT_ENCODING, encoding.into_header_value()); + .insert(header::CONTENT_ENCODING, HeaderValue::from(encoding)); let res = Response::from_parts(parts, body); Ok(res) diff --git a/rama-http/src/layer/cors/tests.rs b/rama-http/src/layer/cors/tests.rs index 312fd2563..2642f34a4 100644 --- a/rama-http/src/layer/cors/tests.rs +++ b/rama-http/src/layer/cors/tests.rs @@ -57,7 +57,7 @@ async fn test_allow_origin_async_predicate() { }); let valid_origin = HeaderValue::from_static("http://example.com"); - let parts = http::Request::new("hello world").into_parts().0; + let parts = rama_http_types::Request::new("hello world").into_parts().0; let header = allow_origin .to_future(Some(&valid_origin), &parts) @@ -67,7 +67,7 @@ async fn test_allow_origin_async_predicate() { assert_eq!(header.1, valid_origin); let invalid_origin = HeaderValue::from_static("http://example.org"); - let parts = http::Request::new("hello world").into_parts().0; + let parts = rama_http_types::Request::new("hello world").into_parts().0; let res = allow_origin.to_future(Some(&invalid_origin), &parts).await; assert!(res.is_none()); diff --git a/rama-http/src/layer/decompression/layer.rs b/rama-http/src/layer/decompression/layer.rs index 775e39096..df90d2348 100644 --- a/rama-http/src/layer/decompression/layer.rs +++ b/rama-http/src/layer/decompression/layer.rs @@ -1,6 +1,6 @@ use super::Decompression; -use crate::layer::util::compression::AcceptEncoding; use rama_core::Layer; +use rama_http_types::headers::encoding::AcceptEncoding; /// Decompresses response bodies of the underlying service. /// @@ -11,6 +11,7 @@ use rama_core::Layer; #[derive(Debug, Default, Clone)] pub struct DecompressionLayer { accept: AcceptEncoding, + only_if_requested: bool, } impl Layer for DecompressionLayer { @@ -20,6 +21,7 @@ impl Layer for DecompressionLayer { Decompression { inner: service, accept: self.accept, + only_if_requested: self.only_if_requested, } } } @@ -77,4 +79,22 @@ impl DecompressionLayer { self.accept.set_zstd(enable); self } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`rama_http_types::compression::DecompressIfPossible`] marker type. + pub fn only_if_requested(mut self, enable: bool) -> Self { + self.only_if_requested = enable; + self + } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`rama_http_types::compression::DecompressIfPossible`] marker type. + pub fn set_only_if_requested(&mut self, enable: bool) -> &mut Self { + self.only_if_requested = enable; + self + } } diff --git a/rama-http/src/layer/decompression/request/layer.rs b/rama-http/src/layer/decompression/request/layer.rs index 985e8e629..6e1b97e5e 100644 --- a/rama-http/src/layer/decompression/request/layer.rs +++ b/rama-http/src/layer/decompression/request/layer.rs @@ -1,6 +1,6 @@ use super::service::RequestDecompression; -use crate::layer::util::compression::AcceptEncoding; use rama_core::Layer; +use rama_http_types::headers::encoding::AcceptEncoding; /// Decompresses request bodies and calls its underlying service. /// diff --git a/rama-http/src/layer/decompression/request/service.rs b/rama-http/src/layer/decompression/request/service.rs index be2aa2fef..2e35d54cd 100644 --- a/rama-http/src/layer/decompression/request/service.rs +++ b/rama-http/src/layer/decompression/request/service.rs @@ -5,13 +5,13 @@ use crate::dep::http_body_util::{BodyExt, Empty, combinators::UnsyncBoxBody}; use crate::layer::{ decompression::DecompressionBody, decompression::body::BodyInner, - util::compression::{AcceptEncoding, CompressionLevel, WrapBody}, - util::content_encoding::SupportedEncodings, + util::compression::{CompressionLevel, WrapBody}, }; use crate::{HeaderValue, Request, Response, StatusCode, header}; use bytes::Buf; use rama_core::error::BoxError; use rama_core::{Context, Service}; +use rama_http_types::headers::encoding::{AcceptEncoding, SupportedEncodings}; use rama_utils::macros::define_inner_service_accessors; /// Decompresses request bodies and calls its underlying service. @@ -124,7 +124,7 @@ where .header( header::ACCEPT_ENCODING, accept - .to_header_value() + .maybe_to_header_value() .unwrap_or(HeaderValue::from_static("identity")), ) .status(StatusCode::UNSUPPORTED_MEDIA_TYPE) diff --git a/rama-http/src/layer/decompression/service.rs b/rama-http/src/layer/decompression/service.rs index cd4031590..d3c23c02c 100644 --- a/rama-http/src/layer/decompression/service.rs +++ b/rama-http/src/layer/decompression/service.rs @@ -2,15 +2,14 @@ use std::fmt; use super::{DecompressionBody, body::BodyInner}; use crate::dep::http_body::Body; -use crate::layer::util::{ - compression::{AcceptEncoding, CompressionLevel, WrapBody}, - content_encoding::SupportedEncodings, -}; +use crate::layer::util::compression::{CompressionLevel, WrapBody}; use crate::{ Request, Response, header::{self, ACCEPT_ENCODING}, }; use rama_core::{Context, Service}; +use rama_http_types::compression::DecompressIfPossible; +use rama_http_types::headers::encoding::{AcceptEncoding, SupportedEncodings}; use rama_utils::macros::define_inner_service_accessors; /// Decompresses response bodies of the underlying service. @@ -22,6 +21,7 @@ use rama_utils::macros::define_inner_service_accessors; pub struct Decompression { pub(crate) inner: S, pub(crate) accept: AcceptEncoding, + pub(crate) only_if_requested: bool, } impl Decompression { @@ -30,6 +30,7 @@ impl Decompression { Self { inner: service, accept: AcceptEncoding::default(), + only_if_requested: false, } } @@ -82,6 +83,24 @@ impl Decompression { self.accept.set_zstd(enable); self } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`DecompressIfPossible`] marker type. + pub fn only_if_requested(mut self, enable: bool) -> Self { + self.only_if_requested = enable; + self + } + + /// Sets whether to only decompress bodies if it is requested + /// via the response extension or request context. + /// + /// A request is made using the [`DecompressIfPossible`] marker type. + pub fn set_only_if_requested(&mut self, enable: bool) -> &mut Self { + self.only_if_requested = enable; + self + } } impl fmt::Debug for Decompression { @@ -89,6 +108,7 @@ impl fmt::Debug for Decompression { f.debug_struct("Decompression") .field("inner", &self.inner) .field("accept", &self.accept) + .field("only_if_requested", &self.only_if_requested) .finish() } } @@ -98,6 +118,7 @@ impl Clone for Decompression { Decompression { inner: self.inner.clone(), accept: self.accept, + only_if_requested: self.only_if_requested, } } } @@ -118,15 +139,27 @@ where mut req: Request, ) -> Result { if let header::Entry::Vacant(entry) = req.headers_mut().entry(ACCEPT_ENCODING) { - if let Some(accept) = self.accept.to_header_value() { + if let Some(accept) = self.accept.maybe_to_header_value() { entry.insert(accept); } } + let decompression_requested = ctx.contains::(); + let res = self.inner.serve(ctx, req).await?; let (mut parts, body) = res.into_parts(); + if self.only_if_requested + && !(decompression_requested + || parts.extensions.get::().is_some()) + { + return Ok(Response::from_parts( + parts, + DecompressionBody::new(BodyInner::identity(body)), + )); + } + let res = if let header::Entry::Occupied(entry) = parts.headers.entry(header::CONTENT_ENCODING) { let body = match entry.get().as_bytes() { diff --git a/rama-http/src/layer/opentelemetry.rs b/rama-http/src/layer/opentelemetry.rs index 289ff56f0..24ce3fcf6 100644 --- a/rama-http/src/layer/opentelemetry.rs +++ b/rama-http/src/layer/opentelemetry.rs @@ -267,11 +267,11 @@ impl RequestMetricsService { attributes.push(KeyValue::new(HTTP_REQUEST_METHOD, req.method().to_string())); if let Some(http_version) = request_ctx.as_ref().and_then(|rc| match rc.http_version { - http::Version::HTTP_09 => Some("0.9"), - http::Version::HTTP_10 => Some("1.0"), - http::Version::HTTP_11 => Some("1.1"), - http::Version::HTTP_2 => Some("2"), - http::Version::HTTP_3 => Some("3"), + rama_http_types::Version::HTTP_09 => Some("0.9"), + rama_http_types::Version::HTTP_10 => Some("1.0"), + rama_http_types::Version::HTTP_11 => Some("1.1"), + rama_http_types::Version::HTTP_2 => Some("2"), + rama_http_types::Version::HTTP_3 => Some("3"), _ => None, }) { attributes.push(KeyValue::new(NETWORK_PROTOCOL_VERSION, http_version)); diff --git a/rama-http/src/layer/required_header/request.rs b/rama-http/src/layer/required_header/request.rs index 748b0de3e..4ce8db2f1 100644 --- a/rama-http/src/layer/required_header/request.rs +++ b/rama-http/src/layer/required_header/request.rs @@ -228,7 +228,7 @@ mod test { |_ctx: Context<()>, req: Request| async move { assert!(req.headers().contains_key(HOST)); assert!(req.headers().contains_key(USER_AGENT)); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) }, )); @@ -252,7 +252,7 @@ mod test { req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()), Some("foo") ); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) })); let req = Request::builder() @@ -275,7 +275,7 @@ mod test { req.headers().get(USER_AGENT).unwrap(), RAMA_ID_HEADER_VALUE.to_str().unwrap() ); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) })); let req = Request::builder() @@ -302,7 +302,7 @@ mod test { req.headers().get(USER_AGENT).and_then(|v| v.to_str().ok()), Some("foo") ); - Ok::<_, Infallible>(http::Response::new(Body::empty())) + Ok::<_, Infallible>(rama_http_types::Response::new(Body::empty())) })); let req = Request::builder() diff --git a/rama-http/src/layer/retry/body.rs b/rama-http/src/layer/retry/body.rs index 5eefa01e7..47d0ff15e 100644 --- a/rama-http/src/layer/retry/body.rs +++ b/rama-http/src/layer/retry/body.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use rama_http_types::dep::http_body; #[derive(Debug, Clone)] /// A body that can be clone and used for requests that have to be rertried. diff --git a/rama-http/src/layer/retry/policy.rs b/rama-http/src/layer/retry/policy.rs index 859e69e87..fcb00d17d 100644 --- a/rama-http/src/layer/retry/policy.rs +++ b/rama-http/src/layer/retry/policy.rs @@ -191,7 +191,7 @@ macro_rules! impl_retry_policy_either { async fn retry( &self, ctx: Context, - req: http::Request, + req: rama_http_types::Request, result: Result, ) -> PolicyResult { match self { @@ -204,8 +204,8 @@ macro_rules! impl_retry_policy_either { fn clone_input( &self, ctx: &Context, - req: &http::Request, - ) -> Option<(Context, http::Request)> { + req: &rama_http_types::Request, + ) -> Option<(Context, rama_http_types::Request)> { match self { $( rama_core::combinators::$id::$param(policy) => policy.clone_input(ctx, req), diff --git a/rama-http/src/layer/trace/body.rs b/rama-http/src/layer/trace/body.rs index 1a9918601..9793807c1 100644 --- a/rama-http/src/layer/trace/body.rs +++ b/rama-http/src/layer/trace/body.rs @@ -96,7 +96,7 @@ where self.inner.is_end_stream() } - fn size_hint(&self) -> http_body::SizeHint { + fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint { self.inner.size_hint() } } diff --git a/rama-http/src/layer/trace/on_response.rs b/rama-http/src/layer/trace/on_response.rs index f8941e35e..fc44faa1e 100644 --- a/rama-http/src/layer/trace/on_response.rs +++ b/rama-http/src/layer/trace/on_response.rs @@ -171,7 +171,7 @@ fn status(res: &Response) -> Option { // For simplicity, we simply check that the content type starts with "application/grpc". let is_grpc = res .headers() - .get(http::header::CONTENT_TYPE) + .get(rama_http_types::header::CONTENT_TYPE) .is_some_and(|value| value.as_bytes().starts_with("application/grpc".as_bytes())); if is_grpc { diff --git a/rama-http/src/layer/traffic_writer/response.rs b/rama-http/src/layer/traffic_writer/response.rs index 76b3d0dd2..18564677b 100644 --- a/rama-http/src/layer/traffic_writer/response.rs +++ b/rama-http/src/layer/traffic_writer/response.rs @@ -287,7 +287,7 @@ where .map_err(|err| OpaqueError::from_boxed(err.into())) .context("printer prepare: collect response body")? .to_bytes(); - let resp: http::Response = + let resp: rama_http_types::Response = Response::from_parts(parts.clone(), Body::from(body_bytes.clone())); self.writer.write_response(resp).await; Response::from_parts(parts, Body::from(body_bytes)) diff --git a/rama-http/src/layer/ua.rs b/rama-http/src/layer/ua.rs index 9d3bfd620..84896ff77 100644 --- a/rama-http/src/layer/ua.rs +++ b/rama-http/src/layer/ua.rs @@ -28,7 +28,7 @@ //! //! let _ = service //! .get("http://www.example.com") -//! .typed_header(headers::UserAgent::from_static(UA)) +//! .typed_header(rama_http_types::headers::UserAgent::from_static(UA)) //! .send(Context::default()) //! .await //! .unwrap(); @@ -123,35 +123,42 @@ where mut ctx: Context, req: Request, ) -> impl Future> + Send + '_ { - let mut user_agent = req - .headers() - .typed_get::() - .map(|ua| UserAgent::new(ua.to_string())); - - if let Some(overwrites) = self + let overwrites = self .overwrite_header .as_ref() .and_then(|header| req.headers().get(header)) .map(|header| header.as_bytes()) - .and_then(|value| serde_html_form::from_bytes::(value).ok()) - { - if let Some(ua) = overwrites.ua { - user_agent = Some(UserAgent::new(ua)); - } - if let Some(ref mut ua) = user_agent { + .and_then(|value| serde_html_form::from_bytes::(value).ok()); + + let mut user_agent = overwrites + .as_ref() + .and_then(|o| o.ua.as_deref()) + .map(UserAgent::new) + .or_else(|| { + req.headers() + .typed_get::() + .map(|ua| UserAgent::new(ua.to_string())) + }); + + if let Some(mut ua) = user_agent.take() { + if let Some(overwrites) = overwrites { if let Some(http_agent) = overwrites.http { - ua.with_http_agent(http_agent); + ua.set_http_agent(http_agent); } if let Some(tls_agent) = overwrites.tls { - ua.with_tls_agent(tls_agent); + ua.set_tls_agent(tls_agent); } if let Some(preserve_ua) = overwrites.preserve_ua { - ua.with_preserve_ua_header(preserve_ua); + ua.set_preserve_ua_header(preserve_ua); + } + if let Some(req_init) = overwrites.req_init { + ua.set_request_initiator(req_init); + } + if let Some(req_client_hints) = overwrites.req_client_hints { + ua.set_requested_client_hints(req_client_hints); } } - } - if let Some(ua) = user_agent.take() { ctx.insert(ua); } @@ -204,8 +211,11 @@ mod tests { use crate::layer::required_header::AddRequiredRequestHeadersLayer; use crate::service::client::HttpClientExt; use crate::{IntoResponse, Response, StatusCode, headers}; + use itertools::Itertools; use rama_core::Context; use rama_core::service::service_fn; + use rama_http_types::headers::ClientHint; + use rama_ua::RequestInitiator; use std::convert::Infallible; #[tokio::test] @@ -236,6 +246,34 @@ mod tests { .unwrap(); } + #[tokio::test] + async fn test_user_agent_classifier_layer_ua_iphone_app() { + const UA: &str = "iPhone App/1.0"; + + async fn handle(ctx: Context, _req: Request) -> Result { + let ua: &UserAgent = ctx.get().unwrap(); + + assert_eq!(ua.header_str(), UA); + assert!(ua.info().is_none()); + assert_eq!(ua.platform(), Some(PlatformKind::IOS)); + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); + assert!(!ua.preserve_ua_header()); + assert!(ua.request_initiator().is_none()); + + Ok(StatusCode::OK.into_response()) + } + + let service = UserAgentClassifierLayer::new().layer(service_fn(handle)); + + let _ = service + .get("http://www.example.com") + .typed_header(headers::UserAgent::from_static(UA)) + .send(Context::default()) + .await + .unwrap(); + } + #[tokio::test] async fn test_user_agent_classifier_layer_ua_chrome() { const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67"; @@ -274,6 +312,10 @@ mod tests { assert_eq!(ua_info.kind, UserAgentKind::Chromium); assert_eq!(ua_info.version, Some(124)); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); + assert_eq!( + ua.requested_client_hints().join(", "), + "sec-ch-downlink, sec-ch-ect" + ); Ok(StatusCode::OK.into_response()) } @@ -288,6 +330,7 @@ mod tests { "x-proxy-ua", serde_html_form::to_string(&UserAgentOverwrites { ua: Some(UA.to_owned()), + req_client_hints: Some(vec![ClientHint::Downlink, ClientHint::Ect]), ..Default::default() }) .unwrap(), @@ -306,10 +349,11 @@ mod tests { assert_eq!(ua.header_str(), UA); assert!(ua.info().is_none()); - assert!(ua.platform().is_none()); - assert_eq!(ua.http_agent(), HttpAgent::Safari); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + assert_eq!(ua.platform(), Some(PlatformKind::IOS)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); assert!(ua.preserve_ua_header()); + assert_eq!(ua.request_initiator(), Some(RequestInitiator::Xhr)); Ok(StatusCode::OK.into_response()) } @@ -324,9 +368,11 @@ mod tests { "x-proxy-ua", serde_html_form::to_string(&UserAgentOverwrites { ua: Some(UA.to_owned()), - http: Some(HttpAgent::Safari), + http: Some(HttpAgent::Firefox), tls: Some(TlsAgent::Boringssl), preserve_ua: Some(true), + req_init: Some(RequestInitiator::Xhr), + req_client_hints: Some(vec![ClientHint::Downlink]), }) .unwrap(), ) diff --git a/rama-http/src/layer/util/compression.rs b/rama-http/src/layer/util/compression.rs index a46978885..780f9d78e 100644 --- a/rama-http/src/layer/util/compression.rs +++ b/rama-http/src/layer/util/compression.rs @@ -1,7 +1,5 @@ //! Types used by compression and decompression middleware. -use super::content_encoding::SupportedEncodings; -use crate::HeaderValue; use crate::dep::http_body::{Body, Frame}; use bytes::{Buf, Bytes, BytesMut}; use futures_lite::Stream; @@ -16,88 +14,6 @@ use std::{ use tokio::io::AsyncRead; use tokio_util::io::StreamReader; -#[derive(Debug, Clone, Copy)] -pub(crate) struct AcceptEncoding { - pub(crate) gzip: bool, - pub(crate) deflate: bool, - pub(crate) br: bool, - pub(crate) zstd: bool, -} - -impl AcceptEncoding { - #[allow(dead_code)] - pub(crate) fn to_header_value(self) -> Option { - let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) { - (true, true, true, false) => "gzip,deflate,br", - (true, true, false, false) => "gzip,deflate", - (true, false, true, false) => "gzip,br", - (true, false, false, false) => "gzip", - (false, true, true, false) => "deflate,br", - (false, true, false, false) => "deflate", - (false, false, true, false) => "br", - (true, true, true, true) => "zstd,gzip,deflate,br", - (true, true, false, true) => "zstd,gzip,deflate", - (true, false, true, true) => "zstd,gzip,br", - (true, false, false, true) => "zstd,gzip", - (false, true, true, true) => "zstd,deflate,br", - (false, true, false, true) => "zstd,deflate", - (false, false, true, true) => "zstd,br", - (false, false, false, true) => "zstd", - (false, false, false, false) => return None, - }; - Some(HeaderValue::from_static(accept)) - } - - #[allow(dead_code)] - pub(crate) fn set_gzip(&mut self, enable: bool) { - self.gzip = enable; - } - - #[allow(dead_code)] - pub(crate) fn set_deflate(&mut self, enable: bool) { - self.deflate = enable; - } - - #[allow(dead_code)] - pub(crate) fn set_br(&mut self, enable: bool) { - self.br = enable; - } - - #[allow(dead_code)] - pub(crate) fn set_zstd(&mut self, enable: bool) { - self.zstd = enable; - } -} - -impl SupportedEncodings for AcceptEncoding { - fn gzip(&self) -> bool { - self.gzip - } - - fn deflate(&self) -> bool { - self.deflate - } - - fn br(&self) -> bool { - self.br - } - - fn zstd(&self) -> bool { - self.zstd - } -} - -impl Default for AcceptEncoding { - fn default() -> Self { - AcceptEncoding { - gzip: true, - deflate: true, - br: true, - zstd: true, - } - } -} - /// A `Body` that has been converted into an `AsyncRead`. pub(crate) type AsyncReadBody = StreamReader, ::Error>, ::Data>; @@ -329,7 +245,7 @@ where } #[inline] - fn size_hint(&self) -> http_body::SizeHint { + fn size_hint(&self) -> rama_http_types::dep::http_body::SizeHint { self.body.size_hint() } } diff --git a/rama-http/src/layer/util/mod.rs b/rama-http/src/layer/util/mod.rs index 1fdeba4e7..b55fcaa15 100644 --- a/rama-http/src/layer/util/mod.rs +++ b/rama-http/src/layer/util/mod.rs @@ -2,5 +2,3 @@ #[cfg(feature = "compression")] pub(crate) mod compression; - -pub(crate) mod content_encoding; diff --git a/rama-http/src/matcher/mod.rs b/rama-http/src/matcher/mod.rs index b9d191ef8..42b6fd1ea 100644 --- a/rama-http/src/matcher/mod.rs +++ b/rama-http/src/matcher/mod.rs @@ -1,9 +1,9 @@ -//! [`service::Matcher`]s implementations to match on [`http::Request`]s. +//! [`service::Matcher`]s implementations to match on [`rama_http_types::Request`]s. //! //! See [`service::matcher` module] for more information. //! //! [`service::Matcher`]: rama_core::matcher::Matcher -//! [`http::Request`]: crate::Request +//! [`rama_http_types::Request`]: crate::Request //! [`service::matcher` module]: rama_core use crate::Request; use rama_core::{Context, context::Extensions, matcher::IteratorMatcherExt}; @@ -455,7 +455,10 @@ impl HttpMatcher { } /// Create a [`HeaderMatcher`] matcher. - pub fn header(name: http::header::HeaderName, value: http::header::HeaderValue) -> Self { + pub fn header( + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, + ) -> Self { Self { kind: HttpMatcherKind::Header(HeaderMatcher::is(name, value)), negate: false, @@ -467,8 +470,8 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn and_header( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.and(Self::header(name, value)) } @@ -478,15 +481,15 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn or_header( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.or(Self::header(name, value)) } /// Create a [`HeaderMatcher`] matcher when the given header exists /// to match on the existence of a header. - pub fn header_exists(name: http::header::HeaderName) -> Self { + pub fn header_exists(name: rama_http_types::header::HeaderName) -> Self { Self { kind: HttpMatcherKind::Header(HeaderMatcher::exists(name)), negate: false, @@ -497,7 +500,7 @@ impl HttpMatcher { /// on top of the existing set of [`HttpMatcher`] matchers. /// /// See [`HeaderMatcher`] for more information. - pub fn and_header_exists(self, name: http::header::HeaderName) -> Self { + pub fn and_header_exists(self, name: rama_http_types::header::HeaderName) -> Self { self.and(Self::header_exists(name)) } @@ -505,14 +508,14 @@ impl HttpMatcher { /// as an alternative to the existing set of [`HttpMatcher`] matchers. /// /// See [`HeaderMatcher`] for more information. - pub fn or_header_exists(self, name: http::header::HeaderName) -> Self { + pub fn or_header_exists(self, name: rama_http_types::header::HeaderName) -> Self { self.or(Self::header_exists(name)) } /// Create a [`HeaderMatcher`] matcher to match on it containing the given value. pub fn header_contains( - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { Self { kind: HttpMatcherKind::Header(HeaderMatcher::contains(name, value)), @@ -526,8 +529,8 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn and_header_contains( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.and(Self::header_contains(name, value)) } @@ -538,8 +541,8 @@ impl HttpMatcher { /// See [`HeaderMatcher`] for more information. pub fn or_header_contains( self, - name: http::header::HeaderName, - value: http::header::HeaderValue, + name: rama_http_types::header::HeaderName, + value: rama_http_types::header::HeaderValue, ) -> Self { self.or(Self::header_contains(name, value)) } diff --git a/rama-http/src/service/fs/mod.rs b/rama-http/src/service/fs/mod.rs index 8663f7a4d..1a8dc83d4 100644 --- a/rama-http/src/service/fs/mod.rs +++ b/rama-http/src/service/fs/mod.rs @@ -2,8 +2,8 @@ use bytes::Bytes; use futures_lite::Stream; -use http_body::{Body, Frame}; use pin_project_lite::pin_project; +use rama_http_types::dep::http_body::{Body, Frame}; use std::{ fmt, io, pin::Pin, diff --git a/rama-http/src/service/fs/serve_dir/future.rs b/rama-http/src/service/fs/serve_dir/future.rs index 142bddda0..128c97156 100644 --- a/rama-http/src/service/fs/serve_dir/future.rs +++ b/rama-http/src/service/fs/serve_dir/future.rs @@ -3,11 +3,12 @@ use crate::{ Body, HeaderValue, Request, Response, StatusCode, dep::http_body_util::BodyExt, header::{self, ALLOW}, - layer::util::content_encoding::Encoding, service::fs::AsyncReadBody, }; use bytes::Bytes; use rama_core::{Context, Service, error::BoxError}; +use rama_http_types::dep::http_body; +use rama_http_types::headers::encoding::Encoding; use std::{convert::Infallible, io}; pub(super) async fn consume_open_file_result( @@ -25,7 +26,8 @@ where Ok(OpenFileOutput::Redirect { location }) => { let mut res = response_with_status(StatusCode::TEMPORARY_REDIRECT); - res.headers_mut().insert(http::header::LOCATION, location); + res.headers_mut() + .insert(rama_http_types::header::LOCATION, location); Ok(res) } @@ -123,7 +125,7 @@ fn build_response(output: FileOpened) -> Response { .maybe_encoding .filter(|encoding| *encoding != Encoding::Identity) { - builder = builder.header(header::CONTENT_ENCODING, encoding.into_header_value()); + builder = builder.header(header::CONTENT_ENCODING, HeaderValue::from(encoding)); } if let Some(last_modified) = output.last_modified { diff --git a/rama-http/src/service/fs/serve_dir/mod.rs b/rama-http/src/service/fs/serve_dir/mod.rs index 0db993aad..4981ca95c 100644 --- a/rama-http/src/service/fs/serve_dir/mod.rs +++ b/rama-http/src/service/fs/serve_dir/mod.rs @@ -1,13 +1,11 @@ -use crate::dep::http_body::Body as HttpBody; -use crate::layer::{ - set_status::SetStatus, - util::content_encoding::{SupportedEncodings, encodings}, -}; +use crate::dep::http_body::{self, Body as HttpBody}; +use crate::layer::set_status::SetStatus; use crate::{Body, HeaderValue, Method, Request, Response, StatusCode, header}; use bytes::Bytes; use percent_encoding::percent_decode; use rama_core::error::BoxError; use rama_core::{Context, Service}; +use rama_http_types::headers::encoding::{SupportedEncodings, parse_accept_encoding_headers}; use std::{ convert::Infallible, path::{Component, Path, PathBuf}, @@ -536,7 +534,7 @@ impl ServeDir { .and_then(|value| value.to_str().ok()) .map(|s| s.to_owned()); - let negotiated_encodings: Vec<_> = encodings( + let negotiated_encodings: Vec<_> = parse_accept_encoding_headers( req.headers(), self.precompressed_variants.unwrap_or_default(), ) diff --git a/rama-http/src/service/fs/serve_dir/open_file.rs b/rama-http/src/service/fs/serve_dir/open_file.rs index 198703bc3..6d92cf371 100644 --- a/rama-http/src/service/fs/serve_dir/open_file.rs +++ b/rama-http/src/service/fs/serve_dir/open_file.rs @@ -2,9 +2,9 @@ use super::{ ServeVariant, headers::{IfModifiedSince, IfUnmodifiedSince, LastModified}, }; -use crate::layer::util::content_encoding::{Encoding, QValue}; use crate::{HeaderValue, Method, Request, Uri, header}; use http_range_header::RangeUnsatisfiableError; +use rama_http_types::headers::{encoding::Encoding, specifier::QualityValue}; use std::{ ffi::OsStr, fs::Metadata, @@ -40,7 +40,7 @@ pub(super) async fn open_file( variant: ServeVariant, mut path_to_file: PathBuf, req: Request, - negotiated_encodings: Vec<(Encoding, QValue)>, + negotiated_encodings: Vec>, range_header: Option, buf_chunk_size: usize, ) -> io::Result { @@ -172,9 +172,10 @@ fn check_modified_headers( // to the corresponding file extension for the encoding. fn preferred_encoding( path: &mut PathBuf, - negotiated_encoding: &[(Encoding, QValue)], + negotiated_encoding: &[QualityValue], ) -> Option { - let preferred_encoding = Encoding::preferred_encoding(negotiated_encoding.iter().copied()); + let preferred_encoding = + Encoding::maybe_preferred_encoding(negotiated_encoding.iter().copied()); if let Some(file_extension) = preferred_encoding.and_then(|encoding| encoding.to_file_extension()) @@ -199,7 +200,7 @@ fn preferred_encoding( // file the uncompressed file is used as a fallback. async fn open_file_with_fallback( mut path: PathBuf, - mut negotiated_encoding: Vec<(Encoding, QValue)>, + mut negotiated_encoding: Vec>, ) -> io::Result<(File, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. @@ -211,8 +212,7 @@ async fn open_file_with_fallback( // to reset the path before the next iteration. path.set_extension(OsStr::new("")); // Remove the encoding from the negotiated_encodings since the file doesn't exist - negotiated_encoding - .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); + negotiated_encoding.retain(|qv| qv.value != encoding); continue; } (Err(err), _) => return Err(err), @@ -226,7 +226,7 @@ async fn open_file_with_fallback( // file the uncompressed file is used as a fallback. async fn file_metadata_with_fallback( mut path: PathBuf, - mut negotiated_encoding: Vec<(Encoding, QValue)>, + mut negotiated_encoding: Vec>, ) -> io::Result<(Metadata, Option)> { let (file, encoding) = loop { // Get the preferred encoding among the negotiated ones. @@ -238,8 +238,7 @@ async fn file_metadata_with_fallback( // to reset the path before the next iteration. path.set_extension(OsStr::new("")); // Remove the encoding from the negotiated_encodings since the file doesn't exist - negotiated_encoding - .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); + negotiated_encoding.retain(|qv| qv.value != encoding); continue; } (Err(err), _) => return Err(err), @@ -288,7 +287,7 @@ async fn is_dir(path_to_file: &Path) -> bool { } fn append_slash_on_path(uri: Uri) -> Uri { - let http::uri::Parts { + let rama_http_types::dep::http::uri::Parts { scheme, authority, path_and_query, diff --git a/rama-http/src/service/fs/serve_dir/tests.rs b/rama-http/src/service/fs/serve_dir/tests.rs index 9f313f319..d1e036c9d 100644 --- a/rama-http/src/service/fs/serve_dir/tests.rs +++ b/rama-http/src/service/fs/serve_dir/tests.rs @@ -395,7 +395,7 @@ async fn redirect_to_trailing_slash_on_dir() { assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT); - let location = &res.headers()[http::header::LOCATION]; + let location = &res.headers()[rama_http_types::header::LOCATION]; assert_eq!(location, "/src/"); } diff --git a/rama-http/src/service/fs/serve_file.rs b/rama-http/src/service/fs/serve_file.rs index 77220613a..3c0c45b0d 100644 --- a/rama-http/src/service/fs/serve_file.rs +++ b/rama-http/src/service/fs/serve_file.rs @@ -148,7 +148,7 @@ mod compression_tests { #[cfg(feature = "compression")] async fn precompressed_zstd() { use async_compression::tokio::bufread::ZstdDecoder; - use http_body_util::BodyExt; + use rama_http_types::dep::http_body_util::BodyExt; use tokio::io::AsyncReadExt; let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_zstd(); diff --git a/rama-http/src/service/web/endpoint/extract/authority.rs b/rama-http/src/service/web/endpoint/extract/authority.rs index 86d69224d..2ff6b7688 100644 --- a/rama-http/src/service/web/endpoint/extract/authority.rs +++ b/rama-http/src/service/web/endpoint/extract/authority.rs @@ -1,7 +1,7 @@ use super::FromRequestContextRefPair; -use crate::dep::http::request::Parts; use crate::utils::macros::define_http_rejection; use rama_core::Context; +use rama_http_types::dep::http::request::Parts; use rama_net::address; use rama_net::http::RequestContext; use rama_utils::macros::impl_deref; @@ -80,7 +80,7 @@ mod tests { test_authority_from_request( "/", "some-domain:123", - vec![(&http::header::HOST, "some-domain:123")], + vec![(&rama_http_types::header::HOST, "some-domain:123")], ) .await; } @@ -102,7 +102,7 @@ mod tests { "some-domain:456", vec![ (&X_FORWARDED_HOST, "some-domain:456"), - (&http::header::HOST, "some-domain:123"), + (&rama_http_types::header::HOST, "some-domain:123"), ], ) .await; diff --git a/rama-http/src/service/web/endpoint/extract/body/csv.rs b/rama-http/src/service/web/endpoint/extract/body/csv.rs index f37836fe8..01353cf86 100644 --- a/rama-http/src/service/web/endpoint/extract/body/csv.rs +++ b/rama-http/src/service/web/endpoint/extract/body/csv.rs @@ -100,9 +100,12 @@ mod test { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/csv; charset=utf-8") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header( + rama_http_types::header::CONTENT_TYPE, + "text/csv; charset=utf-8", + ) .body("name,age,alive\nglen,42,\nadr,40,true\n".into()) .unwrap(); let resp = service.serve(Context::default(), req).await.unwrap(); @@ -122,9 +125,9 @@ mod test { let service = WebService::default() .post("/", |Csv(_): Csv>| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/plain") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header(rama_http_types::header::CONTENT_TYPE, "text/plain") .body(r#"{"name": "glen", "age": 42}"#.into()) .unwrap(); let resp = service.serve(Context::default(), req).await.unwrap(); @@ -143,9 +146,12 @@ mod test { let service = WebService::default() .post("/", |Csv(_): Csv>| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/csv; charset=utf-8") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header( + rama_http_types::header::CONTENT_TYPE, + "text/csv; charset=utf-8", + ) // the missing column last line should trigger an error .body("name,age,alive\nglen,42,\nadr,40\n".into()) .unwrap(); diff --git a/rama-http/src/service/web/endpoint/extract/body/json.rs b/rama-http/src/service/web/endpoint/extract/body/json.rs index c58628062..3c6959dbe 100644 --- a/rama-http/src/service/web/endpoint/extract/body/json.rs +++ b/rama-http/src/service/web/endpoint/extract/body/json.rs @@ -85,10 +85,10 @@ mod test { assert_eq!(body.alive, None); }); - let req = http::Request::builder() - .method(http::Method::POST) + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) .header( - http::header::CONTENT_TYPE, + rama_http_types::header::CONTENT_TYPE, "application/json; charset=utf-8", ) .body(r#"{"name": "glen", "age": 42}"#.into()) @@ -109,9 +109,9 @@ mod test { let service = WebService::default().post("/", |Json(_): Json| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) - .header(http::header::CONTENT_TYPE, "text/plain") + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) + .header(rama_http_types::header::CONTENT_TYPE, "text/plain") .body(r#"{"name": "glen", "age": 42}"#.into()) .unwrap(); let resp = service.serve(Context::default(), req).await.unwrap(); @@ -130,10 +130,10 @@ mod test { let service = WebService::default().post("/", |Json(_): Json| async move { StatusCode::OK }); - let req = http::Request::builder() - .method(http::Method::POST) + let req = rama_http_types::Request::builder() + .method(rama_http_types::Method::POST) .header( - http::header::CONTENT_TYPE, + rama_http_types::header::CONTENT_TYPE, "application/json; charset=utf-8", ) .body(r#"deal with it, or not?!"#.into()) diff --git a/rama-http/src/service/web/endpoint/extract/host.rs b/rama-http/src/service/web/endpoint/extract/host.rs index f49c68ce7..e65af74b2 100644 --- a/rama-http/src/service/web/endpoint/extract/host.rs +++ b/rama-http/src/service/web/endpoint/extract/host.rs @@ -1,7 +1,7 @@ use super::FromRequestContextRefPair; -use crate::dep::http::request::Parts; use crate::utils::macros::define_http_rejection; use rama_core::Context; +use rama_http_types::dep::http::request::Parts; use rama_net::address; use rama_net::http::RequestContext; use rama_utils::macros::impl_deref; @@ -75,7 +75,7 @@ mod tests { test_host_from_request( "/", "some-domain", - vec![(&http::header::HOST, "some-domain:123")], + vec![(&rama_http_types::header::HOST, "some-domain:123")], ) .await; } @@ -97,7 +97,7 @@ mod tests { "some-domain", vec![ (&X_FORWARDED_HOST, "some-domain:456"), - (&http::header::HOST, "some-domain:123"), + (&rama_http_types::header::HOST, "some-domain:123"), ], ) .await; diff --git a/rama-http/src/service/web/endpoint/extract/typed_header.rs b/rama-http/src/service/web/endpoint/extract/typed_header.rs index 482d80808..230dde28e 100644 --- a/rama-http/src/service/web/endpoint/extract/typed_header.rs +++ b/rama-http/src/service/web/endpoint/extract/typed_header.rs @@ -128,7 +128,7 @@ impl TypedHeaderRejectionReason { impl IntoResponse for TypedHeaderRejection { fn into_response(self) -> Response { - (http::StatusCode::BAD_REQUEST, self.to_string()).into_response() + (rama_http_types::StatusCode::BAD_REQUEST, self.to_string()).into_response() } } diff --git a/rama-net/Cargo.toml b/rama-net/Cargo.toml index 49c1d027c..e118f8713 100644 --- a/rama-net/Cargo.toml +++ b/rama-net/Cargo.toml @@ -50,6 +50,7 @@ venndb = { workspace = true, optional = true } itertools = { workspace = true } nom = { workspace = true } quickcheck = { workspace = true } +serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-test = { workspace = true } diff --git a/rama-net/src/tls/client/config.rs b/rama-net/src/tls/client/config.rs index 2c73e9805..47598d662 100644 --- a/rama-net/src/tls/client/config.rs +++ b/rama-net/src/tls/client/config.rs @@ -1,5 +1,15 @@ +use std::sync::Arc; + use super::{ClientHelloExtension, merge_client_hello_lists}; -use crate::tls::{CipherSuite, CompressionAlgorithm, DataEncoding, KeyLogIntent}; +use crate::tls::{CipherSuite, CompressionAlgorithm, DataEncoding, KeyLogIntent, ProtocolVersion}; + +#[derive(Debug, Clone, Default)] +/// Common API to configure a Proxy TLS Client +/// +/// See [`ClientConfig`] for more information, +/// this is only a new-type wrapper to be able to differentiate +/// the info found in context for a dynamic https client. +pub struct ProxyClientConfig(pub Arc); #[derive(Debug, Clone, Default)] /// Common API to configure a TLS Client @@ -96,3 +106,14 @@ impl From for ClientConfig { } } } + +impl From for super::ClientHello { + fn from(value: ClientConfig) -> Self { + super::ClientHello { + protocol_version: ProtocolVersion::TLSv1_2, + cipher_suites: value.cipher_suites.unwrap_or_default(), + compression_algorithms: value.compression_algorithms.unwrap_or_default(), + extensions: value.extensions.unwrap_or_default(), + } + } +} diff --git a/rama-net/src/tls/client/hello/mod.rs b/rama-net/src/tls/client/hello/mod.rs index 20191a633..9e1c22c61 100644 --- a/rama-net/src/tls/client/hello/mod.rs +++ b/rama-net/src/tls/client/hello/mod.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + use crate::address::Host; use crate::tls::{ ApplicationProtocol, CipherSuite, ECPointFormat, ExtensionId, ProtocolVersion, SignatureScheme, @@ -10,7 +12,7 @@ mod rustls; #[cfg(feature = "boring")] mod boring; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] /// When a client first connects to a server, it is required to send /// the ClientHello as its first message. /// @@ -125,7 +127,7 @@ impl ClientHello { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] /// Extensions that can be set in a [`ClientHello`] message by a TLS client. /// /// While its name may infer that an extension is by definition optional, diff --git a/rama-net/src/tls/client/mod.rs b/rama-net/src/tls/client/mod.rs index 17207d619..d0c2d39e1 100644 --- a/rama-net/src/tls/client/mod.rs +++ b/rama-net/src/tls/client/mod.rs @@ -20,7 +20,7 @@ pub(crate) use parser::parse_client_hello; mod config; #[doc(inline)] -pub use config::{ClientAuth, ClientAuthData, ClientConfig, ServerVerifyMode}; +pub use config::{ClientAuth, ClientAuthData, ClientConfig, ProxyClientConfig, ServerVerifyMode}; use super::{ApplicationProtocol, DataEncoding, ProtocolVersion}; diff --git a/rama-net/src/tls/enums/mod.rs b/rama-net/src/tls/enums/mod.rs index bcaea6e94..fba5ac7b8 100644 --- a/rama-net/src/tls/enums/mod.rs +++ b/rama-net/src/tls/enums/mod.rs @@ -66,6 +66,27 @@ macro_rules! enum_builder { ::std::fmt::UpperHex::fmt(&u8::from(*self), f) } } + + impl ::serde::Serialize for $enum_name { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + u8::from(*self).serialize(serializer) + } + } + + impl<'de> ::serde::Deserialize<'de> for $enum_name { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + let n = u8::deserialize(deserializer)?; + Ok(n.into()) + } + } }; ( $(#[$comment:meta])* @@ -133,6 +154,27 @@ macro_rules! enum_builder { ::std::fmt::UpperHex::fmt(&u16::from(*self), f) } } + + impl ::serde::Serialize for $enum_name { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + u16::from(*self).serialize(serializer) + } + } + + impl<'de> ::serde::Deserialize<'de> for $enum_name { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + let n = u16::deserialize(deserializer)?; + Ok(n.into()) + } + } }; ( $(#[$comment:meta])* @@ -142,7 +184,7 @@ macro_rules! enum_builder { ) => { $(#[$comment])* #[non_exhaustive] - #[derive(Debug, PartialEq, Eq, Clone)] + #[derive(Debug, PartialEq, Eq, Clone, Hash)] $enum_vis enum $enum_name { $( $enum_var),* ,Unknown(Vec) @@ -256,6 +298,34 @@ macro_rules! enum_builder { } } } + + impl ::serde::Serialize for $enum_name { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + match self { + $( $enum_name::$enum_var => { + $enum_val.serialize(serializer) + }),* + ,$enum_name::Unknown(x) => { + x.serialize(serializer) + } + } + } + } + + impl<'de> ::serde::Deserialize<'de> for $enum_name { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + let b = <::std::borrow::Cow<'de, [u8]>>::deserialize(deserializer)?; + Ok(b.as_ref().into()) + } + } }; } @@ -1028,4 +1098,45 @@ mod tests { assert_eq!(12, r.position()); assert_eq!(&INPUT.as_bytes()[3..12], b"\x08http/1.1"); } + + #[test] + fn test_enum_u8_serialize_deserialize() { + let p: ECPointFormat = serde_json::from_str( + &serde_json::to_string(&ECPointFormat::ANSIX962CompressedChar2).unwrap(), + ) + .unwrap(); + assert_eq!(ECPointFormat::ANSIX962CompressedChar2, p); + + let p: ECPointFormat = + serde_json::from_str(&serde_json::to_string(&ECPointFormat::from(42u8)).unwrap()) + .unwrap(); + assert_eq!(ECPointFormat::from(42u8), p); + } + + #[test] + fn test_enum_u16_serialize_deserialize() { + let p: SupportedGroup = + serde_json::from_str(&serde_json::to_string(&SupportedGroup::BRAINPOOLP384R1).unwrap()) + .unwrap(); + assert_eq!(SupportedGroup::BRAINPOOLP384R1, p); + + let p: SupportedGroup = + serde_json::from_str(&serde_json::to_string(&SupportedGroup::from(0xffffu16)).unwrap()) + .unwrap(); + assert_eq!(SupportedGroup::from(0xffffu16), p); + } + + #[test] + fn test_enum_bytes_serialize_deserialize() { + let p: ApplicationProtocol = + serde_json::from_str(&serde_json::to_string(&ApplicationProtocol::HTTP_3).unwrap()) + .unwrap(); + assert_eq!(ApplicationProtocol::HTTP_3, p); + + let p: ApplicationProtocol = serde_json::from_str( + &serde_json::to_string(&ApplicationProtocol::from(b"foobar")).unwrap(), + ) + .unwrap(); + assert_eq!(ApplicationProtocol::from(b"foobar"), p); + } } diff --git a/rama-tls/src/boring/client/connector.rs b/rama-tls/src/boring/client/connector.rs index cb8ac7b83..0175d281b 100644 --- a/rama-tls/src/boring/client/connector.rs +++ b/rama-tls/src/boring/client/connector.rs @@ -1,19 +1,19 @@ -use super::TlsConnectorData; -use crate::types::TlsTunnel; -use pin_project_lite::pin_project; use private::{ConnectorKindAuto, ConnectorKindSecure, ConnectorKindTunnel}; -use rama_core::error::{BoxError, ErrorExt, OpaqueError}; +use rama_core::error::{BoxError, ErrorContext, ErrorExt, OpaqueError}; use rama_core::{Context, Layer, Service}; use rama_net::address::Host; use rama_net::client::{ConnectorService, EstablishedClientConnection}; use rama_net::stream::Stream; use rama_net::tls::ApplicationProtocol; -use rama_net::tls::client::NegotiatedTlsParameters; +use rama_net::tls::client::{ClientConfig, NegotiatedTlsParameters}; use rama_net::transport::TryRefIntoTransportContext; use std::fmt; -use tokio::io::{AsyncRead, AsyncWrite}; +use std::sync::Arc; use tokio_boring::SslStream; +use super::{AutoTlsStream, TlsConnectorData, TlsStream}; +use crate::types::TlsTunnel; + /// A [`Layer`] which wraps the given service with a [`TlsConnector`]. /// /// See [`TlsConnector`] for more information. @@ -253,15 +253,27 @@ where return Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Plain { inner: conn }, - }, + conn: AutoTlsStream::plain(conn), }); } let host = transport_ctx.authority.host().clone(); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into boring connector data")?, + ), + None => None, + }, + }; + let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?; tracing::trace!( @@ -274,9 +286,7 @@ where Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Secure { inner: stream }, - }, + conn: AutoTlsStream::secure(stream), }) } } @@ -289,7 +299,7 @@ where + Send + 'static, { - type Response = EstablishedClientConnection, State, Request>; + type Response = EstablishedClientConnection, State, Request>; type Error = BoxError; async fn serve( @@ -314,8 +324,23 @@ where let host = transport_ctx.authority.host().clone(); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into boring connector data")?, + ), + None => None, + }, + }; + let (conn, negotiated_params) = self.handshake(connector_data, host, conn).await?; + let conn = TlsStream::new(conn); ctx.insert(negotiated_params); Ok(EstablishedClientConnection { ctx, req, conn }) @@ -353,14 +378,26 @@ where return Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Plain { inner: conn }, - }, + conn: AutoTlsStream::plain(conn), }); } }; - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into boring connector data")?, + ), + None => None, + }, + }; + let (stream, negotiated_params) = self.handshake(connector_data, host, conn).await?; ctx.insert(negotiated_params); @@ -368,13 +405,39 @@ where Ok(EstablishedClientConnection { ctx, req, - conn: AutoTlsStream { - inner: AutoTlsStreamData::Secure { inner: stream }, - }, + conn: AutoTlsStream::secure(stream), }) } } +pub async fn tls_connect( + server_host: Host, + stream: T, + connector_data: Option<&TlsConnectorData>, +) -> Result, OpaqueError> +where + T: Stream + Unpin, +{ + let client_config_data = match connector_data { + Some(connector_data) => connector_data.try_to_build_config()?, + None => TlsConnectorData::new()?.try_to_build_config()?, + }; + let server_host = client_config_data.server_name.unwrap_or(server_host); + let stream = tokio_boring::connect( + client_config_data.config, + server_host.to_string().as_str(), + stream, + ) + .await + .map_err(|err| match err.as_io_error() { + Some(err) => OpaqueError::from_display(err.to_string()) + .context("boring ssl connector: connect") + .into_boxed(), + None => OpaqueError::from_display("boring ssl connector: connect").into_boxed(), + })?; + Ok(TlsStream::new(stream)) +} + impl TlsConnector { async fn handshake( &self, @@ -386,23 +449,8 @@ impl TlsConnector { T: Stream + Unpin, { let connector_data = connector_data.as_ref().or(self.connector_data.as_ref()); - let client_config_data = match connector_data { - Some(connector_data) => connector_data.try_to_build_config()?, - None => TlsConnectorData::new_http_auto()?.try_to_build_config()?, - }; - let server_host = client_config_data.server_name.unwrap_or(server_host); - let stream = tokio_boring::connect( - client_config_data.config, - server_host.to_string().as_str(), - stream, - ) - .await - .map_err(|err| match err.as_io_error() { - Some(err) => OpaqueError::from_display(err.to_string()) - .context("boring ssl connector: connect") - .into_boxed(), - None => OpaqueError::from_display("boring ssl connector: connect").into_boxed(), - })?; + + let TlsStream { inner: stream } = tls_connect(server_host, stream, connector_data).await?; let params = match stream.ssl().session() { Some(ssl_session) => { @@ -447,94 +495,6 @@ impl TlsConnector { } } -pin_project! { - /// A stream which can be either a secure or a plain stream. - pub struct AutoTlsStream { - #[pin] - inner: AutoTlsStreamData, - } -} - -impl fmt::Debug for AutoTlsStream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("AutoTlsStream") - .field("inner", &self.inner) - .finish() - } -} - -pin_project! { - #[project = AutoTlsStreamDataProj] - /// A stream which can be either a secure or a plain stream. - enum AutoTlsStreamData { - /// A secure stream. - Secure{ #[pin] inner: SslStream }, - /// A plain stream. - Plain { #[pin] inner: S }, - } -} - -impl fmt::Debug for AutoTlsStreamData { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(), - AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(), - } - } -} - -impl AsyncRead for AutoTlsStream -where - S: Stream + Unpin, -{ - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf), - } - } -} - -impl AsyncWrite for AutoTlsStream -where - S: Stream + Unpin, -{ - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf), - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx), - } - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.project().inner.project() { - AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx), - AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx), - } - } -} - mod private { use rama_net::address::Host; diff --git a/rama-tls/src/boring/client/mod.rs b/rama-tls/src/boring/client/mod.rs index 5e97fd8ab..a44bfbd73 100644 --- a/rama-tls/src/boring/client/mod.rs +++ b/rama-tls/src/boring/client/mod.rs @@ -1,8 +1,14 @@ -//! TLS client support for Rama. +//! TLS (Boring) client support for Rama. + +mod tls_stream_auto; +pub use tls_stream_auto::AutoTlsStream; + +mod tls_stream; +pub use tls_stream::TlsStream; mod connector; #[doc(inline)] -pub use connector::{AutoTlsStream, TlsConnector, TlsConnectorLayer}; +pub use connector::{TlsConnector, TlsConnectorLayer, tls_connect}; mod connector_data; #[doc(inline)] diff --git a/rama-tls/src/boring/client/tls_stream.rs b/rama-tls/src/boring/client/tls_stream.rs new file mode 100644 index 000000000..0bcc1b001 --- /dev/null +++ b/rama-tls/src/boring/client/tls_stream.rs @@ -0,0 +1,85 @@ +use std::fmt; + +use boring::ssl::SslRef; +use pin_project_lite::pin_project; +use rama_net::stream::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_boring::SslStream; + +pin_project! { + /// A stream which can be either a secure or a plain stream. + pub struct TlsStream { + #[pin] + pub(super) inner: SslStream, + } +} + +impl TlsStream { + pub(super) fn new(inner: SslStream) -> Self { + Self { inner } + } + + pub fn ssl_ref(&self) -> &SslRef { + self.inner.ssl() + } +} + +impl fmt::Debug for TlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsStream") + .field("inner", &self.inner) + .finish() + } +} + +impl AsyncRead for TlsStream +where + S: Stream + Unpin, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_read(cx, buf) + } +} + +impl AsyncWrite for TlsStream +where + S: Stream + Unpin, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + false + } +} diff --git a/rama-tls/src/boring/client/tls_stream_auto.rs b/rama-tls/src/boring/client/tls_stream_auto.rs new file mode 100644 index 000000000..0888dbe36 --- /dev/null +++ b/rama-tls/src/boring/client/tls_stream_auto.rs @@ -0,0 +1,132 @@ +use std::fmt; + +use boring::ssl::SslRef; +use pin_project_lite::pin_project; +use rama_net::stream::Stream; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_boring::SslStream; + +pin_project! { + /// A stream which can be either a secure or a plain stream. + pub struct AutoTlsStream { + #[pin] + inner: AutoTlsStreamData, + } +} + +impl AutoTlsStream { + pub(super) fn secure(inner: SslStream) -> Self { + Self { + inner: AutoTlsStreamData::Secure { inner }, + } + } + + pub(super) fn plain(inner: S) -> Self { + Self { + inner: AutoTlsStreamData::Plain { inner }, + } + } + + pub fn ssl_ref(&self) -> Option<&SslRef> { + match &self.inner { + AutoTlsStreamData::Secure { inner } => Some(inner.ssl()), + AutoTlsStreamData::Plain { .. } => None, + } + } +} + +impl fmt::Debug for AutoTlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AutoTlsStream") + .field("inner", &self.inner) + .finish() + } +} + +pin_project! { + #[project = AutoTlsStreamDataProj] + /// A stream which can be either a secure or a plain stream. + enum AutoTlsStreamData { + /// A secure stream. + Secure{ #[pin] inner: SslStream }, + /// A plain stream. + Plain { #[pin] inner: S }, + } +} + +impl fmt::Debug for AutoTlsStreamData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(), + AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(), + } + } +} + +impl AsyncRead for AutoTlsStream +where + S: Stream + Unpin, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for AutoTlsStream +where + S: Stream + Unpin, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project().inner.project() { + AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx), + AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + let buf = bufs + .iter() + .find(|b| !b.is_empty()) + .map_or(&[][..], |b| &**b); + self.poll_write(cx, buf) + } + + fn is_write_vectored(&self) -> bool { + false + } +} diff --git a/rama-tls/src/rustls/client/connector.rs b/rama-tls/src/rustls/client/connector.rs index bf15fe5e6..d5d08c48f 100644 --- a/rama-tls/src/rustls/client/connector.rs +++ b/rama-tls/src/rustls/client/connector.rs @@ -10,6 +10,7 @@ use rama_net::address::Host; use rama_net::client::{ConnectorService, EstablishedClientConnection}; use rama_net::stream::Stream; use rama_net::tls::ApplicationProtocol; +use rama_net::tls::client::ClientConfig; use rama_net::tls::client::NegotiatedTlsParameters; use rama_net::transport::TryRefIntoTransportContext; use std::fmt; @@ -268,7 +269,21 @@ where "TlsConnector(auto): attempt to secure inner connection", ); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into rustls connector data")?, + ), + None => None, + }, + }; + let (stream, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; tracing::trace!( @@ -322,7 +337,21 @@ where let server_host = transport_ctx.authority.host().clone(); - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into rustls connector data")?, + ), + None => None, + }, + }; + let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; ctx.insert(negotiated_params); @@ -368,7 +397,21 @@ where } }; - let connector_data = ctx.get().cloned(); + let connector_data = match ctx.get::() { + Some(cd) => Some(cd.clone()), + None => match ctx.get::>() { + // support info passed down by layers such as tls emulators + Some(tls_config) => Some( + tls_config + .as_ref() + .clone() + .try_into() + .context("turn context ClientConfig into rustls connector data")?, + ), + None => None, + }, + }; + let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?; ctx.insert(negotiated_params); diff --git a/rama-ua/Cargo.toml b/rama-ua/Cargo.toml index 664f271aa..654cb608e 100644 --- a/rama-ua/Cargo.toml +++ b/rama-ua/Cargo.toml @@ -13,10 +13,20 @@ rust-version = { workspace = true } [lints] workspace = true +[features] +default = [] +tls = ["rama-net/tls"] + [dependencies] +bytes = { workspace = true } +itertools = { workspace = true } rama-core = { version = "0.2.0-alpha.7", path = "../rama-core" } +rama-http-types = { version = "0.2.0-alpha.7", path = "../rama-http-types" } +rama-net = { version = "0.2.0-alpha.7", path = "../rama-net", features = ["http"] } rama-utils = { version = "0.2.0-alpha.7", path = "../rama-utils" } +rand = { workspace = true } serde = { workspace = true, features = ["derive"] } +tracing = { workspace = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/rama-ua/src/emulate/layer.rs b/rama-ua/src/emulate/layer.rs new file mode 100644 index 000000000..baa982893 --- /dev/null +++ b/rama-ua/src/emulate/layer.rs @@ -0,0 +1,135 @@ +use std::fmt; + +use rama_core::Layer; +use rama_http_types::HeaderName; + +use super::UserAgentSelectFallback; + +pub struct UserAgentEmulateLayer

{ + provider: P, + optional: bool, + try_auto_detect_user_agent: bool, + input_header_order: Option, + select_fallback: Option, +} + +impl fmt::Debug for UserAgentEmulateLayer

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UserAgentEmulateLayer") + .field("provider", &self.provider) + .field("optional", &self.optional) + .field( + "try_auto_detect_user_agent", + &self.try_auto_detect_user_agent, + ) + .field("input_header_order", &self.input_header_order) + .field("select_fallback", &self.select_fallback) + .finish() + } +} + +impl Clone for UserAgentEmulateLayer

{ + fn clone(&self) -> Self { + Self { + provider: self.provider.clone(), + optional: self.optional, + try_auto_detect_user_agent: self.try_auto_detect_user_agent, + input_header_order: self.input_header_order.clone(), + select_fallback: self.select_fallback, + } + } +} + +impl

UserAgentEmulateLayer

{ + pub fn new(provider: P) -> Self { + Self { + provider, + optional: false, + try_auto_detect_user_agent: false, + input_header_order: None, + select_fallback: None, + } + } + + /// When no user agent profile was found it will + /// fail the request unless optional is true. In case of + /// the latter the service will do nothing. + pub fn optional(mut self, optional: bool) -> Self { + self.optional = optional; + self + } + + /// See [`Self::optional`]. + pub fn set_optional(&mut self, optional: bool) -> &mut Self { + self.optional = optional; + self + } + + /// If true, the layer will try to auto-detect the user agent from the request, + /// but only in case that info is not yet found in the context. + pub fn try_auto_detect_user_agent(mut self, try_auto_detect_user_agent: bool) -> Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + + /// See [`Self::try_auto_detect_user_agent`]. + pub fn set_try_auto_detect_user_agent( + &mut self, + try_auto_detect_user_agent: bool, + ) -> &mut Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + + /// Define a header that if present is to contain a CSV header name list, + /// that allows you to define the desired header order for the (extra) headers + /// found in the input (http) request. + /// + /// Extra meaning any headers not considered a base header and already defined + /// by the (selected) User Agent Profile. + /// + /// This can be useful because your http client might not respect the header casing + /// and/or order of the headers taken together. Using this metadata allows you to + /// communicate this data through anyway. If however your http client does respect + /// casing and order, or you don't care about some of it, you might not need it. + pub fn input_header_order(mut self, name: HeaderName) -> Self { + self.input_header_order = Some(name); + self + } + + /// See [`Self::input_header_order`]. + pub fn set_input_header_order(&mut self, name: HeaderName) -> &mut Self { + self.input_header_order = Some(name); + self + } + + /// Choose what to do in case no profile could be selected + /// using the regular pre-conditions as specified by the provider. + pub fn select_fallback(mut self, fb: UserAgentSelectFallback) -> Self { + self.select_fallback = Some(fb); + self + } + + /// See [`Self::select_fallback`]. + pub fn set_select_fallback(&mut self, fb: UserAgentSelectFallback) -> &mut Self { + self.select_fallback = Some(fb); + self + } +} + +impl Layer for UserAgentEmulateLayer

{ + type Service = super::UserAgentEmulateService; + + fn layer(&self, inner: S) -> Self::Service { + let mut svc = super::UserAgentEmulateService::new(inner, self.provider.clone()) + .optional(self.optional) + .try_auto_detect_user_agent(self.try_auto_detect_user_agent); + if let Some(fb) = self.select_fallback { + svc.set_select_fallback(fb); + } + if let Some(name) = self.input_header_order.clone() { + svc.set_input_header_order(name); + } + svc + } +} diff --git a/rama-ua/src/emulate/mod.rs b/rama-ua/src/emulate/mod.rs new file mode 100644 index 000000000..bb78ac71e --- /dev/null +++ b/rama-ua/src/emulate/mod.rs @@ -0,0 +1,8 @@ +mod provider; +pub use provider::{UserAgentProvider, UserAgentSelectFallback}; + +mod layer; +pub use layer::UserAgentEmulateLayer; + +mod service; +pub use service::UserAgentEmulateService; diff --git a/rama-ua/src/emulate/provider.rs b/rama-ua/src/emulate/provider.rs new file mode 100644 index 000000000..5ce04c7c6 --- /dev/null +++ b/rama-ua/src/emulate/provider.rs @@ -0,0 +1,105 @@ +use std::sync::Arc; + +use rama_core::Context; + +use crate::{UserAgentDatabase, UserAgentProfile}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +/// Fallback strategy that can be injected into the context +/// to customise what a provider can be requested to do +/// in case the preconditions for UA selection were not fulfilled. +/// +/// It is advised only fallback for pre-conditions and not +/// post-selection failure as the latter would be rather confusing. +/// +/// For example if you request a Chromium profile you do not expect a Firefox one. +/// However if you do not give any filters it is fair to assume a random profile is desired, +/// given those all satisfy the abscence of filters. +pub enum UserAgentSelectFallback { + #[default] + Abort, + Random, +} + +pub trait UserAgentProvider: Send + Sync + 'static { + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile>; +} + +impl UserAgentProvider for () { + #[inline] + fn select_user_agent_profile(&self, _ctx: &Context) -> Option<&UserAgentProfile> { + None + } +} + +impl UserAgentProvider for UserAgentProfile { + #[inline] + fn select_user_agent_profile(&self, _ctx: &Context) -> Option<&UserAgentProfile> { + Some(self) + } +} + +impl UserAgentProvider for UserAgentDatabase { + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + match (ctx.get(), ctx.get()) { + (Some(agent), _) => self.get(agent), + (None, Some(UserAgentSelectFallback::Random)) => self.rnd(), + (None, None | Some(UserAgentSelectFallback::Abort)) => None, + } + } +} + +impl UserAgentProvider for Option

+where + P: UserAgentProvider, +{ + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + self.as_ref().and_then(|p| p.select_user_agent_profile(ctx)) + } +} + +impl UserAgentProvider for Arc

+where + P: UserAgentProvider, +{ + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + self.as_ref().select_user_agent_profile(ctx) + } +} + +impl UserAgentProvider for Box

+where + P: UserAgentProvider, +{ + #[inline] + fn select_user_agent_profile(&self, ctx: &Context) -> Option<&UserAgentProfile> { + self.as_ref().select_user_agent_profile(ctx) + } +} + +macro_rules! impl_user_agent_provider_either { + ($id:ident, $($param:ident),+ $(,)?) => { + impl UserAgentProvider for ::rama_core::combinators::$id<$($param),+> + where + $( + $param: UserAgentProvider, + )+ + { + fn select_user_agent_profile( + &self, + ctx: &Context, + ) -> Option<&UserAgentProfile> { + match self { + $( + ::rama_core::combinators::$id::$param(s) => s.select_user_agent_profile(ctx), + )+ + } + } + } + }; +} + +::rama_core::combinators::impl_either!(impl_user_agent_provider_either); diff --git a/rama-ua/src/emulate/service.rs b/rama-ua/src/emulate/service.rs new file mode 100644 index 000000000..2895f1447 --- /dev/null +++ b/rama-ua/src/emulate/service.rs @@ -0,0 +1,1788 @@ +use std::fmt; + +use rama_core::{ + Context, Service, + error::{BoxError, ErrorContext, OpaqueError}, +}; +use rama_http_types::{ + HeaderMap, HeaderName, IntoResponse, Method, Request, Response, Version, + compression::DecompressIfPossible, + conn::Http1ClientContextParams, + header::{ + ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, HOST, ORIGIN, + REFERER, USER_AGENT, + }, + headers::{ + ClientHint, + encoding::{Encoding, parse_accept_encoding_headers}, + }, + proto::{ + h1::{ + Http1HeaderMap, + headers::{HeaderMapValueRemover, original::OriginalHttp1Headers}, + }, + h2::PseudoHeaderOrder, + }, +}; +use rama_net::{Protocol, http::RequestContext}; + +use crate::{ + CUSTOM_HEADER_MARKER, HttpAgent, HttpHeadersProfile, RequestInitiator, UserAgent, + UserAgentProfile, contains_ignore_ascii_case, starts_with_ignore_ascii_case, +}; + +use super::{UserAgentProvider, UserAgentSelectFallback}; + +pub struct UserAgentEmulateService { + inner: S, + provider: P, + optional: bool, + try_auto_detect_user_agent: bool, + input_header_order: Option, + select_fallback: Option, +} + +impl fmt::Debug for UserAgentEmulateService { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UserAgentEmulateService") + .field("inner", &self.inner) + .field("provider", &self.provider) + .field("optional", &self.optional) + .field( + "try_auto_detect_user_agent", + &self.try_auto_detect_user_agent, + ) + .field("input_header_order", &self.input_header_order) + .field("select_fallback", &self.select_fallback) + .finish() + } +} + +impl Clone for UserAgentEmulateService { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + provider: self.provider.clone(), + optional: self.optional, + try_auto_detect_user_agent: self.try_auto_detect_user_agent, + input_header_order: self.input_header_order.clone(), + select_fallback: self.select_fallback, + } + } +} + +impl UserAgentEmulateService { + pub fn new(inner: S, provider: P) -> Self { + Self { + inner, + provider, + optional: false, + try_auto_detect_user_agent: false, + input_header_order: None, + select_fallback: None, + } + } + + /// When no user agent profile was found it will + /// fail the request unless optional is true. In case of + /// the latter the service will do nothing. + pub fn optional(mut self, optional: bool) -> Self { + self.optional = optional; + self + } + + /// See [`Self::optional`]. + pub fn set_optional(&mut self, optional: bool) -> &mut Self { + self.optional = optional; + self + } + + /// If true, the service will try to auto-detect the user agent from the request, + /// but only in case that info is not yet found in the context. + pub fn try_auto_detect_user_agent(mut self, try_auto_detect_user_agent: bool) -> Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + + /// See [`Self::try_auto_detect_user_agent`]. + pub fn set_try_auto_detect_user_agent( + &mut self, + try_auto_detect_user_agent: bool, + ) -> &mut Self { + self.try_auto_detect_user_agent = try_auto_detect_user_agent; + self + } + + /// Define a header that if present is to contain a CSV header name list, + /// that allows you to define the desired header order for the (extra) headers + /// found in the input (http) request. + /// + /// Extra meaning any headers not considered a base header and already defined + /// by the (selected) User Agent Profile. + /// + /// This can be useful because your http client might not respect the header casing + /// and/or order of the headers taken together. Using this metadata allows you to + /// communicate this data through anyway. If however your http client does respect + /// casing and order, or you don't care about some of it, you might not need it. + pub fn input_header_order(mut self, name: HeaderName) -> Self { + self.input_header_order = Some(name); + self + } + + /// See [`Self::input_header_order`]. + pub fn set_input_header_order(&mut self, name: HeaderName) -> &mut Self { + self.input_header_order = Some(name); + self + } + + /// Choose what to do in case no profile could be selected + /// using the regular pre-conditions as specified by the provider. + pub fn select_fallback(mut self, fb: UserAgentSelectFallback) -> Self { + self.select_fallback = Some(fb); + self + } + + /// See [`Self::select_fallback`]. + pub fn set_select_fallback(&mut self, fb: UserAgentSelectFallback) -> &mut Self { + self.select_fallback = Some(fb); + self + } +} + +impl Service> for UserAgentEmulateService +where + State: Clone + Send + Sync + 'static, + Body: Send + Sync + 'static, + S: Service, Response: IntoResponse, Error: Into>, + P: UserAgentProvider, +{ + type Response = Response; + type Error = BoxError; + + async fn serve( + &self, + mut ctx: Context, + mut req: Request, + ) -> Result { + if let Some(fallback) = self.select_fallback { + ctx.insert(fallback); + } + + if self.try_auto_detect_user_agent && !ctx.contains::() { + match req + .headers() + .get(USER_AGENT) + .and_then(|ua| ua.to_str().ok()) + { + Some(ua_str) => { + let user_agent = UserAgent::new(ua_str); + tracing::trace!( + ua_str = %ua_str, + %user_agent, + "user agent auto-detected from request" + ); + ctx.insert(user_agent); + } + None => { + tracing::debug!( + "user agent auto-detection not possible: no user agent header present" + ); + } + } + } + + let profile = match self.provider.select_user_agent_profile(&ctx) { + Some(profile) => profile, + None => { + return if self.optional { + Ok(self + .inner + .serve(ctx, req) + .await + .map_err(Into::into)? + .into_response()) + } else { + Err(OpaqueError::from_display( + "requirement not fulfilled: user agent profile could not be selected", + ) + .into_boxed()) + }; + } + }; + + tracing::debug!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent profile selected for emulation" + ); + + let preserve_http = matches!( + ctx.get::().and_then(|ua| ua.http_agent()), + Some(HttpAgent::Preserve), + ); + + let requested_client_hints = ctx + .get::() + .map(|ua| ua.requested_client_hints().copied().collect::>()); + + let mut original_requested_encodings = None; + + if preserve_http { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: skip http settings as http is instructed to be preserved" + ); + } else { + emulate_http_settings(&mut ctx, &mut req, profile); + match get_base_http_headers(&ctx, &req, profile) { + Some(base_http_headers) => { + let original_http_header_order = get_original_http_header_order( + &ctx, + &req, + self.input_header_order.as_ref(), + ) + .context("collect original http header order")?; + + let original_headers = req.headers().clone(); + + let preserve_ua_header = ctx + .get::() + .map(|ua| ua.preserve_ua_header()) + .unwrap_or_default(); + + let is_secure_request = match ctx.get::() { + Some(request_ctx) => request_ctx.protocol.is_secure(), + None => req + .uri() + .scheme() + .map(|s| Protocol::from(s.clone()).is_secure()) + .unwrap_or_default(), + }; + + original_requested_encodings = Some( + parse_accept_encoding_headers(&original_headers, true) + .map(|qv| qv.value) + .collect::>(), + ); + + let output_headers = merge_http_headers( + base_http_headers, + original_http_header_order, + original_headers, + preserve_ua_header, + is_secure_request, + requested_client_hints.as_deref(), + ); + + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: http settings and headers emulated" + ); + let (output_headers, original_headers) = output_headers.into_parts(); + *req.headers_mut() = output_headers; + req.extensions_mut().insert(original_headers); + } + None => { + tracing::debug!( + "user agent emulation: no http headers to emulate: no base http headers found" + ); + } + } + } + + #[cfg(feature = "tls")] + { + use crate::TlsAgent; + + let preserve_tls = matches!( + ctx.get::().and_then(|ua| ua.tls_agent()), + Some(TlsAgent::Preserve), + ); + if preserve_tls { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "user agent emulation: skip tls settings as http is instructed to be preserved" + ); + } else { + // client_config's Arc is to be lazilly cloned by a tls connector + // only when a connection is to be made, as to play nicely + // with concepts such as connection pooling + ctx.insert(profile.tls.client_config.clone()); + } + } + + // serve emulated http(s) request via inner service + let mut res = self + .inner + .serve(ctx, req) + .await + .map_err(Into::into)? + .into_response(); + + if let Some(original_requested_encodings) = original_requested_encodings { + if let Some(content_encoding) = + Encoding::maybe_from_content_encoding_header(res.headers(), true) + { + if !original_requested_encodings.contains(&content_encoding) { + // Only request decompression if the server used a content-encoding + // not listed in the original request's Accept-Encoding header + // or because the callee didn't set this header at all. + res.extensions_mut().insert(DecompressIfPossible::default()); + } + } + } + + Ok(res) + } +} + +fn emulate_http_settings( + ctx: &mut Context, + req: &mut Request, + profile: &UserAgentProfile, +) { + match req.version() { + Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add http1-specific settings", + ); + ctx.insert(Http1ClientContextParams { + title_header_case: profile.http.h1.settings.title_case_headers, + }); + } + Version::HTTP_2 => { + tracing::trace!( + ua_kind = %profile.ua_kind, + ua_version = ?profile.ua_version, + platform = ?profile.platform, + "UA emulation add h2-specific settings", + ); + req.extensions_mut().insert(PseudoHeaderOrder::from_iter( + profile + .http + .h2 + .settings + .http_pseudo_headers + .iter() + .flatten(), + )); + } + Version::HTTP_3 => tracing::debug!( + "UA emulation not yet supported for h3: not applying anything h3-specific" + ), + _ => tracing::debug!( + version = ?req.version(), + "UA emulation not supported for unknown http version: not applying anything version-specific", + ), + } +} + +fn get_base_http_headers<'a, Body, State>( + ctx: &Context, + req: &Request, + profile: &'a UserAgentProfile, +) -> Option<&'a Http1HeaderMap> { + let headers_profile = match req.version() { + Version::HTTP_09 | Version::HTTP_10 | Version::HTTP_11 => &profile.http.h1.headers, + Version::HTTP_2 => &profile.http.h2.headers, + _ => { + tracing::debug!( + version = ?req.version(), + "UA emulation not supported for unknown http version: not applying anything version-specific", + ); + return None; + } + }; + Some( + match ctx.get::().and_then(|ua| ua.request_initiator()) { + Some(req_init) => { + tracing::trace!(%req_init, "base http headers defined based on hint from UserAgent (overwrite)"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + // NOTE: the primitive checks below are pretty bad, + // feel free to help improve. Just need to make sure it has good enough fallbacks, + // and that they are cheap enough to check. + None => match *req.method() { + Method::GET => { + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else { + RequestInitiator::Navigate + }; + tracing::trace!(%req_init, "base http headers defined based on Get=NavigateOrXhr assumption"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + Method::POST => { + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else if headers_contains_partial_value(req.headers(), &CONTENT_TYPE, "form-") + { + RequestInitiator::Form + } else { + RequestInitiator::Fetch + }; + tracing::trace!(%req_init, "base http headers defined based on Post=FormOrFetch assumption"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + _ => { + let req_init = if headers_contains_partial_value( + req.headers(), + &X_REQUESTED_WITH, + "XmlHttpRequest", + ) { + RequestInitiator::Xhr + } else { + RequestInitiator::Fetch + }; + tracing::trace!(%req_init, "base http headers defined based on XhrOrFetch assumption"); + get_base_http_headers_from_req_init(req_init, headers_profile) + } + }, + }, + ) +} + +static X_REQUESTED_WITH: HeaderName = HeaderName::from_static("x-requested-with"); + +fn headers_contains_partial_value(headers: &HeaderMap, name: &HeaderName, value: &str) -> bool { + headers + .get(name) + .and_then(|value| value.to_str().ok()) + .map(|s| contains_ignore_ascii_case(s, value).is_some()) + .unwrap_or_default() +} + +fn get_base_http_headers_from_req_init( + req_init: RequestInitiator, + headers: &HttpHeadersProfile, +) -> &Http1HeaderMap { + match req_init { + RequestInitiator::Navigate => &headers.navigate, + RequestInitiator::Form => headers.form.as_ref().unwrap_or(&headers.navigate), + RequestInitiator::Xhr => headers + .xhr + .as_ref() + .or(headers.fetch.as_ref()) + .unwrap_or(&headers.navigate), + RequestInitiator::Fetch => headers + .fetch + .as_ref() + .or(headers.xhr.as_ref()) + .unwrap_or(&headers.navigate), + } +} + +fn get_original_http_header_order( + ctx: &Context, + req: &Request, + input_header_order: Option<&HeaderName>, +) -> Result, OpaqueError> { + if let Some(header) = input_header_order.and_then(|name| req.headers().get(name)) { + let s = header.to_str().context("interpret header as a utf-8 str")?; + let mut headers = OriginalHttp1Headers::with_capacity(s.matches(',').count()); + for s in s.split(',') { + let s = s.trim(); + if s.is_empty() { + continue; + } + headers.push(s.parse().context("parse header part as h1 headern name")?); + } + return Ok(Some(headers)); + } + Ok(ctx.get().or_else(|| req.extensions().get()).cloned()) +} + +fn merge_http_headers( + base_http_headers: &Http1HeaderMap, + original_http_header_order: Option, + original_headers: HeaderMap, + preserve_ua_header: bool, + is_secure_request: bool, + requested_client_hints: Option<&[ClientHint]>, +) -> Http1HeaderMap { + let mut original_headers = HeaderMapValueRemover::from(original_headers); + + let mut output_headers_a = Vec::new(); + let mut output_headers_b = Vec::new(); + + let mut output_headers_ref = &mut output_headers_a; + + let is_header_allowed = |header_name: &HeaderName| { + if let Some(hint) = ClientHint::match_header_name(header_name) { + is_secure_request + && (hint.is_low_entropy() + || requested_client_hints + .map(|hints| hints.contains(&hint)) + .unwrap_or_default()) + } else { + is_secure_request || !starts_with_ignore_ascii_case(header_name.as_str(), "sec-fetch") + } + }; + + // put all "base" headers in correct order, and with proper name casing + for (base_name, base_value) in base_http_headers.clone().into_iter() { + let base_header_name = base_name.header_name(); + let original_value = original_headers.remove(base_header_name); + match base_header_name { + &ACCEPT | &ACCEPT_LANGUAGE => { + let value = original_value.unwrap_or(base_value); + output_headers_ref.push((base_name, value)); + } + &REFERER | &COOKIE | &AUTHORIZATION | &HOST | &ORIGIN | &CONTENT_LENGTH + | &CONTENT_TYPE => { + if let Some(value) = original_value { + output_headers_ref.push((base_name, value)); + } + } + &USER_AGENT => { + if preserve_ua_header { + let value = original_value.unwrap_or(base_value); + output_headers_ref.push((base_name, value)); + } else { + output_headers_ref.push((base_name, base_value)); + } + } + _ => { + if base_header_name == CUSTOM_HEADER_MARKER { + output_headers_ref = &mut output_headers_b; + } else if is_header_allowed(base_header_name) { + output_headers_ref.push((base_name, base_value)); + } + } + } + } + + // respect original header order of original headers where possible + for header_name in original_http_header_order.into_iter().flatten() { + if let Some(value) = original_headers.remove(header_name.header_name()) { + if is_header_allowed(header_name.header_name()) { + output_headers_a.push((header_name, value)); + } + } + } + + let original_headers_iter = original_headers + .into_iter() + .filter(|(header_name, _)| is_header_allowed(header_name.header_name())); + + Http1HeaderMap::from_iter( + output_headers_a + .into_iter() + .chain(original_headers_iter) // add all remaining original headers in any order within the right loc + .chain(output_headers_b), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::{convert::Infallible, str::FromStr}; + + use itertools::Itertools as _; + use rama_core::service::service_fn; + use rama_http_types::{ + Body, BodyExtractExt, HeaderValue, header::ETAG, proto::h1::Http1HeaderName, + }; + + use crate::{ + Http1Profile, Http1Settings, Http2Profile, Http2Settings, HttpHeadersProfile, HttpProfile, + }; + + #[test] + fn test_merge_http_headers() { + struct TestCase { + description: &'static str, + base_http_headers: Vec<(&'static str, &'static str)>, + original_http_header_order: Option>, + original_headers: Vec<(&'static str, &'static str)>, + preserve_ua_header: bool, + is_secure_request: bool, + requested_client_hints: Option>, + expected: Vec<(&'static str, &'static str)>, + } + + let test_cases = [ + TestCase { + description: "empty", + base_http_headers: vec![], + original_http_header_order: None, + original_headers: vec![], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![], + }, + TestCase { + description: "base headers only", + base_http_headers: vec![ + ("Accept", "text/html"), + ("Content-Type", "application/json"), + ], + original_http_header_order: None, + original_headers: vec![], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![("Accept", "text/html")], + }, + TestCase { + description: "base headers only with content-type", + base_http_headers: vec![ + ("Accept", "text/html"), + ("Content-Type", "application/json"), + ], + original_http_header_order: None, + original_headers: vec![("content-type", "text/xml")], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![("Accept", "text/html"), ("Content-Type", "text/xml")], + }, + TestCase { + description: "original headers only", + base_http_headers: vec![], + original_http_header_order: None, + original_headers: vec![("accept", "text/html")], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![("accept", "text/html")], + }, + TestCase { + description: "original and base headers, no conflicts", + base_http_headers: vec![("accept", "text/html"), ("user-agent", "python/3.10")], + original_http_header_order: None, + original_headers: vec![("content-type", "application/json")], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("user-agent", "python/3.10"), + ("content-type", "application/json"), + ], + }, + TestCase { + description: "original and base headers, with conflicts", + base_http_headers: vec![ + ("accept", "text/html"), + ("content-type", "text/html"), + ("user-agent", "python/3.10"), + ], + original_http_header_order: Some(vec!["content-type", "user-agent"]), + original_headers: vec![ + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("content-type", "application/json"), + ("user-agent", "python/3.10"), + ], + }, + TestCase { + description: "original and base headers, with conflicts, preserve ua header", + base_http_headers: vec![ + ("accept", "text/html"), + ("content-type", "text/html"), + ("user-agent", "python/3.10"), + ], + original_http_header_order: Some(vec!["content-type", "user-agent"]), + original_headers: vec![ + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + preserve_ua_header: true, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "no opt-in base headers defined", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec!["content-type", "user-agent"]), + original_headers: vec![ + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "some opt-in base headers defined", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec![ + "content-type", + "cookie", + "user-agent", + "referer", + ]), + original_headers: vec![ + ("content-type", "application/json"), + ("cookie", "foo=bar"), + ("user-agent", "php/8.0"), + ("referer", "https://ramaproxy.org"), + ], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("cookie", "foo=bar"), + ("referer", "https://ramaproxy.org"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "all opt-in base headers defined", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec![ + "content-type", + "cookie", + "user-agent", + "referer", + "authorization", + ]), + original_headers: vec![ + ("content-type", "application/json"), + ("cookie", "foo=bar"), + ("user-agent", "php/8.0"), + ("referer", "https://ramaproxy.org"), + ("authorization", "Bearer 42"), + ], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 42"), + ("cookie", "foo=bar"), + ("referer", "https://ramaproxy.org"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ], + }, + TestCase { + description: "all opt-in base headers defined, with custom header marker", + base_http_headers: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 1234567890"), + ("x-rama-custom-header-marker", "1"), + ("cookie", "session=1234567890"), + ("referer", "https://example.com"), + ], + original_http_header_order: Some(vec![ + "content-type", + "cookie", + "user-agent", + "referer", + "authorization", + ]), + original_headers: vec![ + ("content-type", "application/json"), + ("cookie", "foo=bar"), + ("user-agent", "php/8.0"), + ("referer", "https://ramaproxy.org"), + ("authorization", "Bearer 42"), + ], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("accept", "text/html"), + ("authorization", "Bearer 42"), + ("content-type", "application/json"), + ("user-agent", "php/8.0"), + ("cookie", "foo=bar"), + ("referer", "https://ramaproxy.org"), + ], + }, + TestCase { + description: "realistic browser example", + base_http_headers: vec![ + ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "en-US,en;q=0.9"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Referer", "https://www.google.com/"), + ("Upgrade-Insecure-Requests", "1"), + ("x-rama-custom-header-marker", "1"), + ("Cookie", "rama-ua-test=1"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + original_http_header_order: Some(vec![ + "x-show-price", + "x-show-price-currency", + "accept-language", + "cookie", + ]), + original_headers: vec![ + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("cookie", "session=on; foo=bar"), + ("x-requested-with", "XMLHttpRequest"), + ("host", "www.example.com"), + ], + preserve_ua_header: false, + is_secure_request: false, + requested_client_hints: None, + expected: vec![ + ("Host", "www.example.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Upgrade-Insecure-Requests", "1"), + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("x-requested-with", "XMLHttpRequest"), + ("Cookie", "session=on; foo=bar"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + }, + TestCase { + description: "realistic browser example over tls", + base_http_headers: vec![ + ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "en-US,en;q=0.9"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Referer", "https://www.google.com/"), + ("Upgrade-Insecure-Requests", "1"), + ("x-rama-custom-header-marker", "1"), + ("Cookie", "rama-ua-test=1"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + original_http_header_order: Some(vec![ + "x-show-price", + "x-show-price-currency", + "accept-language", + "cookie", + ]), + original_headers: vec![ + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("cookie", "session=on; foo=bar"), + ("x-requested-with", "XMLHttpRequest"), + ], + preserve_ua_header: false, + is_secure_request: true, + requested_client_hints: None, + expected: vec![ + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Upgrade-Insecure-Requests", "1"), + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("x-requested-with", "XMLHttpRequest"), + ("Cookie", "session=on; foo=bar"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + }, + TestCase { + description: "realistic browser example over tls with requested client hints", + base_http_headers: vec![ + ("Host", "www.google.com"), + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "en-US,en;q=0.9"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Referer", "https://www.google.com/"), + ("Upgrade-Insecure-Requests", "1"), + ("x-rama-custom-header-marker", "1"), + ("Cookie", "rama-ua-test=1"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("Sec-CH-Downlink", "100"), + ("Sec-CH-Ect", "4g"), + ("Sec-CH-RTT", "100"), + ("Sec-CH-UA-Arch", "arm"), + ("Sec-CH-UA-Bitness", "64"), + ("Sec-CH-UA-Full-Version", "120.0.0.0"), + ("Sec-CH-UA-Full-Version-List", "Chrome 120.0.0.0"), + ("Sec-CH-UA-Mobile", "?0"), + ("Sec-CH-UA-Platform", "macOS"), + ("Sec-CH-UA-Platform-Version", "10.15.7"), + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + original_http_header_order: Some(vec![ + "x-show-price", + "x-show-price-currency", + "accept-language", + "cookie", + "Sec-CH-UA-Model", + ]), + original_headers: vec![ + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("accept-language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("cookie", "session=on; foo=bar"), + ("x-requested-with", "XMLHttpRequest"), + ("sec-ch-ua-model", "Macintosh"), + ], + preserve_ua_header: false, + is_secure_request: true, + requested_client_hints: Some(vec![ + "Downlink", + "Ect", + "RTT", + "Sec-CH-UA-Arch", + "Sec-CH-UA-Bitness", + "Sec-CH-UA-Model", + ]), + expected: vec![ + ( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + ( + "Accept", + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + ), + ("Accept-Language", "fr-FR,fr;q=0.9,en-US;q=0.8,en;q=0.7"), + ("Accept-Encoding", "gzip, deflate, br"), + ("Connection", "keep-alive"), + ("Upgrade-Insecure-Requests", "1"), + ("x-show-price", "true"), + ("x-show-price-currency", "USD"), + ("Sec-CH-UA-Model", "Macintosh"), + ("x-requested-with", "XMLHttpRequest"), + ("Cookie", "session=on; foo=bar"), + ("Sec-Fetch-Dest", "document"), + ("Sec-Fetch-Mode", "navigate"), + ("Sec-Fetch-Site", "cross-site"), + ("Sec-Fetch-User", "?1"), + ("Sec-CH-Downlink", "100"), + ("Sec-CH-Ect", "4g"), + ("Sec-CH-RTT", "100"), + ("Sec-CH-UA-Arch", "arm"), + ("Sec-CH-UA-Bitness", "64"), + ("Sec-CH-UA-Mobile", "?0"), // not requested, but low entropy + ("Sec-CH-UA-Platform", "macOS"), // not requested, but low entropy + ("DNT", "1"), + ("Sec-GPC", "1"), + ("Priority", "u=0, i"), + ], + }, + ]; + + for test_case in test_cases { + let base_http_headers = + Http1HeaderMap::from_iter(test_case.base_http_headers.into_iter().map( + |(name, value)| { + ( + Http1HeaderName::from_str(name).unwrap(), + HeaderValue::from_static(value), + ) + }, + )); + let original_http_header_order = test_case.original_http_header_order.map(|headers| { + OriginalHttp1Headers::from_iter( + headers + .into_iter() + .map(|header| Http1HeaderName::from_str(header).unwrap()), + ) + }); + let original_headers = HeaderMap::from_iter( + test_case.original_headers.into_iter().map(|(name, value)| { + ( + HeaderName::from_static(name), + HeaderValue::from_static(value), + ) + }), + ); + let preserve_ua_header = test_case.preserve_ua_header; + let is_secure_request = test_case.is_secure_request; + let requested_client_hints = test_case.requested_client_hints.map(|hints| { + hints + .into_iter() + .map(|hint| ClientHint::from_str(hint).unwrap()) + .collect::>() + }); + + let output_headers = merge_http_headers( + &base_http_headers, + original_http_header_order, + original_headers, + preserve_ua_header, + is_secure_request, + requested_client_hints.as_deref(), + ); + + let output_str = output_headers + .into_iter() + .map(|(name, value)| format!("{}: {}\r\n", name, value.to_str().unwrap())) + .join(""); + + let expected_str = test_case + .expected + .iter() + .map(|(name, value)| format!("{}: {}\r\n", name, value)) + .join(""); + + assert_eq!( + output_str, expected_str, + "test case '{}' failed", + test_case.description + ); + } + } + + #[test] + fn test_get_original_http_header_order() { + struct TestCase { + description: &'static str, + req: Request, + ctx: Option>, + expected_input_header_order: &'static str, + } + + let test_cases = [ + TestCase { + description: "empty request", + req: Request::new(Body::empty()), + ctx: None, + expected_input_header_order: "", + }, + TestCase { + description: "request original header order in req extensions", + req: { + let mut req = Request::new(Body::empty()); + req.extensions_mut() + .insert(OriginalHttp1Headers::from_iter([ + Http1HeaderName::from_str("x-REQUESTED-with").unwrap(), + Http1HeaderName::from_str("Accept").unwrap(), + ])); + req + }, + ctx: None, + expected_input_header_order: "x-REQUESTED-with,Accept", + }, + TestCase { + description: "request original header order in ctx", + req: { + let mut req = Request::new(Body::empty()); + req.headers_mut().insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("BAR"), + ); + req + }, + // ctx has precedence over req extensions + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert(OriginalHttp1Headers::from_iter([ + Http1HeaderName::from_str("x-REQUESTED-with").unwrap(), + Http1HeaderName::from_str("Accept").unwrap(), + ])); + ctx + }), + expected_input_header_order: "x-REQUESTED-with,Accept", + }, + TestCase { + description: "request with headers but no original header order", + req: { + let mut req = Request::new(Body::empty()); + req.headers_mut().insert( + HeaderName::from_static("foo"), + HeaderValue::from_static("BAR"), + ); + req + }, + ctx: None, + expected_input_header_order: "", + }, + ]; + + for test_case in test_cases { + let input_header_order = get_original_http_header_order( + test_case.ctx.as_ref().unwrap_or(&Context::default()), + &test_case.req, + None, + ) + .unwrap(); + let input_header_order_str = input_header_order + .map(|headers| { + headers + .into_iter() + .map(|header| header.to_string()) + .join(",") + }) + .unwrap_or_default(); + assert_eq!( + input_header_order_str, test_case.expected_input_header_order, + "{}", + test_case.description, + ); + } + } + + #[tokio::test] + async fn test_get_base_h2_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + settings: Http1Settings::default(), + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + settings: Http2Settings::default(), + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .map(|header| header.to_str().unwrap().to_owned()) + .unwrap_or_default(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::GET) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, ""); + + let req = Request::builder() + .method(Method::GET) + .version(Version::HTTP_2) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "navigate"); + } + + #[tokio::test] + async fn test_get_base_http_headers_profile_with_only_navigate_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + xhr: None, + fetch: None, + form: None, + }, + settings: Http1Settings::default(), + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + settings: Http2Settings::default(), + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::DELETE) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "navigate"); + } + + #[tokio::test] + async fn test_get_base_http_headers_profile_without_fetch_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + xhr: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("xhr").unwrap())] + .into_iter() + .collect(), + None, + )), + fetch: None, + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + settings: Http1Settings::default(), + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + settings: Http2Settings::default(), + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::DELETE) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "xhr"); + } + + #[tokio::test] + async fn test_get_base_http_headers_profile_without_xhr_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("fetch").unwrap())] + .into_iter() + .collect(), + None, + )), + xhr: None, + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + settings: Http1Settings::default(), + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + settings: Http2Settings::default(), + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + let req = Request::builder() + .method(Method::DELETE) + .header( + HeaderName::from_static("x-requested-with"), + "XmlHttpRequest", + ) + .body(Body::empty()) + .unwrap(); + let res = ua_service.serve(Context::default(), req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, "fetch"); + } + + #[tokio::test] + async fn test_get_base_http_headers() { + let ua = UserAgent::new( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ); + let ua_profile = UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: HttpProfile { + h1: Http1Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("navigate").unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("fetch").unwrap())] + .into_iter() + .collect(), + None, + )), + xhr: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("xhr").unwrap())] + .into_iter() + .collect(), + None, + )), + form: Some(Http1HeaderMap::new( + [(ETAG, HeaderValue::from_str("form").unwrap())] + .into_iter() + .collect(), + None, + )), + }, + settings: Http1Settings::default(), + }, + h2: Http2Profile { + headers: HttpHeadersProfile { + navigate: Http1HeaderMap::default(), + fetch: None, + xhr: None, + form: None, + }, + settings: Http2Settings::default(), + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + }; + + let ua_service = UserAgentEmulateService::new( + service_fn(async |req: Request| { + Ok::<_, Infallible>( + req.headers() + .get(ETAG) + .unwrap() + .to_str() + .unwrap() + .to_owned(), + ) + }), + ua_profile, + ); + + struct TestCase { + description: &'static str, + method: Option, + headers: Option, + ctx: Option>, + expected: &'static str, + } + + let test_cases = [ + TestCase { + description: "GET request", + method: None, + headers: None, + ctx: None, + expected: "navigate", + }, + TestCase { + description: "GET request with XRW header", + method: None, + headers: Some( + [( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "GET request with RequestInitiator hint Navigate", + method: None, + headers: None, + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert( + UserAgent::new("").with_request_initiator(RequestInitiator::Navigate), + ); + ctx + }), + expected: "navigate", + }, + TestCase { + description: "GET request with RequestInitiator hint Form", + method: None, + headers: None, + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert(UserAgent::new("").with_request_initiator(RequestInitiator::Form)); + ctx + }), + expected: "form", + }, + TestCase { + description: "explicit GET request", + method: Some(Method::GET), + headers: None, + ctx: None, + expected: "navigate", + }, + TestCase { + description: "explicit POST request", + method: Some(Method::POST), + headers: None, + ctx: None, + expected: "fetch", + }, + TestCase { + description: "explicit POST request with XRW header", + method: Some(Method::POST), + headers: Some( + [( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit POST request with multipart/form-data and XRW header", + method: Some(Method::POST), + headers: Some( + [ + ( + CONTENT_TYPE, + HeaderValue::from_static( + "multipart/form-data; boundary=ExampleBoundaryString", + ), + ), + ( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + ), + ] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit POST request with application/x-www-form-urlencoded and XRW header", + method: Some(Method::POST), + headers: Some( + [ + ( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + ), + ( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + ), + ] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit POST request with multipart/form-data", + method: Some(Method::POST), + headers: Some( + [( + CONTENT_TYPE, + HeaderValue::from_static( + "multipart/form-data; boundary=ExampleBoundaryString", + ), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "form", + }, + TestCase { + description: "explicit POST request with application/x-www-form-urlencoded", + method: Some(Method::POST), + headers: Some( + [( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "form", + }, + TestCase { + description: "explicit DELETE request with XRW header", + method: Some(Method::DELETE), + headers: Some( + [( + HeaderName::from_static("x-requested-with"), + HeaderValue::from_static("XmlHttpRequest"), + )] + .into_iter() + .collect(), + ), + ctx: None, + expected: "xhr", + }, + TestCase { + description: "explicit DELETE request", + method: Some(Method::DELETE), + headers: None, + ctx: None, + expected: "fetch", + }, + TestCase { + description: "explicit DELETE request with RequestInitiator hint", + method: Some(Method::DELETE), + headers: None, + ctx: Some({ + let mut ctx = Context::default(); + ctx.insert(UserAgent::new("").with_request_initiator(RequestInitiator::Xhr)); + ctx + }), + expected: "xhr", + }, + ]; + + for test_case in test_cases { + let mut req = Request::builder() + .method(test_case.method.unwrap_or(Method::GET)) + .body(Body::empty()) + .unwrap(); + if let Some(headers) = test_case.headers { + req.headers_mut().extend(headers); + } + let ctx = test_case.ctx.unwrap_or_default(); + let res = ua_service.serve(ctx, req).await.unwrap(); + let body = res.into_body().try_into_string().await.unwrap(); + assert_eq!(body, test_case.expected, "{}", test_case.description); + } + } +} diff --git a/rama-ua/src/lib.rs b/rama-ua/src/lib.rs index 94d4fb209..4c196c6f7 100644 --- a/rama-ua/src/lib.rs +++ b/rama-ua/src/lib.rs @@ -1,9 +1,14 @@ -//! User Agent (UA) parser and types. +//! User Agent (UA) parser and profiles. //! -//! This module provides a parser ([`UserAgent::new`]) for User Agents +//! This crate provides a parser ([`UserAgent::new`]) for User Agents //! as well as a classifier (`UserAgentClassifierLayer` in `rama_http`) that can be used to //! classify incoming requests based on their User Agent (header). //! +//! These can be used to know what UA is connecting to a server, +//! but it can also be used to emulate the UA from a client +//! via the profiles that are found in this crate as well, +//! be it builtin modules or custom ones. +//! //! Learn more about User Agents (UA) and why Rama supports it //! at . //! @@ -55,36 +60,11 @@ #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] -use serde::{Deserialize, Serialize}; - -mod info; -pub use info::{ - DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind, -}; - -mod parse; -use parse::parse_http_user_agent_header; +mod ua; +pub use ua::*; -/// Information that can be used to overwrite the [`UserAgent`] of an http request. -/// -/// Used by the `UserAgentClassifier` (see `rama-http`) to overwrite the specified -/// information duing the classification of the [`UserAgent`]. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct UserAgentOverwrites { - /// Overwrite the [`UserAgent`] of the http `Request` with a custom value. - /// - /// This value will be used instead of - /// the 'User-Agent' http (header) value. - /// - /// This is useful in case you cannot set the User-Agent header in your request. - pub ua: Option, - /// Overwrite the [`HttpAgent`] of the http `Request` with a custom value. - pub http: Option, - /// Overwrite the [`TlsAgent`] of the http `Request` with a custom value. - pub tls: Option, - /// Preserve the original [`UserAgent`] header of the http `Request`. - pub preserve_ua: Option, -} +mod profile; +pub use profile::*; -#[cfg(test)] -mod parse_tests; +mod emulate; +pub use emulate::*; diff --git a/rama-ua/src/profile/db.rs b/rama-ua/src/profile/db.rs new file mode 100644 index 000000000..95d20cf3a --- /dev/null +++ b/rama-ua/src/profile/db.rs @@ -0,0 +1,472 @@ +use itertools::Itertools as _; +use rand::seq::IndexedRandom as _; +use std::collections::HashMap; + +use crate::{DeviceKind, PlatformKind, UserAgent, UserAgentKind, UserAgentProfile}; + +#[derive(Debug, Default)] +pub struct UserAgentDatabase { + profiles: Vec, + + map_ua_string: HashMap, + + map_ua_kind: HashMap>, + map_platform: HashMap<(UserAgentKind, PlatformKind), Vec>, + map_device: HashMap<(UserAgentKind, DeviceKind), Vec>, +} + +impl UserAgentDatabase { + #[inline] + pub fn len(&self) -> usize { + self.profiles.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.profiles.is_empty() + } + + pub fn iter_ua_str(&self) -> impl Iterator { + self.map_ua_string.keys().map(|s| s.as_str()) + } + + pub fn iter_ua_kind(&self) -> impl Iterator { + self.map_ua_kind.keys() + } + + pub fn iter_platform(&self) -> impl Iterator { + self.map_platform + .keys() + .map(|(_, platform)| platform) + .dedup() + } + + pub fn iter_device(&self) -> impl Iterator { + self.map_device.keys().map(|(_, device)| device).dedup() + } + + pub fn insert(&mut self, profile: UserAgentProfile) { + let index = self.profiles.len(); + if let Some(ua_header) = profile.ua_str() { + self.map_ua_string.insert(ua_header.to_owned(), index); + } + + self.map_ua_kind + .entry(profile.ua_kind) + .or_default() + .push(index); + + if let Some(platform) = profile.platform { + self.map_platform + .entry((profile.ua_kind, platform)) + .or_default() + .push(index); + self.map_device + .entry((profile.ua_kind, platform.device())) + .or_default() + .push(index); + } + + self.profiles.push(profile); + } + + pub fn rnd(&self) -> Option<&UserAgentProfile> { + let ua_kind = self.market_rnd_ua_kind(); + self.map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + + pub fn get(&self, ua: &UserAgent) -> Option<&UserAgentProfile> { + if let Some(profile) = self + .map_ua_string + .get(ua.header_str()) + .and_then(|idx| self.profiles.get(*idx)) + { + return Some(profile); + } + + match (ua.ua_kind(), ua.platform(), ua.device()) { + (Some(ua_kind), Some(platform), _) => { + // UA + Platform Match (e.g. chrome windows) + self.map_platform + .get(&(ua_kind, platform)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (Some(ua_kind), None, Some(device)) => { + // UA + Device match (e.g. firefox desktop) + self.map_device + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (Some(ua_kind), None, None) => { + // random profile for this UA + self.map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (None, Some(platform), _) => { + // NOTE: I guestimated these numbers... Feel free to help improve these + let ua_kind = match platform { + PlatformKind::Windows => self.market_rnd_ua_kind_with_shares(7, 0), + PlatformKind::MacOS => self.market_rnd_ua_kind_with_shares(9, 35), + PlatformKind::Linux => self.market_rnd_ua_kind_with_shares(22, 0), + PlatformKind::Android => self.market_rnd_ua_kind_with_shares(3, 0), + PlatformKind::IOS => self.market_rnd_ua_kind_with_shares(5, 42), + }; + self.map_platform + .get(&(ua_kind, platform)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + (None, None, device) => { + // random ua kind matching with device or not + match device { + Some(device) => { + let ua_kind = match device { + // https://gs.statcounter.com/browser-market-share/desktop/worldwide (feb 2025) + DeviceKind::Desktop => self.market_rnd_ua_kind_with_shares(7, 9), + // https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + DeviceKind::Mobile => self.market_rnd_ua_kind_with_shares(1, 23), + }; + self.map_device + .get(&(ua_kind, device)) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + None => { + let ua_kind = self.market_rnd_ua_kind(); + self.map_ua_kind + .get(&ua_kind) + .and_then(|v| v.choose(&mut rand::rng())) + .and_then(|idx| self.profiles.get(*idx)) + } + } + } + } + } + + #[inline] + pub fn iter(&self) -> impl Iterator { + self.profiles.iter() + } + + fn market_rnd_ua_kind(&self) -> UserAgentKind { + // https://gs.statcounter.com/browser-market-share/mobile/worldwide (feb 2025) + self.market_rnd_ua_kind_with_shares(3, 18) + } + + fn market_rnd_ua_kind_with_shares(&self, firefox: i32, safari: i32) -> UserAgentKind { + let r = rand::random_range(0..=100); + if r < firefox && self.map_ua_kind.contains_key(&UserAgentKind::Firefox) { + UserAgentKind::Firefox + } else if r < safari + firefox && self.map_ua_kind.contains_key(&UserAgentKind::Safari) { + UserAgentKind::Safari + } else { + UserAgentKind::Chromium + } + } +} + +impl FromIterator for UserAgentDatabase { + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let (lb, _) = iter.size_hint(); + + let mut db = UserAgentDatabase { + profiles: Vec::with_capacity(lb), + ..Default::default() + }; + + for profile in iter { + db.insert(profile); + } + + db + } +} + +#[cfg(test)] +mod tests { + use rama_http_types::{HeaderValue, header::USER_AGENT, proto::h1::Http1HeaderMap}; + + use super::*; + + #[test] + fn test_ua_db_empty() { + let db = UserAgentDatabase::default(); + assert_eq!(db.iter().count(), 0); + assert!(db.get(&UserAgent::new("")).is_none()); + assert!(db.get(&UserAgent::new("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36")).is_none()); + + let rnd = db.rnd(); + assert!(rnd.is_none()); + + assert!(db.iter_ua_str().next().is_none()); + assert!(db.iter_ua_kind().next().is_none()); + assert!(db.iter_platform().next().is_none()); + assert!(db.iter_device().next().is_none()); + } + + #[test] + fn test_ua_db_get_by_ua_str() { + let db = get_dummy_ua_db(); + + let profile = db.get(&UserAgent::new("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")).unwrap(); + assert_eq!(profile.ua_kind, UserAgentKind::Chromium); + assert_eq!(profile.ua_version, Some(120)); + assert_eq!(profile.platform, Some(PlatformKind::Windows)); + assert_eq!( + profile + .http + .h1 + .headers + .navigate + .get(USER_AGENT) + .unwrap() + .to_str() + .unwrap(), + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0" + ); + assert_eq!( + profile + .http + .h2 + .headers + .navigate + .get(USER_AGENT) + .unwrap() + .to_str() + .unwrap(), + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0" + ); + } + + #[test] + fn test_ua_db_get_by_ua_kind_and_device() { + let db = get_dummy_ua_db(); + let test_cases = [ + ( + "Chrome Desktop", + UserAgentKind::Chromium, + DeviceKind::Desktop, + ), + ("Chrome Mobile", UserAgentKind::Chromium, DeviceKind::Mobile), + ( + "Desktop Firefox", + UserAgentKind::Firefox, + DeviceKind::Desktop, + ), + ( + "Mobile with Firefox", + UserAgentKind::Firefox, + DeviceKind::Mobile, + ), + ( + "Safari on Desktop", + UserAgentKind::Safari, + DeviceKind::Desktop, + ), + ("mobile&safari", UserAgentKind::Safari, DeviceKind::Mobile), + ]; + + for (ua_str, ua_kind, device) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!(profile.ua_kind, ua_kind); + assert!( + profile + .platform + .map(|p| p.device() == device) + .unwrap_or_default() + ); + } + } + + #[test] + fn test_ua_db_get_by_ua_kind_and_platform() { + let db = get_dummy_ua_db(); + let test_cases = [ + ( + "Chrome Windows", + UserAgentKind::Chromium, + PlatformKind::Windows, + ), + ("MacOS Chrome", UserAgentKind::Chromium, PlatformKind::MacOS), + ( + "Chrome&Windows", + UserAgentKind::Chromium, + PlatformKind::Windows, + ), + ( + "Firefox on Windows", + UserAgentKind::Firefox, + PlatformKind::Windows, + ), + ( + "MacOS with Firefox", + UserAgentKind::Firefox, + PlatformKind::MacOS, + ), + ( + "Firefox + Linux", + UserAgentKind::Firefox, + PlatformKind::Linux, + ), + ]; + + for (ua_str, ua_kind, platform) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!(profile.ua_kind, ua_kind); + assert_eq!(profile.platform, Some(platform)); + } + } + + #[test] + fn test_ua_db_get_by_ua_kind() { + let db = get_dummy_ua_db(); + let test_cases = [ + ("Firefox", UserAgentKind::Firefox), + ("Safari", UserAgentKind::Safari), + ("Chrome", UserAgentKind::Chromium), + ("Chromium", UserAgentKind::Chromium), + ]; + + for (ua_str, ua_kind) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!(profile.ua_kind, ua_kind, "ua_str: {}", ua_str); + } + } + + #[test] + fn test_ua_db_get_by_device() { + let db = get_dummy_ua_db(); + let test_cases = [ + ("Desktop", DeviceKind::Desktop), + ("DESKTOP", DeviceKind::Desktop), + ("desktop", DeviceKind::Desktop), + ("Mobile", DeviceKind::Mobile), + ("MOBILE", DeviceKind::Mobile), + ("mobile", DeviceKind::Mobile), + ]; + + for (ua_str, device) in test_cases { + let profile = db.get(&UserAgent::new(ua_str)).expect(ua_str); + assert_eq!( + profile.platform.map(|p| p.device() == device), + Some(true), + "ua_str: {}", + ua_str + ); + } + } + + #[test] + fn test_ua_db_rnd() { + let db = get_dummy_ua_db(); + + let mut set = std::collections::HashSet::new(); + for _ in 0..db.len() * 1000 { + let rnd = db.rnd().unwrap(); + set.insert( + rnd.http + .h1 + .headers + .navigate + .get(USER_AGENT) + .expect("ua header") + .to_str() + .expect("utf-8 ua header value") + .to_owned(), + ); + } + + assert_eq!(set.len(), db.len()); + } + + fn dummy_ua_profile_from_str(s: &str) -> UserAgentProfile { + let ua = UserAgent::new(s); + UserAgentProfile { + ua_kind: ua.ua_kind().unwrap(), + ua_version: ua.ua_version(), + platform: ua.platform(), + http: crate::HttpProfile { + h1: crate::Http1Profile { + headers: crate::HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(USER_AGENT, HeaderValue::from_str(s).unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + settings: crate::Http1Settings::default(), + }, + h2: crate::Http2Profile { + headers: crate::HttpHeadersProfile { + navigate: Http1HeaderMap::new( + [(USER_AGENT, HeaderValue::from_str(s).unwrap())] + .into_iter() + .collect(), + None, + ), + fetch: None, + xhr: None, + form: None, + }, + settings: crate::Http2Settings::default(), + }, + }, + #[cfg(feature = "tls")] + tls: crate::TlsProfile { + client_config: std::sync::Arc::new(rama_net::tls::client::ClientConfig::default()), + }, + } + } + + fn get_dummy_ua_db() -> UserAgentDatabase { + let mut db = UserAgentDatabase::default(); + + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 14_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Linux; Android 10; HD1913) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36 EdgA/120.0.0.0")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPhone; CPU iPhone OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 EdgiOS/120.0.0.0 Mobile/15E148 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Windows NT 11.0; Win64; x64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 14.1; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (X11; Linux x86_64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:120.0) Gecko/20100101 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str( + "Mozilla/5.0 (Android 14; Mobile; rv:120.0) Gecko/120.0 Firefox/120.0", + )); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPhone; CPU iPhone OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) FxiOS/120.0 Mobile/15E148 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 14_1) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Safari/605.1.15")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPhone; CPU iPhone OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (iPad; CPU OS 17_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1")); + db.insert(dummy_ua_profile_from_str("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15")); + + db + } +} diff --git a/rama-ua/src/profile/http.rs b/rama-ua/src/profile/http.rs new file mode 100644 index 000000000..6139cd4c5 --- /dev/null +++ b/rama-ua/src/profile/http.rs @@ -0,0 +1,44 @@ +use rama_http_types::{ + HeaderName, + proto::{h1::Http1HeaderMap, h2::PseudoHeader}, +}; +use serde::{Deserialize, Serialize}; + +pub static CUSTOM_HEADER_MARKER: HeaderName = + HeaderName::from_static("x-rama-custom-header-marker"); + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct HttpProfile { + pub h1: Http1Profile, + pub h2: Http2Profile, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct HttpHeadersProfile { + pub navigate: Http1HeaderMap, + pub fetch: Option, + pub xhr: Option, + pub form: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Http1Profile { + pub headers: HttpHeadersProfile, + pub settings: Http1Settings, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct Http1Settings { + pub title_case_headers: bool, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Http2Profile { + pub headers: HttpHeadersProfile, + pub settings: Http2Settings, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct Http2Settings { + pub http_pseudo_headers: Option>, +} diff --git a/rama-ua/src/profile/mod.rs b/rama-ua/src/profile/mod.rs new file mode 100644 index 000000000..662e04556 --- /dev/null +++ b/rama-ua/src/profile/mod.rs @@ -0,0 +1,13 @@ +mod http; +pub use http::*; + +#[cfg(feature = "tls")] +mod tls; +#[cfg(feature = "tls")] +pub use tls::*; + +mod ua; +pub use ua::*; + +mod db; +pub use db::*; diff --git a/rama-ua/src/profile/tls.rs b/rama-ua/src/profile/tls.rs new file mode 100644 index 000000000..240bddde9 --- /dev/null +++ b/rama-ua/src/profile/tls.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use rama_net::tls::client::{ClientConfig, ClientHello, ServerVerifyMode}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +pub struct TlsProfile { + pub client_config: Arc, +} + +impl<'de> Deserialize<'de> for TlsProfile { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let input = TlsProfileSerde::deserialize(deserializer)?; + let mut cfg = ClientConfig::from(input.client_hello); + if input.insecure { + cfg.server_verify_mode = Some(ServerVerifyMode::Disable); + } + Ok(Self { + client_config: Arc::new(cfg), + }) + } +} + +impl Serialize for TlsProfile { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let insecure = matches!( + self.client_config.server_verify_mode, + Some(ServerVerifyMode::Disable) + ); + TlsProfileSerde { + client_hello: self.client_config.as_ref().clone().into(), + insecure, + } + .serialize(serializer) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct TlsProfileSerde { + client_hello: ClientHello, + insecure: bool, +} diff --git a/rama-ua/src/profile/ua.rs b/rama-ua/src/profile/ua.rs new file mode 100644 index 000000000..824217b36 --- /dev/null +++ b/rama-ua/src/profile/ua.rs @@ -0,0 +1,47 @@ +use rama_http_types::header::USER_AGENT; +use serde::{Deserialize, Serialize}; + +use crate::{PlatformKind, UserAgentKind}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UserAgentProfile { + /// The kind of [`crate::UserAgent`] + pub ua_kind: UserAgentKind, + /// The version of the [`crate::UserAgent`] + pub ua_version: Option, + /// The platform the [`crate::UserAgent`] is running on. + pub platform: Option, + + /// The profile information regarding the http implementation of the [`crate::UserAgent`]. + pub http: super::HttpProfile, + + #[cfg(feature = "tls")] + /// The profile information regarding the tls implementation of the [`crate::UserAgent`]. + pub tls: super::TlsProfile, +} + +impl UserAgentProfile { + pub fn ua_str(&self) -> Option<&str> { + if let Some(ua) = self + .http + .h1 + .headers + .navigate + .get(USER_AGENT) + .and_then(|v| v.to_str().ok()) + { + Some(ua) + } else if let Some(ua) = self + .http + .h2 + .headers + .navigate + .get(USER_AGENT) + .and_then(|v| v.to_str().ok()) + { + Some(ua) + } else { + None + } + } +} diff --git a/rama-ua/src/info.rs b/rama-ua/src/ua/info.rs similarity index 60% rename from rama-ua/src/info.rs rename to rama-ua/src/ua/info.rs index 6b12217a4..28d901686 100644 --- a/rama-ua/src/info.rs +++ b/rama-ua/src/ua/info.rs @@ -1,19 +1,22 @@ -use super::parse_http_user_agent_header; +use super::{RequestInitiator, parse_http_user_agent_header}; use rama_core::error::OpaqueError; +use rama_http_types::headers::ClientHint; use rama_utils::macros::match_ignore_ascii_case_str; use serde::{Deserialize, Deserializer, Serialize}; -use std::{convert::Infallible, fmt, str::FromStr}; +use std::{convert::Infallible, fmt, str::FromStr, sync::Arc}; /// User Agent (UA) information. /// /// See [the module level documentation](crate) for more information. #[derive(Debug, Clone)] pub struct UserAgent { - pub(super) header: String, + pub(super) header: Arc, pub(super) data: UserAgentData, pub(super) http_agent_overwrite: Option, pub(super) tls_agent_overwrite: Option, pub(super) preserve_ua_header: bool, + pub(super) request_initiator: Option, + pub(super) requested_client_hints: Option>, } impl fmt::Display for UserAgent { @@ -27,13 +30,28 @@ impl fmt::Display for UserAgent { pub(super) enum UserAgentData { Standard { info: UserAgentInfo, - platform: Option, + platform_like: Option, }, Platform(PlatformKind), Device(DeviceKind), Unknown, } +#[derive(Debug, Clone)] +pub(super) enum PlatformLike { + Platform(PlatformKind), + Device(DeviceKind), +} + +impl PlatformLike { + pub(super) fn device(&self) -> DeviceKind { + match self { + PlatformLike::Platform(platform_kind) => platform_kind.device(), + PlatformLike::Device(device_kind) => *device_kind, + } + } +} + /// Information about the [`UserAgent`] #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct UserAgentInfo { @@ -45,18 +63,30 @@ pub struct UserAgentInfo { impl UserAgent { /// Create a new [`UserAgent`] from a `User-Agent` (header) value. - pub fn new(header: impl Into) -> Self { + pub fn new(header: impl Into>) -> Self { parse_http_user_agent_header(header.into()) } /// Overwrite the [`HttpAgent`] advertised by the [`UserAgent`]. - pub fn with_http_agent(&mut self, http_agent: HttpAgent) -> &mut Self { + pub fn with_http_agent(mut self, http_agent: HttpAgent) -> Self { + self.http_agent_overwrite = Some(http_agent); + self + } + + /// Overwrite the [`HttpAgent`] advertised by the [`UserAgent`]. + pub fn set_http_agent(&mut self, http_agent: HttpAgent) -> &mut Self { self.http_agent_overwrite = Some(http_agent); self } /// Overwrite the [`TlsAgent`] advertised by the [`UserAgent`]. - pub fn with_tls_agent(&mut self, tls_agent: TlsAgent) -> &mut Self { + pub fn with_tls_agent(mut self, tls_agent: TlsAgent) -> Self { + self.tls_agent_overwrite = Some(tls_agent); + self + } + + /// Overwrite the [`TlsAgent`] advertised by the [`UserAgent`]. + pub fn set_tls_agent(&mut self, tls_agent: TlsAgent) -> &mut Self { self.tls_agent_overwrite = Some(tls_agent); self } @@ -65,7 +95,16 @@ impl UserAgent { /// /// This is used to indicate to emulators that they should respect the User-Agent header /// attached to this [`UserAgent`], if possible. - pub fn with_preserve_ua_header(&mut self, preserve: bool) -> &mut Self { + pub fn with_preserve_ua_header(mut self, preserve: bool) -> Self { + self.preserve_ua_header = preserve; + self + } + + /// Preserve the incoming `User-Agent` (header) value. + /// + /// This is used to indicate to emulators that they should respect the User-Agent header + /// attached to this [`UserAgent`], if possible. + pub fn set_preserve_ua_header(&mut self, preserve: bool) -> &mut Self { self.preserve_ua_header = preserve; self } @@ -76,28 +115,73 @@ impl UserAgent { self.preserve_ua_header } + /// Define the [`RequestInitiator`] hint. + pub fn with_request_initiator(mut self, req_init: RequestInitiator) -> Self { + self.request_initiator = Some(req_init); + self + } + + /// Define the [`RequestInitiator`] hint. + pub fn set_request_initiator(&mut self, req_init: RequestInitiator) -> &mut Self { + self.request_initiator = Some(req_init); + self + } + + /// returns the [`RequestInitiator`] hint if available. + pub fn request_initiator(&self) -> Option { + self.request_initiator + } + + /// Define the requested (High-Entropy) Client Hints. + pub fn with_requested_client_hints(mut self, req_client_hints: Vec) -> Self { + self.requested_client_hints = Some(req_client_hints); + self + } + + /// Define the requested (High-Entropy) Client Hints. + pub fn set_requested_client_hints(&mut self, req_client_hints: Vec) -> &mut Self { + self.requested_client_hints = Some(req_client_hints); + self + } + + /// Append a requested (High-Entropy) Client Hint. + pub fn append_requested_client_hint(&mut self, hint: ClientHint) -> &mut Self { + self.requested_client_hints + .get_or_insert_default() + .push(hint); + self + } + + /// Extend the requested (High-Entropy) Client Hints. + pub fn extend_requested_client_hints( + &mut self, + hints: impl IntoIterator, + ) -> &mut Self { + self.requested_client_hints + .get_or_insert_default() + .extend(hints); + self + } + + /// returns the requested (High-Entropy) Client Hints. + pub fn requested_client_hints(&self) -> impl Iterator { + self.requested_client_hints.iter().flatten() + } + /// returns the `User-Agent` (header) value used by the [`UserAgent`]. pub fn header_str(&self) -> &str { &self.header } /// returns the device kind of the [`UserAgent`]. - pub fn device(&self) -> DeviceKind { + pub fn device(&self) -> Option { match &self.data { - UserAgentData::Standard { platform, .. } => match platform { - Some(PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux) | None => { - DeviceKind::Desktop - } - Some(PlatformKind::Android | PlatformKind::IOS) => DeviceKind::Mobile, - }, - UserAgentData::Platform(platform) => match platform { - PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux => { - DeviceKind::Desktop - } - PlatformKind::Android | PlatformKind::IOS => DeviceKind::Mobile, - }, - UserAgentData::Device(kind) => *kind, - UserAgentData::Unknown => DeviceKind::Desktop, + UserAgentData::Standard { platform_like, .. } => { + platform_like.as_ref().map(|p| p.device()) + } + UserAgentData::Platform(platform) => Some(platform.device()), + UserAgentData::Device(kind) => Some(*kind), + UserAgentData::Unknown => None, } } @@ -111,12 +195,40 @@ impl UserAgent { } } + /// returns the [`UserAgentKind`] used by the [`UserAgent`], if known. + pub fn ua_kind(&self) -> Option { + match self.http_agent_overwrite { + Some(HttpAgent::Chromium) => Some(UserAgentKind::Chromium), + Some(HttpAgent::Safari) => Some(UserAgentKind::Safari), + Some(HttpAgent::Firefox) => Some(UserAgentKind::Firefox), + Some(HttpAgent::Preserve) => None, + None => match &self.data { + UserAgentData::Standard { + info: UserAgentInfo { kind, .. }, + .. + } => Some(*kind), + _ => None, + }, + } + } + + /// returns the version of the [`UserAgent`], if known. + pub fn ua_version(&self) -> Option { + match &self.data { + UserAgentData::Standard { info, .. } => info.version, + _ => None, + } + } + /// returns the [`PlatformKind`] used by the [`UserAgent`], if known. /// /// This is the platform the [`UserAgent`] is running on. pub fn platform(&self) -> Option { match &self.data { - UserAgentData::Standard { platform, .. } => *platform, + UserAgentData::Standard { platform_like, .. } => match platform_like { + Some(PlatformLike::Platform(platform)) => Some(*platform), + None | Some(PlatformLike::Device(_)) => None, + }, UserAgentData::Platform(platform) => Some(*platform), _ => None, } @@ -125,17 +237,17 @@ impl UserAgent { /// returns the [`HttpAgent`] used by the [`UserAgent`]. /// /// [`UserAgent`]: super::UserAgent - pub fn http_agent(&self) -> HttpAgent { - match &self.http_agent_overwrite { - Some(agent) => agent.clone(), + pub fn http_agent(&self) -> Option { + match self.http_agent_overwrite { + Some(agent) => Some(agent), None => match &self.data { - UserAgentData::Standard { info, .. } => match info.kind { + UserAgentData::Standard { info, .. } => Some(match info.kind { UserAgentKind::Chromium => HttpAgent::Chromium, UserAgentKind::Firefox => HttpAgent::Firefox, UserAgentKind::Safari => HttpAgent::Safari, - }, - UserAgentData::Device(_) | UserAgentData::Platform(_) | UserAgentData::Unknown => { - HttpAgent::Chromium + }), + UserAgentData::Platform(_) | UserAgentData::Device(_) | UserAgentData::Unknown => { + None } }, } @@ -144,17 +256,17 @@ impl UserAgent { /// returns the [`TlsAgent`] used by the [`UserAgent`]. /// /// [`UserAgent`]: super::UserAgent - pub fn tls_agent(&self) -> TlsAgent { - match &self.tls_agent_overwrite { - Some(agent) => agent.clone(), + pub fn tls_agent(&self) -> Option { + match self.tls_agent_overwrite { + Some(agent) => Some(agent), None => match &self.data { - UserAgentData::Standard { info, .. } => match info.kind { + UserAgentData::Standard { info, .. } => Some(match info.kind { UserAgentKind::Chromium => TlsAgent::Boringssl, UserAgentKind::Firefox => TlsAgent::Nss, UserAgentKind::Safari => TlsAgent::Rustls, - }, + }), UserAgentData::Device(_) | UserAgentData::Platform(_) | UserAgentData::Unknown => { - TlsAgent::Rustls + None } }, } @@ -180,16 +292,56 @@ pub enum UserAgentKind { Safari, } +impl UserAgentKind { + pub fn as_str(&self) -> &'static str { + match self { + UserAgentKind::Chromium => "Chromium", + UserAgentKind::Firefox => "Firefox", + UserAgentKind::Safari => "Safari", + } + } +} + impl fmt::Display for UserAgentKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - UserAgentKind::Chromium => write!(f, "Chromium"), - UserAgentKind::Firefox => write!(f, "Firefox"), - UserAgentKind::Safari => write!(f, "Safari"), + write!(f, "{}", self.as_str()) + } +} + +impl FromStr for UserAgentKind { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + match_ignore_ascii_case_str! { + match (s) { + "chromium" => Ok(UserAgentKind::Chromium), + "firefox" => Ok(UserAgentKind::Firefox), + "safari" => Ok(UserAgentKind::Safari), + _ => Err(OpaqueError::from_display(format!("invalid user agent kind: {}", s))), + } } } } +impl Serialize for UserAgentKind { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for UserAgentKind { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + s.parse::().map_err(serde::de::Error::custom) + } +} + /// Device on which the [`UserAgent`] operates. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum DeviceKind { @@ -199,15 +351,21 @@ pub enum DeviceKind { Mobile, } -impl fmt::Display for DeviceKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl DeviceKind { + pub fn as_str(&self) -> &'static str { match self { - DeviceKind::Desktop => write!(f, "Desktop"), - DeviceKind::Mobile => write!(f, "Mobile"), + DeviceKind::Desktop => "Desktop", + DeviceKind::Mobile => "Mobile", } } } +impl fmt::Display for DeviceKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + /// Platform within the [`UserAgent`] operates. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PlatformKind { @@ -223,20 +381,71 @@ pub enum PlatformKind { IOS, } -impl fmt::Display for PlatformKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl PlatformKind { + pub fn as_str(&self) -> &'static str { + match self { + PlatformKind::Windows => "Windows", + PlatformKind::MacOS => "MacOS", + PlatformKind::Linux => "Linux", + PlatformKind::Android => "Android", + PlatformKind::IOS => "iOS", + } + } + + pub fn device(&self) -> DeviceKind { match self { - PlatformKind::Windows => write!(f, "Windows"), - PlatformKind::MacOS => write!(f, "MacOS"), - PlatformKind::Linux => write!(f, "Linux"), - PlatformKind::Android => write!(f, "Android"), - PlatformKind::IOS => write!(f, "iOS"), + PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux => { + DeviceKind::Desktop + } + PlatformKind::Android | PlatformKind::IOS => DeviceKind::Mobile, } } } +impl FromStr for PlatformKind { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + match_ignore_ascii_case_str! { + match (s) { + "windows" => Ok(PlatformKind::Windows), + "macos" => Ok(PlatformKind::MacOS), + "linux" => Ok(PlatformKind::Linux), + "android" => Ok(PlatformKind::Android), + "ios" => Ok(PlatformKind::IOS), + _ => Err(OpaqueError::from_display(format!("invalid platform: {}", s))), + } + } + } +} + +impl Serialize for PlatformKind { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for PlatformKind { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + s.parse::().map_err(serde::de::Error::custom) + } +} + +impl fmt::Display for PlatformKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + /// Http implementation used by the [`UserAgent`] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum HttpAgent { /// Chromium based browsers share the same http implementation Chromium, @@ -251,17 +460,29 @@ pub enum HttpAgent { Preserve, } +impl HttpAgent { + pub fn as_str(&self) -> &'static str { + match self { + HttpAgent::Chromium => "Chromium", + HttpAgent::Firefox => "Firefox", + HttpAgent::Safari => "Safari", + HttpAgent::Preserve => "Preserve", + } + } +} + +impl fmt::Display for HttpAgent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + impl Serialize for HttpAgent { fn serialize(&self, serializer: S) -> Result where S: serde::ser::Serializer, { - match self { - HttpAgent::Chromium => serializer.serialize_str("Chromium"), - HttpAgent::Firefox => serializer.serialize_str("Firefox"), - HttpAgent::Safari => serializer.serialize_str("Safari"), - HttpAgent::Preserve => serializer.serialize_str("Preserve"), - } + serializer.serialize_str(self.as_str()) } } @@ -271,15 +492,7 @@ impl<'de> Deserialize<'de> for HttpAgent { D: Deserializer<'de>, { let s = >::deserialize(deserializer)?; - match_ignore_ascii_case_str! { - match (s) { - "chrome" | "chromium" => Ok(HttpAgent::Chromium), - "Firefox" => Ok(HttpAgent::Firefox), - "Safari" => Ok(HttpAgent::Safari), - "preserve" => Ok(HttpAgent::Preserve), - _ => Err(serde::de::Error::custom("invalid http agent")), - } - } + s.parse::().map_err(serde::de::Error::custom) } } @@ -299,19 +512,8 @@ impl FromStr for HttpAgent { } } -impl fmt::Display for HttpAgent { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - HttpAgent::Chromium => write!(f, "Chromium"), - HttpAgent::Firefox => write!(f, "Firefox"), - HttpAgent::Safari => write!(f, "Safari"), - HttpAgent::Preserve => write!(f, "Preserve"), - } - } -} - /// Tls implementation used by the [`UserAgent`] -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TlsAgent { /// Rustls is used as a fallback for all user agents, /// that are not chromium based. @@ -328,28 +530,29 @@ pub enum TlsAgent { Preserve, } -impl fmt::Display for TlsAgent { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl TlsAgent { + pub fn as_str(&self) -> &'static str { match self { - TlsAgent::Rustls => write!(f, "Rustls"), - TlsAgent::Boringssl => write!(f, "Boringssl"), - TlsAgent::Nss => write!(f, "NSS"), - TlsAgent::Preserve => write!(f, "Preserve"), + TlsAgent::Rustls => "Rustls", + TlsAgent::Boringssl => "Boringssl", + TlsAgent::Nss => "NSS", + TlsAgent::Preserve => "Preserve", } } } +impl fmt::Display for TlsAgent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + impl Serialize for TlsAgent { fn serialize(&self, serializer: S) -> Result where S: serde::ser::Serializer, { - match self { - TlsAgent::Rustls => serializer.serialize_str("Rustls"), - TlsAgent::Boringssl => serializer.serialize_str("Boringssl"), - TlsAgent::Nss => serializer.serialize_str("NSS"), - TlsAgent::Preserve => serializer.serialize_str("Preserve"), - } + serializer.serialize_str(self.as_str()) } } @@ -359,15 +562,7 @@ impl<'de> Deserialize<'de> for TlsAgent { D: Deserializer<'de>, { let s = >::deserialize(deserializer)?; - match_ignore_ascii_case_str! { - match (s) { - "rustls" => Ok(TlsAgent::Rustls), - "boring" | "boringssl" => Ok(TlsAgent::Boringssl), - "nss" => Ok(TlsAgent::Nss), - "preserve" => Ok(TlsAgent::Preserve), - _ => Err(serde::de::Error::custom("invalid tls agent")), - } - } + s.parse::().map_err(serde::de::Error::custom) } } @@ -406,9 +601,9 @@ mod tests { }) ); assert_eq!(ua.platform(), Some(PlatformKind::MacOS)); - assert_eq!(ua.device(), DeviceKind::Desktop); - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); } #[test] @@ -426,9 +621,9 @@ mod tests { }) ); assert_eq!(ua.platform(), Some(PlatformKind::MacOS)); - assert_eq!(ua.device(), DeviceKind::Desktop); - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); } #[test] diff --git a/rama-ua/src/ua/mod.rs b/rama-ua/src/ua/mod.rs new file mode 100644 index 000000000..c0ccb0edb --- /dev/null +++ b/rama-ua/src/ua/mod.rs @@ -0,0 +1,107 @@ +use rama_core::error::OpaqueError; +use rama_http_types::headers::ClientHint; +use rama_utils::macros::match_ignore_ascii_case_str; +use serde::{Deserialize, Deserializer, Serialize}; +use std::{fmt, str::FromStr}; + +mod info; +pub use info::{ + DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind, +}; + +mod parse; +use parse::parse_http_user_agent_header; +pub(crate) use parse::{contains_ignore_ascii_case, starts_with_ignore_ascii_case}; + +/// Information that can be used to overwrite the [`UserAgent`] of an http request. +/// +/// Used by the `UserAgentClassifier` (see `rama-http`) to overwrite the specified +/// information duing the classification of the [`UserAgent`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct UserAgentOverwrites { + /// Overwrite the [`UserAgent`] of the http `Request` with a custom value. + /// + /// This value will be used instead of + /// the 'User-Agent' http (header) value. + /// + /// This is useful in case you cannot set the User-Agent header in your request. + pub ua: Option, + /// Overwrite the [`HttpAgent`] of the http `Request` with a custom value. + pub http: Option, + /// Overwrite the [`TlsAgent`] of the http `Request` with a custom value. + pub tls: Option, + /// Preserve the original [`UserAgent`] header of the http `Request`. + pub preserve_ua: Option, + /// Requested (High-Entropy) Client Hints. + pub req_client_hints: Option>, + /// Hint a specific request intiator for UA Emulation. A related + /// or default initiator might be chosen in case the hinted one is not available. + /// + /// In case this hint is not specified it will be gussed for you instead based + /// on the request method and headers. + pub req_init: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RequestInitiator { + Navigate, + Form, + Xhr, + Fetch, +} + +impl RequestInitiator { + pub fn as_str(&self) -> &'static str { + match self { + RequestInitiator::Navigate => "navigate", + RequestInitiator::Form => "form", + RequestInitiator::Xhr => "xhr", + RequestInitiator::Fetch => "fetch", + } + } +} + +impl fmt::Display for RequestInitiator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl Serialize for RequestInitiator { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl<'de> Deserialize<'de> for RequestInitiator { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = >::deserialize(deserializer)?; + s.parse::() + .map_err(serde::de::Error::custom) + } +} + +impl FromStr for RequestInitiator { + type Err = OpaqueError; + + fn from_str(s: &str) -> Result { + match_ignore_ascii_case_str! { + match (s) { + "navigate" => Ok(RequestInitiator::Navigate), + "form" => Ok(RequestInitiator::Form), + "xhr" => Ok(RequestInitiator::Xhr), + "fetch" => Ok(RequestInitiator::Fetch), + _ => Err(OpaqueError::from_display(format!("invalid request initiator: {}", s))), + } + } + } +} + +#[cfg(test)] +mod parse_tests; diff --git a/rama-ua/src/parse.rs b/rama-ua/src/ua/parse.rs similarity index 74% rename from rama-ua/src/parse.rs rename to rama-ua/src/ua/parse.rs index 7e375ae6a..c05abb03d 100644 --- a/rama-ua/src/parse.rs +++ b/rama-ua/src/ua/parse.rs @@ -1,8 +1,10 @@ #![allow(dead_code)] +use std::sync::Arc; + use super::{ DeviceKind, PlatformKind, UserAgent, UserAgentKind, - info::{UserAgentData, UserAgentInfo}, + info::{PlatformLike, UserAgentData, UserAgentInfo}, }; /// Maximum length of a User Agent string that we take into consideration. @@ -20,8 +22,9 @@ const MAX_UA_LENGTH: usize = 512; /// - complete: we do not care about all the possible user agents out there, only the popular ones. /// /// That said. Do open a ticket if you find bugs or think something is missing. -pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { - let ua = header.as_str(); +pub(crate) fn parse_http_user_agent_header(header: impl Into>) -> UserAgent { + let header = header.into(); + let ua = header.as_ref(); let ua = if ua.len() > MAX_UA_LENGTH { match ua.get(..MAX_UA_LENGTH) { Some(s) => s, @@ -32,6 +35,8 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, + requested_client_hints: None, }; } } @@ -39,117 +44,125 @@ pub(crate) fn parse_http_user_agent_header(header: String) -> UserAgent { ua }; - let (kind, kind_version, maybe_platform) = if let Some(loc) = - contains_ignore_ascii_case(ua, "Firefox") - { - let kind = UserAgentKind::Firefox; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); - (Some(kind), kind_version, None) - } else if let Some(loc) = contains_ignore_ascii_case(ua, "Chrom") { - let kind = UserAgentKind::Chromium; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); - (Some(kind), kind_version, None) - } else if contains_ignore_ascii_case(ua, "Safari").is_some() { - if let Some(firefox_loc) = contains_ignore_ascii_case(ua, "FxiOS") { + let (kind, kind_version, maybe_platform) = + if let Some(loc) = contains_ignore_ascii_case(ua, "Firefox") { let kind = UserAgentKind::Firefox; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[firefox_loc..]); - (Some(kind), kind_version, Some(PlatformKind::IOS)) - } else if let Some(chrome_loc) = contains_ignore_ascii_case(ua, "CriOS") { - let kind = UserAgentKind::Chromium; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[chrome_loc..]); - (Some(kind), kind_version, Some(PlatformKind::IOS)) - } else if let Some(chromium_loc) = contains_any_ignore_ascii_case(ua, &["Opera"]) { + let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); + (Some(kind), kind_version, None) + } else if let Some(loc) = contains_ignore_ascii_case(ua, "Chrom") { let kind = UserAgentKind::Chromium; - let kind_version = parse_ua_version_firefox_and_chromium(&ua[chromium_loc..]); + let kind_version = parse_ua_version_firefox_and_chromium(&ua[loc..]); (Some(kind), kind_version, None) + } else if contains_ignore_ascii_case(ua, "Safari").is_some() { + if let Some(firefox_loc) = contains_ignore_ascii_case(ua, "FxiOS") { + let kind = UserAgentKind::Firefox; + let kind_version = parse_ua_version_firefox_and_chromium(&ua[firefox_loc..]); + (Some(kind), kind_version, Some(PlatformKind::IOS)) + } else if let Some(chrome_loc) = contains_ignore_ascii_case(ua, "CriOS") { + let kind = UserAgentKind::Chromium; + let kind_version = parse_ua_version_firefox_and_chromium(&ua[chrome_loc..]); + (Some(kind), kind_version, Some(PlatformKind::IOS)) + } else if let Some(chromium_loc) = contains_any_ignore_ascii_case(ua, &["Opera"]) { + let kind = UserAgentKind::Chromium; + let kind_version = parse_ua_version_firefox_and_chromium(&ua[chromium_loc..]); + (Some(kind), kind_version, None) + } else { + let kind = UserAgentKind::Safari; + let kind_version = parse_ua_version_safari(ua); + (Some(kind), kind_version, None) + } } else { - let kind = UserAgentKind::Safari; - let kind_version = parse_ua_version_safari(ua); - (Some(kind), kind_version, None) - } - } else if contains_any_ignore_ascii_case(ua, &["Mobile", "Phone", "Tablet", "Zune"]).is_some() { - return UserAgent { - header, - data: UserAgentData::Device(DeviceKind::Mobile), - http_agent_overwrite: None, - tls_agent_overwrite: None, - preserve_ua_header: false, - }; - } else if contains_ignore_ascii_case(ua, "Desktop").is_some() { - return UserAgent { - header, - data: UserAgentData::Device(DeviceKind::Desktop), - http_agent_overwrite: None, - tls_agent_overwrite: None, - preserve_ua_header: false, + (None, None, None) }; - } else { - (None, None, None) - }; - let maybe_platform = match maybe_platform { - Some(platform) => Some(platform), + let (maybe_platform, maybe_device) = match maybe_platform { + Some(platform) => (Some(platform), None), None => { if contains_ignore_ascii_case(ua, "Windows").is_some() { if contains_ignore_ascii_case(ua, "X11").is_some() { - None + (None, Some(DeviceKind::Mobile)) } else { - Some(PlatformKind::Windows) + (Some(PlatformKind::Windows), None) } } else if contains_ignore_ascii_case(ua, "Android").is_some() { if contains_ignore_ascii_case(ua, "iOS").is_some() { - Some(PlatformKind::IOS) + (Some(PlatformKind::IOS), None) } else { - Some(PlatformKind::Android) + (Some(PlatformKind::Android), None) } } else if contains_ignore_ascii_case(ua, "Linux").is_some() { if contains_any_ignore_ascii_case(ua, &["Mobile", "UCW"]).is_some() { - Some(PlatformKind::Android) + (Some(PlatformKind::Android), None) } else { - Some(PlatformKind::Linux) + (Some(PlatformKind::Linux), None) } } else if contains_any_ignore_ascii_case(ua, &["iOS", "iPad", "iPod", "iPhone"]) .is_some() { - Some(PlatformKind::IOS) + (Some(PlatformKind::IOS), None) } else if contains_ignore_ascii_case(ua, "Mac").is_some() { - Some(PlatformKind::MacOS) + (Some(PlatformKind::MacOS), None) } else if contains_ignore_ascii_case(ua, "Darwin").is_some() { if contains_ignore_ascii_case(ua, "86").is_some() { - Some(PlatformKind::MacOS) + (Some(PlatformKind::MacOS), None) } else { - Some(PlatformKind::IOS) + (Some(PlatformKind::IOS), None) } + } else if contains_any_ignore_ascii_case(ua, &["Mobile", "Phone", "Tablet", "Zune"]) + .is_some() + { + (None, Some(DeviceKind::Mobile)) + } else if contains_ignore_ascii_case(ua, "Desktop").is_some() { + (None, Some(DeviceKind::Desktop)) } else { - None + (None, None) } } }; - match (kind, kind_version, maybe_platform) { - (Some(kind), version, platform) => UserAgent { + match (kind, kind_version, maybe_platform, maybe_device) { + (Some(kind), version, platform, device) => UserAgent { header, data: UserAgentData::Standard { info: UserAgentInfo { kind, version }, - platform, + platform_like: match (platform, device) { + (Some(platform), _) => Some(PlatformLike::Platform(platform)), + (None, Some(device)) => Some(PlatformLike::Device(device)), + (None, None) => None, + }, }, http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, + requested_client_hints: None, }, - (None, _, Some(platform)) => UserAgent { + (None, _, Some(platform), _) => UserAgent { header, data: UserAgentData::Platform(platform), http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, + requested_client_hints: None, }, - (None, _, None) => UserAgent { + (None, _, None, Some(device)) => UserAgent { + header, + data: UserAgentData::Device(device), + http_agent_overwrite: None, + tls_agent_overwrite: None, + preserve_ua_header: false, + request_initiator: None, + requested_client_hints: None, + }, + (None, _, None, None) => UserAgent { header, data: UserAgentData::Unknown, http_agent_overwrite: None, tls_agent_overwrite: None, preserve_ua_header: false, + request_initiator: None, + requested_client_hints: None, }, } } @@ -178,7 +191,7 @@ fn parse_ua_version_safari(ua: &str) -> Option { }) } -fn contains_ignore_ascii_case(s: &str, sub: &str) -> Option { +pub(crate) fn contains_ignore_ascii_case(s: &str, sub: &str) -> Option { let n = sub.len(); if n > s.len() { return None; @@ -191,6 +204,17 @@ fn contains_ignore_ascii_case(s: &str, sub: &str) -> Option { }) } +pub(crate) fn starts_with_ignore_ascii_case(s: &str, sub: &str) -> bool { + let n = sub.len(); + if n > s.len() { + return false; + } + + s.get(..n) + .map(|s| s.eq_ignore_ascii_case(sub)) + .unwrap_or_default() +} + fn contains_any_ignore_ascii_case(s: &str, subs: &[&str]) -> Option { let max = s.len(); let smallest_length = subs.iter().map(|s| s.len()).min().unwrap_or(0); @@ -220,7 +244,13 @@ fn contains_any_ignore_ascii_case(s: &str, subs: &[&str]) -> Option { #[cfg(test)] mod tests { - // test contains_ignore_ascii_case + use super::*; + + #[test] + fn test_starts_with_ignore_ascii_case() { + assert!(starts_with_ignore_ascii_case("user-agent", "user")); + assert!(!starts_with_ignore_ascii_case("user-agent", "agent")); + } #[test] fn test_contains_ignore_ascii_case_empty_sub() { diff --git a/rama-ua/src/parse_tests.rs b/rama-ua/src/ua/parse_tests.rs similarity index 67% rename from rama-ua/src/parse_tests.rs rename to rama-ua/src/ua/parse_tests.rs index 7fc6b229a..ed473c0a1 100644 --- a/rama-ua/src/parse_tests.rs +++ b/rama-ua/src/ua/parse_tests.rs @@ -8,22 +8,22 @@ fn test_parse_desktop_ua() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert!(ua.info().is_none()); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -32,22 +32,22 @@ fn test_parse_too_long_ua() { let mut ua = UserAgent::new(ua_str.clone()); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), None); assert_eq!(ua.info(), None); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -56,22 +56,22 @@ fn test_parse_windows() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!(ua.info(), None); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents have defaults for platforms + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -80,7 +80,7 @@ fn test_parse_chrome() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), None); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -90,18 +90,18 @@ fn test_parse_chrome() { ); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -110,7 +110,7 @@ fn test_parse_windows_chrome() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -120,18 +120,18 @@ fn test_parse_windows_chrome() { ); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -140,7 +140,7 @@ fn test_parse_desktop_chrome() { let ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -152,12 +152,12 @@ fn test_parse_desktop_chrome() { } #[test] -fn test_parse_desktop_chrome_with_version() { +fn test_parse_desktop_chrome_set_version() { let ua_str = "desktop chrome/124"; let ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -169,12 +169,12 @@ fn test_parse_desktop_chrome_with_version() { } #[test] -fn test_parse_windows_chrome_with_version() { +fn test_parse_windows_chrome_set_version() { let ua_str = "windows chrome/124"; let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -184,18 +184,18 @@ fn test_parse_windows_chrome_with_version() { ); assert_eq!(ua.platform(), Some(PlatformKind::Windows)); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -204,22 +204,22 @@ fn test_parse_mobile_ua() { let mut ua = UserAgent::new(*ua_str); assert_eq!(ua.header_str(), *ua_str); - assert_eq!(ua.device(), DeviceKind::Mobile); + assert_eq!(ua.device(), Some(DeviceKind::Mobile)); assert_eq!(ua.info(), None); assert_eq!(ua.platform(), None); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } } @@ -230,24 +230,24 @@ fn test_parse_happy_path_unknown_ua() { // UA Is always stored as is. assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), None); // No information should be known about the UA. assert!(ua.info().is_none()); assert!(ua.platform().is_none()); - // Http/Tls agents do have defaults - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Rustls); + // Http/Tls agents do not have defaults + assert_eq!(ua.http_agent(), None); + assert_eq!(ua.tls_agent(), None); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] @@ -256,7 +256,7 @@ fn test_parse_happy_path_ua_macos_chrome() { let mut ua = UserAgent::new(ua_str); assert_eq!(ua.header_str(), ua_str); - assert_eq!(ua.device(), DeviceKind::Desktop); + assert_eq!(ua.device(), Some(DeviceKind::Desktop)); assert_eq!( ua.info(), Some(UserAgentInfo { @@ -266,18 +266,18 @@ fn test_parse_happy_path_ua_macos_chrome() { ); assert_eq!(ua.platform(), Some(PlatformKind::MacOS)); - // Http/Tls - assert_eq!(ua.http_agent(), HttpAgent::Chromium); - assert_eq!(ua.tls_agent(), TlsAgent::Boringssl); + // Http/Tls agents can be linked to found UA Info + assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium)); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl)); // Overwrite http agent - ua.with_http_agent(HttpAgent::Firefox); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_http_agent(HttpAgent::Firefox); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); // Overwrite tls agent - ua.with_tls_agent(TlsAgent::Nss); - assert_eq!(ua.tls_agent(), TlsAgent::Nss); - assert_eq!(ua.http_agent(), HttpAgent::Firefox); + ua.set_tls_agent(TlsAgent::Nss); + assert_eq!(ua.tls_agent(), Some(TlsAgent::Nss)); + assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox)); } #[test] diff --git a/tests/integration/examples/example_tests/http_user_agent_classifier.rs b/tests/integration/examples/example_tests/http_user_agent_classifier.rs index 3e7414a97..8e260e7e0 100644 --- a/tests/integration/examples/example_tests/http_user_agent_classifier.rs +++ b/tests/integration/examples/example_tests/http_user_agent_classifier.rs @@ -35,8 +35,8 @@ async fn test_http_user_agent_classifier() { assert_eq!(ua_rama.kind, None); assert_eq!(ua_rama.version, None); assert_eq!(ua_rama.platform, None); - assert_eq!(ua_rama.http_agent, Some("Chromium".to_owned())); - assert_eq!(ua_rama.tls_agent, Some("Rustls".to_owned())); + assert_eq!(ua_rama.http_agent, None); + assert_eq!(ua_rama.tls_agent, None); const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67"; @@ -68,6 +68,8 @@ async fn test_http_user_agent_classifier() { http: Some(HttpAgent::Safari), tls: Some(TlsAgent::Boringssl), preserve_ua: Some(false), + req_init: None, + req_client_hints: None, }) .unwrap(), ) @@ -80,7 +82,7 @@ async fn test_http_user_agent_classifier() { assert_eq!(ua_app.ua, UA_APP); assert!(ua_app.kind.is_none()); assert!(ua_app.version.is_none()); - assert!(ua_app.platform.is_none()); + assert_eq!(ua_app.platform, Some("iOS".to_owned())); assert_eq!(ua_app.http_agent, Some("Safari".to_owned())); assert_eq!(ua_app.tls_agent, Some("Boringssl".to_owned())); }